fix: 添加 Sqlite3Helper

This commit is contained in:
Julian Freeman
2024-07-19 00:38:40 -04:00
parent efdf13444f
commit 0a13c1fe0e
2 changed files with 467 additions and 1 deletions

2
.gitignore vendored
View File

@@ -14,7 +14,7 @@ dist/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
lib/ #lib/
lib64/ lib64/
parts/ parts/
sdist/ sdist/

466
lib/Sqlite3Helper.py Normal file
View File

@@ -0,0 +1,466 @@
# coding: utf8
from __future__ import annotations
import sqlite3
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import StrEnum
from os import PathLike
__version__ = "1.1.0"
__version_info__ = (1, 1, 0)
class DataType(StrEnum):
NULL = "NULL"
INTEGER = "INTEGER"
REAL = "REAL"
TEXT = "TEXT"
BLOB = "BLOB"
class NullType(object):
def __str__(self):
return "NULL"
class BlobType(object):
def __init__(self, data: bytes = b""):
self._data = data
def __str__(self):
return f"X'{self._data.hex()}'"
def _get_type(data_type: DataType) -> type:
if data_type == DataType.NULL:
return NullType
if data_type == DataType.INTEGER:
return int
if data_type == DataType.REAL:
return float
if data_type == DataType.TEXT:
return str
if data_type == DataType.BLOB:
return BlobType
raise TypeError(f"Data type {data_type} is not supported")
def _get_data_type(type_: type) -> DataType:
if type_ is NullType:
return DataType.NULL
if type_ is int:
return DataType.INTEGER
if type_ is float:
return DataType.REAL
if type_ is str:
return DataType.TEXT
if type_ is BlobType:
return DataType.BLOB
raise TypeError(f"Data type {type_} is not supported")
def _to_string(value):
# 如果传入的类型不是 text 会直接返回原值
if type(value) is str:
if not (value.startswith("'") or value.endswith("'")):
value = f"'{value}'"
return str(value)
@dataclass
class Column(object):
name: str
data_type: DataType
primary_key: bool = False
nullable: bool = True
unique: bool = False
has_default: bool = False
default: NullType | int | float | str | BlobType = 0
def __str__(self):
head = f"{self.name} {self.data_type.name}"
if self.primary_key:
head = f"{head} PRIMARY KEY"
if not self.nullable:
head = f"{head} NOT NULL"
if self.unique:
head = f"{head} UNIQUE"
if self.has_default:
head = f"{head} DEFAULT {_to_string(self.default)}"
return head
__repr__ = __str__
@dataclass
class AbsHeadRow(ABC):
# 这里的几个其实都是 list[str],但为了防止实例类中提示类型错误,就写成了 list
primary_keys: list = field(default_factory=list)
not_nulls: list = field(default_factory=list)
uniques: list = field(default_factory=list)
defaults: dict = field(default_factory=dict)
@abstractmethod
def __post_init__(self):
pass
def __getattr__(self, item):
# 此类声明的类变量自然都是不会被定义的,所以在 __post_init__ 中获取时就会跳转到这
# 此时需要的其实就是这个类变量名称本身,因此直接返回
if item in self.__annotations__:
return item
raise AttributeError
def to_columns(self) -> list[Column]:
columns: list[Column] = []
# self.__annotations__ 会返回实例对象中定义的类变量名称和类型的键值对
# 但不包括此抽象类中的
for column_name in self.__annotations__:
column_type = self.__annotations__[column_name]
data_type = _get_data_type(column_type)
# 实例对象的这个类变量设置了值,在此作为默认值
if column_name in self.defaults:
default = self.defaults[column_name]
if type(default) is not column_type:
raise TypeError(f"Column {column_name} must be {column_type} but was {type(default)}")
has_default = True
else:
default = 0 # 此处设置的值无用,但是默认给它一个值
has_default = False
primary_key = column_name in self.primary_keys
nullable = column_name not in self.not_nulls
unique = column_name in self.uniques
columns.append(Column(
name=column_name,
data_type=data_type,
primary_key=primary_key,
nullable=nullable,
unique=unique,
has_default=has_default,
default=default,
))
return columns
def to_column_dict(self) -> dict[str, Column]:
return {col.name: col for col in self.to_columns()}
class Expression(object):
def __init__(self, expr: str):
self._expr = expr
def __str__(self):
return self._expr
def and_(self, expression: Expression):
return Expression(f"{self._expr} AND {expression}")
def or_(self, expression: Expression):
return Expression(f"{self._expr} OR {expression}")
def exists(self, not_: bool = False):
mark = "EXISTS"
if not_:
mark = "NOT EXISTS"
return Expression(f"{mark} ({self._expr})")
class Operand(object):
def __init__(self, column: Column | str):
self._name = column.name if isinstance(column, Column) else column
def equal_to(self, value):
return Expression(f"{self._name} = {_to_string(value)}")
def not_equal_to(self, value):
return Expression(f"{self._name} =! {_to_string(value)}")
def less_than(self, value):
return Expression(f"{self._name} < {_to_string(value)}")
def greater_than(self, value):
return Expression(f"{self._name} > {_to_string(value)}")
def less_equal(self, value):
return Expression(f"{self._name} <= {_to_string(value)}")
def greater_equal(self, value):
return Expression(f"{self._name} >= {_to_string(value)}")
def between(self, minimum, maximum, not_: bool = False):
mark = "BETWEEN"
if not_:
mark = "NOT BETWEEN"
return Expression(f"{self._name} {mark} {_to_string(minimum)} AND {_to_string(maximum)}")
def in_(self, values: list | str, not_: bool = False):
if isinstance(values, list):
values = ", ".join([_to_string(value) for value in values])
mark = "IN"
if not_:
mark = "NOT IN"
return Expression(f"{self._name} {mark} ({values})")
def like(self, regx: str, escape: str = "", not_: bool = False):
head = "LIKE"
if not_:
head = "NOT LIKE"
body = f"{head} {_to_string(regx)}"
if len(escape) != 0:
body = f"{body} ESCAPE {_to_string(escape)}"
return Expression(f"{self._name} {body}")
def is_null(self, not_: bool = False):
mark = "IS NULL"
if not_:
mark = "IS NOT NULL"
return Expression(f"{self._name} {mark}")
def glob(self, regx: str):
return Expression(f"{self._name} GLOB {_to_string(regx)}")
class SortOption(StrEnum):
NONE = ""
ASC = "ASC"
DESC = "DESC"
class NullOption(StrEnum):
NONE = ""
NULLS_FIRST = "NULLS FIRST"
NULLS_LAST = "NULLS LAST"
def order(column: Column | str | int,
sort_option: SortOption = SortOption.NONE,
null_option: NullOption = NullOption.NONE) -> str:
name = column.name if isinstance(column, Column) else str(column)
if sort_option != SortOption.NONE:
name = f"{name} {sort_option.value}"
if sqlite3.sqlite_version_info >= (3, 30, 0):
if null_option != NullOption.NONE:
name = f"{name} {null_option.value}"
return name
class Sqlite3Worker(object):
def __init__(self, db_name: str | PathLike[str] = ":memory:"):
self._conn = sqlite3.connect(db_name)
self._cursor = self._conn.cursor()
self._is_closed = False
def __del__(self):
self.close()
def close(self):
if self._is_closed is False:
self._cursor.close()
self._conn.close()
self._is_closed = True
def commit(self):
self._conn.commit()
def create_table(self, table_name: str, columns: list[Column] | AbsHeadRow,
if_not_exists: bool = False, schema_name: str = "", *, execute: bool = True) -> str:
if table_name.startswith("sqlite_"):
raise ValueError("Table name must not start with 'sqlite_')")
if isinstance(columns, AbsHeadRow):
columns = columns.to_columns()
columns_str = ", ".join([str(col) for col in columns])
head = "CREATE TABLE"
if if_not_exists:
head = f"{head} IF NOT EXISTS"
name = table_name
if len(schema_name) != 0:
name = f"{schema_name}.{name}"
statement = f"{head} {name} ({columns_str});"
if execute:
self._cursor.execute(statement)
return statement
def drop_table(self, table_name: str, if_exists: bool = False,
schema_name: str = "", *, execute: bool = True) -> str:
head = "DROP TABLE"
if if_exists:
head = f"{head} IF EXISTS"
name = table_name
if len(schema_name) != 0:
name = f"{schema_name}.{name}"
statement = f"{head} {name};"
if execute:
self._cursor.execute(statement)
return statement
def rename_table(self, table_name: str, new_name: str, *, execute: bool = True) -> str:
head = "ALTER TABLE"
statement = f"{head} {table_name} RENAME TO {new_name};"
if execute:
self._cursor.execute(statement)
return statement
def add_column(self, table_name: str, column: Column, *, execute: bool = True) -> str:
if column.primary_key or column.unique:
raise ValueError("The new column cannot have primary key or unique")
if not column.nullable:
if not column.has_default:
raise ValueError("If the new column is not null, it must have default value")
if column.default is NullType:
raise ValueError("If the new column is not null, its default value must not be NULL")
head = "ALTER TABLE"
statement = f"{head} {table_name} ADD COLUMN {str(column)};"
if execute:
self._cursor.execute(statement)
return statement
def rename_column(self, table_name: str, column_name: str,
new_name: str, *, execute: bool = True) -> str:
if sqlite3.sqlite_version_info < (3, 25, 0):
raise ValueError("SQLite under 3.25.0 does not support rename column")
head = "ALTER TABLE"
statement = f"{head} {table_name} RENAME COLUMN {column_name} TO {new_name};"
if execute:
self._cursor.execute(statement)
return statement
def show_tables(self) -> list[str]:
cond = Operand("type").equal_to("table").and_(Operand("name").like("sqlite_%", not_=True))
_, tables = self.select("sqlite_schema", ["name"], where=cond)
return [table[0] for table in tables]
@staticmethod
def _columns_to_string(columns: list[Column | str]) -> str:
columns_str_ls = []
for column in columns:
if isinstance(column, Column):
columns_str_ls.append(column.name)
elif isinstance(column, str):
columns_str_ls.append(column)
else:
raise ValueError(f"Column must be str or Column object, found {type(column)}")
return ", ".join(columns_str_ls)
def insert_into(self, table_name: str, columns: list[Column | str],
values: list[list[str | int | float]],
*, execute: bool = True, commit: bool = True) -> str:
col_count = len(columns)
columns_str = self._columns_to_string(columns)
values_str_ls = []
for value in values:
if len(value) != col_count:
raise ValueError(f"Length of values must be {col_count}")
for i in range(col_count):
column = columns[i]
if isinstance(column, Column):
type_ = _get_type(column.data_type)
# 支持将 int 隐式转为 float
if type(value[i]) is int and type_ is float:
continue
if type(value[i]) is not type_:
raise ValueError(f"The {i + 1}(th) type of value must be {type_},"
f" because the column type is {column.data_type}")
# 这里的 value 是一行数据,是一个多值列表
values_str_ls.append(f"({', '.join([_to_string(val) for val in value])})")
values_str = ", ".join(values_str_ls)
head = "INSERT INTO"
statement = f"{head} {table_name} ({columns_str}) VALUES {values_str};"
if execute:
self._cursor.execute(statement)
if commit:
self._conn.commit()
return statement
@staticmethod
def _join_where_order_limit(body: str,
where: Expression, order_by: list[str] | str,
limit: int, offset: int) -> str:
if where is not None:
body = f"{body} WHERE {where}"
if order_by is not None:
if not isinstance(order_by, list):
order_by = [order_by]
body = f"{body} ORDER BY {', '.join(order_by)}"
if limit is not None:
body = f"{body} LIMIT {limit}"
if offset is not None:
body = f"{body} OFFSET {offset}"
return body
def select(self, table_name: str, columns: list[Column | str], distinct: bool = False,
where: Expression = None,
order_by: list[str] | str = None,
limit: int = None, offset: int = None,
*, execute: bool = True) -> tuple[str, list[tuple]]:
if len(columns) == 0:
columns_str = "*"
else:
columns_str = self._columns_to_string(columns)
head = "SELECT"
if distinct:
head = f"{head} DISTINCT"
body = f"{head} {columns_str} FROM {table_name}"
body = self._join_where_order_limit(body, where, order_by, limit, offset)
statement = f"{body};"
if execute:
self._cursor.execute(statement)
rows = self._cursor.fetchall()
return statement, rows
else:
return statement, []
def delete_from(self, table_name: str, where: Expression = None,
*, execute: bool = True, commit: bool = True) -> str:
head = "DELETE FROM"
body = f"{head} {table_name}"
if where is not None:
body = f"{body} WHERE {where}"
statement = f"{body};"
if execute:
self._cursor.execute(statement)
if commit:
self._conn.commit()
return statement
def update(self, table_name: str, new_values: list[tuple[Column, NullType | int | float | str | BlobType]],
where: Expression = None,
*, execute: bool = True, commit: bool = True) -> str:
new_values_str_ls = []
for column, value in new_values:
if _get_data_type(type(value)) != column.data_type:
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
new_values_str_ls.append(f"{column.name} = {_to_string(value)}")
head = f"UPDATE {table_name}"
body = f"{head} SET {', '.join(new_values_str_ls)}"
if where is not None:
body = f"{body} WHERE {where}"
statement = f"{body};"
if execute:
self._cursor.execute(statement)
if commit:
self._conn.commit()
return statement