@@ -1,13 +1,18 @@
# coding: utf8
from __future__ import annotations
import os
import sqlite3
from abc import ABC , abstractmethod
from dataclasses import dataclass , field
import time
from dataclasses import dataclass
from enum import StrEnum
from os import PathLike
__version__ = " 1.1.0 "
__version_info__ = ( 1 , 1 , 0 )
from cryptography . fernet import Fernet , InvalidToken
__version__ = " 2.2.0 "
__version_info__ = tuple ( map ( int , __version__ . split ( " . " ) ) )
class DataType ( StrEnum ) :
@@ -27,10 +32,28 @@ class NullType(object):
class BlobType ( object ) :
def __init__ ( self , data : bytes = b " " ) :
self . _ data = data
self . data = data
def __str__ ( self ) :
return f " X ' { self . _ data. hex ( ) } ' "
return f " X ' { self . data . hex ( ) } ' "
class NotRandomFernet ( Fernet ) :
""" 固定下来每次相同的 key 的加密结果相同,方便条件查询 """
def __init__ ( self , key : bytes | str , fix_time : int , fix_iv : bytes , backend = None ) :
super ( ) . __init__ ( key , backend )
self . _fix_time = fix_time
self . _fix_iv = fix_iv
def encrypt ( self , data : bytes ) - > bytes :
return self . _encrypt_from_parts ( data , self . _fix_time , self . _fix_iv )
def _encrypt_blob ( blob : BlobType , fernet : NotRandomFernet ) - > BlobType :
if fernet is None :
raise ValueError ( " Key is not set " )
return BlobType ( fernet . encrypt ( blob . data ) )
def _get_type ( data_type : DataType ) - > type :
@@ -79,8 +102,14 @@ class Column(object):
has_default : bool = False
default : NullType | int | float | str | BlobType = 0
secure : bool = False
def __post_init__ ( self ) :
if self . secure is True and self . data_type != DataType . BLOB :
raise ValueError ( " Only BLOB data can be secured " )
def __str__ ( self ) :
head = f " { self . name } { self . data_type . nam e} "
head = f " { self . name } { self . data_type . valu e} "
if self . primary_key :
head = f " { head } PRIMARY KEY "
if not self . nullable :
@@ -94,63 +123,6 @@ class Column(object):
__repr__ = __str__
@dataclass
class AbsHeadRow ( ABC ) :
# 这里的几个其实都是 list[str],但为了防止实例类中提示类型错误,就写成了 list
primary_keys : list = field ( default_factory = list )
not_nulls : list = field ( default_factory = list )
uniques : list = field ( default_factory = list )
defaults : dict = field ( default_factory = dict )
@abstractmethod
def __post_init__ ( self ) :
pass
def __getattr__ ( self , item ) :
# 此类声明的类变量自然都是不会被定义的,所以在 __post_init__ 中获取时就会跳转到这
# 此时需要的其实就是这个类变量名称本身,因此直接返回
if item in self . __annotations__ :
return item
raise AttributeError
def to_columns ( self ) - > list [ Column ] :
columns : list [ Column ] = [ ]
# self.__annotations__ 会返回实例对象中定义的类变量名称和类型的键值对
# 但不包括此抽象类中的
for column_name in self . __annotations__ :
column_type = self . __annotations__ [ column_name ]
data_type = _get_data_type ( column_type )
# 实例对象的这个类变量设置了值,在此作为默认值
if column_name in self . defaults :
default = self . defaults [ column_name ]
if type ( default ) is not column_type :
raise TypeError ( f " Column { column_name } must be { column_type } but was { type ( default ) } " )
has_default = True
else :
default = 0 # 此处设置的值无用,但是默认给它一个值
has_default = False
primary_key = column_name in self . primary_keys
nullable = column_name not in self . not_nulls
unique = column_name in self . uniques
columns . append ( Column (
name = column_name ,
data_type = data_type ,
primary_key = primary_key ,
nullable = nullable ,
unique = unique ,
has_default = has_default ,
default = default ,
) )
return columns
def to_column_dict ( self ) - > dict [ str , Column ] :
return { col . name : col for col in self . to_columns ( ) }
class Expression ( object ) :
def __init__ ( self , expr : str ) :
@@ -174,14 +146,33 @@ class Expression(object):
class Operand ( object ) :
def __init__ ( self , column : Column | str ) :
def __init__ (
self ,
column : Column | str ,
key : bytes = None ,
fix_time : int = None ,
fix_iv : bytes = None ,
) :
self . _column = column
self . _key = key
self . _fix_time = fix_time
self . _fix_iv = fix_iv
self . _name = column . name if isinstance ( column , Column ) else column
def equal_to ( self , value ) :
return Expression ( f " { self . _name } = { _to_string ( value ) } " )
def not_equal_to ( self , value ) :
return Expression ( f " { self . _name } =! { _to_string ( value ) } " )
def equal_to ( self , value , not_ : bool = False ):
if isinstance ( self . _column , Column ) :
if self . _column . data_type == DataType . BLOB and type ( value ) is str :
value = BlobType ( value . encode ( " utf-8 " ) )
# 这里不能换成 elif
if self . _key is not None and self . _column . secure and isinstance ( value , BlobType ) :
fix_time = self . _fix_time if self . _fix_time is not None else int ( time . time ( ) )
fix_iv = self . _fix_iv if self . _fix_iv is not None else os . urandom ( 16 )
try :
value = _encrypt_blob ( value , NotRandomFernet ( self . _key , fix_time , fix_iv ) )
except ValueError :
pass
op = " != " if not_ else " = "
return Expression ( f " { self . _name } { op } { _to_string ( value ) } " )
def less_than ( self , value ) :
return Expression ( f " { self . _name } < { _to_string ( value ) } " )
@@ -254,11 +245,25 @@ def order(column: Column | str | int,
class Sqlite3Worker ( object ) :
def __init__ ( self , db_name : str | PathLike [ str ] = " :memory: " ) :
def __init__ (
self ,
db_name : str | PathLike [ str ] = " :memory: " ,
key : bytes = None ,
fix_time : int = None ,
fix_iv : bytes = None ,
) :
self . _db_name = db_name
self . _conn = sqlite3 . connect ( db_name )
self . _cursor = self . _conn . cursor ( )
self . _is_closed = False
self . _fernet = None
if key is not None :
fix_time = fix_time if fix_time is not None else int ( time . time ( ) )
fix_iv = fix_iv if fix_iv is not None else os . urandom ( 16 )
try :
self . _fernet = NotRandomFernet ( key , fix_time , fix_iv )
except ValueError :
pass
def __del__ ( self ) :
self . close ( )
@@ -276,13 +281,12 @@ class Sqlite3Worker(object):
def commit ( self ) :
self . _conn . commit ( )
def create_table ( self , table_name : str , columns : list [ Column ] | AbsHeadRow ,
if_not_exists : bool = False , schema_name : str = " " , * , execute : bool = True ) - > str :
def create_table ( self , table_name : str , columns : list [ Column ] ,
if_not_exists : bool = False , schema_name : str = " " ,
* , execute : bool = True ) - > str :
if table_name . startswith ( " sqlite_ " ) :
raise ValueError ( " Table name must not start with ' sqlite_ ' ) " )
if isinstance ( columns , AbsHeadRow ) :
columns = columns . to_columns ( )
columns_str = " , " . join ( [ str ( col ) for col in columns ] )
head = " CREATE TABLE "
if if_not_exists :
@@ -363,27 +367,41 @@ class Sqlite3Worker(object):
return " , " . join ( columns_str_ls )
def insert_into ( self , table_name : str , columns : list [ Column | str ] ,
values : list [ list [ str | int | float | BlobType ] ] ,
values : list [ list [ NullType | str | int | float | BlobType ] ] ,
* , execute : bool = True , commit : bool = True ) - > str :
col_count = len ( columns )
columns_str = self . _columns_to_string ( columns )
values_str_ls = [ ]
for value in values :
if len ( value ) != col_count :
for value_row in values :
if len ( value_row ) != col_count :
raise ValueError ( f " Length of values must be { col_count } " )
value_row_str_ls = [ ]
for i in range ( col_count ) :
column = columns [ i ]
value = value_row [ i ]
if isinstance ( column , Column ) :
type_ = _get_type ( column . data_type )
col_ type = _get_type ( column . data_type )
val_type = type ( value )
# 支持将 int 隐式转为 float
if type ( value [ i ] ) is int and type_ is float :
continue
if type ( value [ i ] ) is not type_ :
raise ValueError ( f " The { i + 1 } (th) type of value must be { type_ } , "
f " because the column type is { column . data_type } " )
# 这里的 value 是一行数据,是一个多值列表
values_str_ls . append ( f " ( { ' , ' . join ( [ _to_string ( val ) for val in value ] ) } ) " )
if val_type is int and col_ type is float :
pass
# 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的
elif val_type is NullType and column . nullable is True :
pass
# 其他类型不匹配
elif val_type is not col_type :
raise ValueError ( f " The { i + 1 } (th) type of value must be { col_type } , "
f " because the column type is { column . data_type } , "
f " found { val_type } " )
# 如果加密
if column . secure and val_type is not NullType :
value = _to_string ( _encrypt_blob ( value , self . _fernet ) )
value_row_str_ls . append ( _to_string ( value ) )
values_str_ls . append ( f " ( { ' , ' . join ( value_row_str_ls ) } ) " )
values_str = " , " . join ( values_str_ls )
@@ -391,8 +409,8 @@ class Sqlite3Worker(object):
statement = f " { head } { table_name } ( { columns_str } ) VALUES { values_str } ; "
if execute :
self . _cursor . execute ( statement )
if commit :
self . _conn . commit ( )
if commit :
self . _conn . commit ( )
return statement
@staticmethod
@@ -415,7 +433,7 @@ class Sqlite3Worker(object):
where : Expression = None ,
order_by : list [ str ] | str = None ,
limit : int = None , offset : int = None ,
* , execute : bool = True ) - > tuple [ str , list [ tuple ] ] :
* , execute : bool = True ) - > tuple [ str , list [ list ] ] :
if len ( columns ) == 0 :
columns_str = " * "
else :
@@ -431,6 +449,22 @@ class Sqlite3Worker(object):
if execute :
self . _cursor . execute ( statement )
rows = self . _cursor . fetchall ( )
rows = [ list ( row ) for row in rows ] # 将每行转成列表,方便替换解密数据
# 下面的整个循环都是为了找到需要解密的数据尝试解密
for i in range ( len ( columns ) ) :
column = columns [ i ]
if isinstance ( column , Column ) and column . secure :
for row in rows :
# 如果是加密的 BLOB 但是值不为 NULL 才解密
if row [ i ] is not None and self . _fernet is not None :
# 不管是key错误还是密文错误, 都是 InvalidToken, 貌似没法区分
# 因此如果有的数据不是加密过的,应该跳过,不应该影响之后的密文解密,
# 因此这里还是得继续循环下去
try :
row [ i ] = self . _fernet . decrypt ( row [ i ] )
except InvalidToken :
pass
return statement , rows
else :
return statement , [ ]
@@ -445,18 +479,30 @@ class Sqlite3Worker(object):
statement = f " { body } ; "
if execute :
self . _cursor . execute ( statement )
if commit :
self . _conn . commit ( )
if commit :
self . _conn . commit ( )
return statement
def update ( self , table_name : str , new_values : list [ tuple [ Column , NullType | int | float | str | BlobType ] ] ,
def update ( self , table_name : str , new_values : list [ tuple [ Column | str , NullType | int | float | str | BlobType ] ] ,
where : Expression = None ,
* , execute : bool = True , commit : bool = True ) - > str :
new_values_str_ls = [ ]
for column , value in new_values :
if _get_data_type ( type ( value ) ) != c olumn. data_type :
raise ValueError ( f " Type of { column . name } must be { column . data_type } , found { type ( value ) } " )
new_values_str_ls . append ( f " { column . name } = { _to_string ( value ) } " )
if isinstance ( column , C olumn) :
# 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的
if type ( value ) is NullType and column . nullable is True :
pass
elif _get_data_type ( type ( value ) ) != column . data_type :
raise ValueError ( f " Type of { column . name } must be { column . data_type } , found { type ( value ) } " )
if column . secure and type ( value ) is not NullType :
value = _encrypt_blob ( value , self . _fernet )
name = column . name
else :
name = column
new_values_str_ls . append ( f " { name } = { _to_string ( value ) } " )
head = f " UPDATE { table_name } "
body = f " { head } SET { ' , ' . join ( new_values_str_ls ) } "
@@ -466,6 +512,6 @@ class Sqlite3Worker(object):
statement = f " { body } ; "
if execute :
self . _cursor . execute ( statement )
if commit :
self . _conn . commit ( )
if commit :
self . _conn . commit ( )
return statement