v1.2.0
This commit is contained in:
@@ -7,11 +7,12 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from os import PathLike
|
||||
from types import NoneType
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
|
||||
__version__ = "2.2.0"
|
||||
__version__ = "2.2.2"
|
||||
__version_info__ = tuple(map(int, __version__.split(".")))
|
||||
|
||||
|
||||
@@ -32,13 +33,18 @@ class NullType(object):
|
||||
class BlobType(object):
|
||||
|
||||
def __init__(self, data: bytes = b""):
|
||||
self.data = data
|
||||
self._data = data
|
||||
|
||||
def __str__(self):
|
||||
return f"X'{self.data.hex()}'"
|
||||
return f"X'{self._data.hex()}'"
|
||||
|
||||
def encrypt(self, fernet: _NotRandomFernet) -> BlobType:
|
||||
if fernet is None:
|
||||
raise ValueError("Key is not set")
|
||||
return BlobType(fernet.encrypt(self._data))
|
||||
|
||||
|
||||
class NotRandomFernet(Fernet):
|
||||
class _NotRandomFernet(Fernet):
|
||||
"""固定下来每次相同的 key 的加密结果相同,方便条件查询"""
|
||||
|
||||
def __init__(self, key: bytes | str, fix_time: int, fix_iv: bytes, backend=None):
|
||||
@@ -50,44 +56,53 @@ class NotRandomFernet(Fernet):
|
||||
return self._encrypt_from_parts(data, self._fix_time, self._fix_iv)
|
||||
|
||||
|
||||
def _encrypt_blob(blob: BlobType, fernet: NotRandomFernet) -> BlobType:
|
||||
if fernet is None:
|
||||
raise ValueError("Key is not set")
|
||||
return BlobType(fernet.encrypt(blob.data))
|
||||
VALUE_TYPES = None | NullType | int | float | str | bytes | BlobType
|
||||
|
||||
|
||||
def _get_type(data_type: DataType) -> type:
|
||||
def _check_data_type(data_type: DataType, allow_null: bool, value) -> bool:
|
||||
value_type = type(value)
|
||||
allow_types = []
|
||||
if data_type == DataType.NULL:
|
||||
return NullType
|
||||
if data_type == DataType.INTEGER:
|
||||
return int
|
||||
if data_type == DataType.REAL:
|
||||
return float
|
||||
if data_type == DataType.TEXT:
|
||||
return str
|
||||
pass
|
||||
elif data_type == DataType.INTEGER:
|
||||
allow_types.extend([int, ])
|
||||
elif data_type == DataType.REAL:
|
||||
allow_types.extend([int, float])
|
||||
elif data_type == DataType.TEXT:
|
||||
allow_types.extend([str, ])
|
||||
elif data_type == DataType.BLOB:
|
||||
allow_types.extend([str, bytes, BlobType])
|
||||
|
||||
if allow_null:
|
||||
allow_types.extend([NoneType, NullType])
|
||||
|
||||
return value_type in allow_types
|
||||
|
||||
|
||||
def _implicitly_convert(data_type: DataType, value):
|
||||
if data_type == DataType.REAL and type(value) is int:
|
||||
return float(value)
|
||||
if data_type == DataType.BLOB:
|
||||
return BlobType
|
||||
raise TypeError(f"Data type {data_type} is not supported")
|
||||
if type(value) is str:
|
||||
return BlobType(value.encode("utf-8"))
|
||||
if type(value) is bytes:
|
||||
return BlobType(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _get_data_type(type_: type) -> DataType:
|
||||
if type_ is NullType:
|
||||
return DataType.NULL
|
||||
if type_ is int:
|
||||
return DataType.INTEGER
|
||||
if type_ is float:
|
||||
return DataType.REAL
|
||||
if type_ is str:
|
||||
return DataType.TEXT
|
||||
if type_ is BlobType:
|
||||
return DataType.BLOB
|
||||
raise TypeError(f"Data type {type_} is not supported")
|
||||
def _is_null(value) -> bool:
|
||||
return type(value) in (NoneType, NullType)
|
||||
|
||||
|
||||
def _to_string(value):
|
||||
# 如果传入的类型不是 text 会直接返回原值
|
||||
if type(value) is str:
|
||||
if not (value.startswith("'") or value.endswith("'")):
|
||||
if value is None:
|
||||
value = NullType()
|
||||
elif type(value) is str:
|
||||
# 只要开头或者结尾任意一个字符不是单引号
|
||||
if not (value.startswith("'") and value.endswith("'")):
|
||||
# 把单引号换为两个单引号转义
|
||||
value = value.replace("'", "''")
|
||||
value = f"'{value}'"
|
||||
return str(value)
|
||||
|
||||
@@ -171,7 +186,7 @@ class Operand(object):
|
||||
fix_time = self._fix_time if self._fix_time is not None else int(time.time())
|
||||
fix_iv = self._fix_iv if self._fix_iv is not None else os.urandom(16)
|
||||
try:
|
||||
value = _encrypt_blob(value, NotRandomFernet(self._key, fix_time, fix_iv))
|
||||
value = value.encrypt(_NotRandomFernet(self._key, fix_time, fix_iv))
|
||||
except ValueError:
|
||||
pass
|
||||
op = "!=" if not_ else "="
|
||||
@@ -264,7 +279,7 @@ class Sqlite3Worker(object):
|
||||
fix_time = fix_time if fix_time is not None else int(time.time())
|
||||
fix_iv = fix_iv if fix_iv is not None else os.urandom(16)
|
||||
try:
|
||||
self._fernet = NotRandomFernet(key, fix_time, fix_iv)
|
||||
self._fernet = _NotRandomFernet(key, fix_time, fix_iv)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -284,6 +299,12 @@ class Sqlite3Worker(object):
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def _execute(self, statement: str):
|
||||
try:
|
||||
self._cursor.execute(statement)
|
||||
except sqlite3.Error as e:
|
||||
raise sqlite3.Error(f"Error name: {e.sqlite_errorname};\nError statement: {statement}")
|
||||
|
||||
def create_table(self, table_name: str, columns: list[Column],
|
||||
if_not_exists: bool = False, schema_name: str = "",
|
||||
*, execute: bool = True) -> str:
|
||||
@@ -301,7 +322,7 @@ class Sqlite3Worker(object):
|
||||
statement = f"{head} {name} ({columns_str});"
|
||||
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
return statement
|
||||
|
||||
def drop_table(self, table_name: str, if_exists: bool = False,
|
||||
@@ -316,14 +337,14 @@ class Sqlite3Worker(object):
|
||||
statement = f"{head} {name};"
|
||||
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
return statement
|
||||
|
||||
def rename_table(self, table_name: str, new_name: str, *, execute: bool = True) -> str:
|
||||
head = "ALTER TABLE"
|
||||
statement = f"{head} {table_name} RENAME TO {new_name};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
return statement
|
||||
|
||||
def add_column(self, table_name: str, column: Column, *, execute: bool = True) -> str:
|
||||
@@ -338,7 +359,7 @@ class Sqlite3Worker(object):
|
||||
head = "ALTER TABLE"
|
||||
statement = f"{head} {table_name} ADD COLUMN {str(column)};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
return statement
|
||||
|
||||
def rename_column(self, table_name: str, column_name: str,
|
||||
@@ -349,7 +370,7 @@ class Sqlite3Worker(object):
|
||||
head = "ALTER TABLE"
|
||||
statement = f"{head} {table_name} RENAME COLUMN {column_name} TO {new_name};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
return statement
|
||||
|
||||
def show_tables(self) -> list[str]:
|
||||
@@ -370,7 +391,7 @@ class Sqlite3Worker(object):
|
||||
return ", ".join(columns_str_ls)
|
||||
|
||||
def insert_into(self, table_name: str, columns: list[Column | str],
|
||||
values: list[list[NullType | str | int | float | BlobType]],
|
||||
values: list[list[VALUE_TYPES]],
|
||||
*, execute: bool = True, commit: bool = True) -> str:
|
||||
col_count = len(columns)
|
||||
columns_str = self._columns_to_string(columns)
|
||||
@@ -381,26 +402,15 @@ class Sqlite3Worker(object):
|
||||
raise ValueError(f"Length of values must be {col_count}")
|
||||
|
||||
value_row_str_ls = []
|
||||
for i in range(col_count):
|
||||
column = columns[i]
|
||||
value = value_row[i]
|
||||
for column, value in zip(columns, value_row):
|
||||
if isinstance(column, Column):
|
||||
col_type = _get_type(column.data_type)
|
||||
val_type = type(value)
|
||||
# 支持将 int 隐式转为 float
|
||||
if val_type is int and col_type is float:
|
||||
pass
|
||||
# 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的
|
||||
elif val_type is NullType and column.nullable is True:
|
||||
pass
|
||||
# 其他类型不匹配
|
||||
elif val_type is not col_type:
|
||||
raise ValueError(f"The {i + 1}(th) type of value must be {col_type},"
|
||||
f" because the column type is {column.data_type},"
|
||||
f" found {val_type}")
|
||||
if not _check_data_type(column.data_type, column.nullable, value):
|
||||
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
|
||||
# 这一步一定在加密之前
|
||||
value = _implicitly_convert(column.data_type, value)
|
||||
# 如果加密
|
||||
if column.secure and val_type is not NullType:
|
||||
value = _to_string(_encrypt_blob(value, self._fernet))
|
||||
if column.secure and not _is_null(value):
|
||||
value = value.encrypt(self._fernet)
|
||||
|
||||
value_row_str_ls.append(_to_string(value))
|
||||
|
||||
@@ -411,7 +421,7 @@ class Sqlite3Worker(object):
|
||||
head = "INSERT INTO"
|
||||
statement = f"{head} {table_name} ({columns_str}) VALUES {values_str};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
@@ -450,7 +460,7 @@ class Sqlite3Worker(object):
|
||||
|
||||
statement = f"{body};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
rows = self._cursor.fetchall()
|
||||
rows = [list(row) for row in rows] # 将每行转成列表,方便替换解密数据
|
||||
# 下面的整个循环都是为了找到需要解密的数据尝试解密
|
||||
@@ -481,25 +491,24 @@ class Sqlite3Worker(object):
|
||||
|
||||
statement = f"{body};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
|
||||
def update(self, table_name: str, new_values: list[tuple[Column | str, NullType | int | float | str | BlobType]],
|
||||
def update(self, table_name: str, new_values: list[tuple[Column | str, VALUE_TYPES]],
|
||||
where: Expression = None,
|
||||
*, execute: bool = True, commit: bool = True) -> str:
|
||||
new_values_str_ls = []
|
||||
for column, value in new_values:
|
||||
if isinstance(column, Column):
|
||||
# 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的
|
||||
if type(value) is NullType and column.nullable is True:
|
||||
pass
|
||||
elif _get_data_type(type(value)) != column.data_type:
|
||||
if not _check_data_type(column.data_type, column.nullable, value):
|
||||
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
|
||||
|
||||
if column.secure and type(value) is not NullType:
|
||||
value = _encrypt_blob(value, self._fernet)
|
||||
# 这一步一定在加密之前
|
||||
value = _implicitly_convert(column.data_type, value)
|
||||
# 如果加密
|
||||
if column.secure and not _is_null(value):
|
||||
value = value.encrypt(self._fernet)
|
||||
|
||||
name = column.name
|
||||
else:
|
||||
@@ -514,7 +523,7 @@ class Sqlite3Worker(object):
|
||||
|
||||
statement = f"{body};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
self._execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
|
||||
Reference in New Issue
Block a user