diff --git a/lib/Sqlite3Helper.py b/lib/Sqlite3Helper.py index 60b9776..abf5dfa 100644 --- a/lib/Sqlite3Helper.py +++ b/lib/Sqlite3Helper.py @@ -1,13 +1,18 @@ # coding: utf8 from __future__ import annotations + +import os import sqlite3 -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +import time +from dataclasses import dataclass from enum import StrEnum from os import PathLike -__version__ = "1.1.0" -__version_info__ = (1, 1, 0) +from cryptography.fernet import Fernet, InvalidToken + + +__version__ = "2.2.0" +__version_info__ = tuple(map(int, __version__.split("."))) class DataType(StrEnum): @@ -27,10 +32,28 @@ class NullType(object): class BlobType(object): def __init__(self, data: bytes = b""): - self._data = data + self.data = data def __str__(self): - return f"X'{self._data.hex()}'" + return f"X'{self.data.hex()}'" + + +class NotRandomFernet(Fernet): + """固定下来每次相同的 key 的加密结果相同,方便条件查询""" + + def __init__(self, key: bytes | str, fix_time: int, fix_iv: bytes, backend=None): + super().__init__(key, backend) + self._fix_time = fix_time + self._fix_iv = fix_iv + + def encrypt(self, data: bytes) -> bytes: + return self._encrypt_from_parts(data, self._fix_time, self._fix_iv) + + +def _encrypt_blob(blob: BlobType, fernet: NotRandomFernet) -> BlobType: + if fernet is None: + raise ValueError("Key is not set") + return BlobType(fernet.encrypt(blob.data)) def _get_type(data_type: DataType) -> type: @@ -79,8 +102,14 @@ class Column(object): has_default: bool = False default: NullType | int | float | str | BlobType = 0 + secure: bool = False + + def __post_init__(self): + if self.secure is True and self.data_type != DataType.BLOB: + raise ValueError("Only BLOB data can be secured") + def __str__(self): - head = f"{self.name} {self.data_type.name}" + head = f"{self.name} {self.data_type.value}" if self.primary_key: head = f"{head} PRIMARY KEY" if not self.nullable: @@ -94,63 +123,6 @@ class Column(object): __repr__ = __str__ -@dataclass -class AbsHeadRow(ABC): - # 这里的几个其实都是 list[str],但为了防止实例类中提示类型错误,就写成了 list - primary_keys: list = field(default_factory=list) - not_nulls: list = field(default_factory=list) - uniques: list = field(default_factory=list) - defaults: dict = field(default_factory=dict) - - @abstractmethod - def __post_init__(self): - pass - - def __getattr__(self, item): - # 此类声明的类变量自然都是不会被定义的,所以在 __post_init__ 中获取时就会跳转到这 - # 此时需要的其实就是这个类变量名称本身,因此直接返回 - if item in self.__annotations__: - return item - raise AttributeError - - def to_columns(self) -> list[Column]: - columns: list[Column] = [] - - # self.__annotations__ 会返回实例对象中定义的类变量名称和类型的键值对 - # 但不包括此抽象类中的 - for column_name in self.__annotations__: - column_type = self.__annotations__[column_name] - data_type = _get_data_type(column_type) - - # 实例对象的这个类变量设置了值,在此作为默认值 - if column_name in self.defaults: - default = self.defaults[column_name] - if type(default) is not column_type: - raise TypeError(f"Column {column_name} must be {column_type} but was {type(default)}") - has_default = True - else: - default = 0 # 此处设置的值无用,但是默认给它一个值 - has_default = False - - primary_key = column_name in self.primary_keys - nullable = column_name not in self.not_nulls - unique = column_name in self.uniques - - columns.append(Column( - name=column_name, - data_type=data_type, - primary_key=primary_key, - nullable=nullable, - unique=unique, - has_default=has_default, - default=default, - )) - return columns - - def to_column_dict(self) -> dict[str, Column]: - return {col.name: col for col in self.to_columns()} - - class Expression(object): def __init__(self, expr: str): @@ -174,14 +146,33 @@ class Expression(object): class Operand(object): - def __init__(self, column: Column | str): + def __init__( + self, + column: Column | str, + key: bytes = None, + fix_time: int = None, + fix_iv: bytes = None, + ): + self._column = column + self._key = key + self._fix_time = fix_time + self._fix_iv = fix_iv self._name = column.name if isinstance(column, Column) else column - def equal_to(self, value): - return Expression(f"{self._name} = {_to_string(value)}") - - def not_equal_to(self, value): - return Expression(f"{self._name} =! {_to_string(value)}") + def equal_to(self, value, not_: bool = False): + if isinstance(self._column, Column): + if self._column.data_type == DataType.BLOB and type(value) is str: + value = BlobType(value.encode("utf-8")) + # 这里不能换成 elif + if self._key is not None and self._column.secure and isinstance(value, BlobType): + fix_time = self._fix_time if self._fix_time is not None else int(time.time()) + fix_iv = self._fix_iv if self._fix_iv is not None else os.urandom(16) + try: + value = _encrypt_blob(value, NotRandomFernet(self._key, fix_time, fix_iv)) + except ValueError: + pass + op = "!=" if not_ else "=" + return Expression(f"{self._name} {op} {_to_string(value)}") def less_than(self, value): return Expression(f"{self._name} < {_to_string(value)}") @@ -254,11 +245,25 @@ def order(column: Column | str | int, class Sqlite3Worker(object): - def __init__(self, db_name: str | PathLike[str] = ":memory:"): + def __init__( + self, + db_name: str | PathLike[str] = ":memory:", + key: bytes = None, + fix_time: int = None, + fix_iv: bytes = None, + ): self._db_name = db_name self._conn = sqlite3.connect(db_name) self._cursor = self._conn.cursor() self._is_closed = False + self._fernet = None + if key is not None: + fix_time = fix_time if fix_time is not None else int(time.time()) + fix_iv = fix_iv if fix_iv is not None else os.urandom(16) + try: + self._fernet = NotRandomFernet(key, fix_time, fix_iv) + except ValueError: + pass def __del__(self): self.close() @@ -276,13 +281,12 @@ class Sqlite3Worker(object): def commit(self): self._conn.commit() - def create_table(self, table_name: str, columns: list[Column] | AbsHeadRow, - if_not_exists: bool = False, schema_name: str = "", *, execute: bool = True) -> str: + def create_table(self, table_name: str, columns: list[Column], + if_not_exists: bool = False, schema_name: str = "", + *, execute: bool = True) -> str: if table_name.startswith("sqlite_"): raise ValueError("Table name must not start with 'sqlite_')") - if isinstance(columns, AbsHeadRow): - columns = columns.to_columns() columns_str = ", ".join([str(col) for col in columns]) head = "CREATE TABLE" if if_not_exists: @@ -363,27 +367,41 @@ class Sqlite3Worker(object): return ", ".join(columns_str_ls) def insert_into(self, table_name: str, columns: list[Column | str], - values: list[list[str | int | float | BlobType]], + values: list[list[NullType | str | int | float | BlobType]], *, execute: bool = True, commit: bool = True) -> str: col_count = len(columns) columns_str = self._columns_to_string(columns) values_str_ls = [] - for value in values: - if len(value) != col_count: + for value_row in values: + if len(value_row) != col_count: raise ValueError(f"Length of values must be {col_count}") + + value_row_str_ls = [] for i in range(col_count): column = columns[i] + value = value_row[i] if isinstance(column, Column): - type_ = _get_type(column.data_type) + col_type = _get_type(column.data_type) + val_type = type(value) # 支持将 int 隐式转为 float - if type(value[i]) is int and type_ is float: - continue - if type(value[i]) is not type_: - raise ValueError(f"The {i + 1}(th) type of value must be {type_}," - f" because the column type is {column.data_type}") - # 这里的 value 是一行数据,是一个多值列表 - values_str_ls.append(f"({', '.join([_to_string(val) for val in value])})") + if val_type is int and col_type is float: + pass + # 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的 + elif val_type is NullType and column.nullable is True: + pass + # 其他类型不匹配 + elif val_type is not col_type: + raise ValueError(f"The {i + 1}(th) type of value must be {col_type}," + f" because the column type is {column.data_type}," + f" found {val_type}") + # 如果加密 + if column.secure and val_type is not NullType: + value = _to_string(_encrypt_blob(value, self._fernet)) + + value_row_str_ls.append(_to_string(value)) + + values_str_ls.append(f"({', '.join(value_row_str_ls)})") values_str = ", ".join(values_str_ls) @@ -391,8 +409,8 @@ class Sqlite3Worker(object): statement = f"{head} {table_name} ({columns_str}) VALUES {values_str};" if execute: self._cursor.execute(statement) - if commit: - self._conn.commit() + if commit: + self._conn.commit() return statement @staticmethod @@ -415,7 +433,7 @@ class Sqlite3Worker(object): where: Expression = None, order_by: list[str] | str = None, limit: int = None, offset: int = None, - *, execute: bool = True) -> tuple[str, list[tuple]]: + *, execute: bool = True) -> tuple[str, list[list]]: if len(columns) == 0: columns_str = "*" else: @@ -431,6 +449,22 @@ class Sqlite3Worker(object): if execute: self._cursor.execute(statement) rows = self._cursor.fetchall() + rows = [list(row) for row in rows] # 将每行转成列表,方便替换解密数据 + # 下面的整个循环都是为了找到需要解密的数据尝试解密 + for i in range(len(columns)): + column = columns[i] + if isinstance(column, Column) and column.secure: + for row in rows: + # 如果是加密的 BLOB 但是值不为 NULL 才解密 + if row[i] is not None and self._fernet is not None: + # 不管是key错误还是密文错误,都是 InvalidToken,貌似没法区分 + # 因此如果有的数据不是加密过的,应该跳过,不应该影响之后的密文解密, + # 因此这里还是得继续循环下去 + try: + row[i] = self._fernet.decrypt(row[i]) + except InvalidToken: + pass + return statement, rows else: return statement, [] @@ -445,18 +479,30 @@ class Sqlite3Worker(object): statement = f"{body};" if execute: self._cursor.execute(statement) - if commit: - self._conn.commit() + if commit: + self._conn.commit() return statement - def update(self, table_name: str, new_values: list[tuple[Column, NullType | int | float | str | BlobType]], + def update(self, table_name: str, new_values: list[tuple[Column | str, NullType | int | float | str | BlobType]], where: Expression = None, *, execute: bool = True, commit: bool = True) -> str: new_values_str_ls = [] for column, value in new_values: - if _get_data_type(type(value)) != column.data_type: - raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}") - new_values_str_ls.append(f"{column.name} = {_to_string(value)}") + if isinstance(column, Column): + # 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的 + if type(value) is NullType and column.nullable is True: + pass + elif _get_data_type(type(value)) != column.data_type: + raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}") + + if column.secure and type(value) is not NullType: + value = _encrypt_blob(value, self._fernet) + + name = column.name + else: + name = column + + new_values_str_ls.append(f"{name} = {_to_string(value)}") head = f"UPDATE {table_name}" body = f"{head} SET {', '.join(new_values_str_ls)}" @@ -466,6 +512,6 @@ class Sqlite3Worker(object): statement = f"{body};" if execute: self._cursor.execute(statement) - if commit: - self._conn.commit() + if commit: + self._conn.commit() return statement diff --git a/lib/config_utils.py b/lib/config_utils.py index 86f997d..76cc8c0 100644 --- a/lib/config_utils.py +++ b/lib/config_utils.py @@ -4,6 +4,8 @@ import sys import json from pathlib import Path +from cryptography.fernet import Fernet + def path_not_exist(path: str | Path) -> bool: """ @@ -50,7 +52,6 @@ def read_config(org_name: str, app_name: str) -> dict: config_path = get_config_path(org_name, app_name) if not config_path.exists(): config = { - "table_name": "entries", "button_min_width": 120, "last_db_path": "", "loaded_memory": {} @@ -76,3 +77,15 @@ def get_default_db_path(config: dict, org_name: str, app_name: str) -> str: def get_secrets_path(org_name: str, app_name: str) -> str: app_dir = get_app_dir(org_name, app_name) return str(app_dir / "secrets.db") + + +def get_or_generate_key(db_name: str, org_name: str, app_name: str) -> bytes: + app_dir = get_app_dir(org_name, app_name) + name = Path(db_name).name + key_path = app_dir / f"{name}.key" + if key_path.exists(): + key = key_path.read_bytes() + else: + key = Fernet.generate_key() + key_path.write_bytes(key) + return key diff --git a/lib/db_columns_def.py b/lib/db_columns_def.py index 0da353d..6943fe7 100644 --- a/lib/db_columns_def.py +++ b/lib/db_columns_def.py @@ -5,10 +5,10 @@ columns_d = { "entry_id": Column("entry_id", DataType.INTEGER, primary_key=True, unique=True), "title": Column("title", DataType.BLOB), "username": Column("username", DataType.BLOB), - "password": Column("password", DataType.BLOB), - "opt": Column("opt", DataType.TEXT), + "password": Column("password", DataType.BLOB, secure=True), + "opt": Column("opt", DataType.BLOB, secure=True), "url": Column("url", DataType.BLOB), - "notes": Column("notes", DataType.BLOB), + "notes": Column("notes", DataType.BLOB, secure=True), "uuid": Column("uuid", DataType.TEXT, nullable=False), "filepath": Column("filepath", DataType.BLOB, nullable=False), "path": Column("path", DataType.BLOB), diff --git a/lib/kps_operations.py b/lib/kps_operations.py index 6f90134..47e5454 100644 --- a/lib/kps_operations.py +++ b/lib/kps_operations.py @@ -38,7 +38,7 @@ def read_kps_to_db(kps_file: str | PathLike[str], password: str, blob_fy(trim_str(entry.title)), blob_fy(trim_str(entry.username)), blob_fy(entry.password), - extract_otp(entry.otp), + blob_fy(extract_otp(entry.otp)), blob_fy(trim_str(entry.url)), blob_fy(entry.notes), str(entry.uuid), diff --git a/lib/sec_db_columns_def.py b/lib/sec_db_columns_def.py index 0584bac..c66aff2 100644 --- a/lib/sec_db_columns_def.py +++ b/lib/sec_db_columns_def.py @@ -4,7 +4,7 @@ from .Sqlite3Helper import Column, DataType sec_columns_d = { "secret_id": Column("secret_id", DataType.INTEGER, primary_key=True, unique=True), "filepath": Column("filepath", DataType.BLOB), - "password": Column("password", DataType.BLOB), + "password": Column("password", DataType.BLOB, secure=True), } sec_all_columns = [ diff --git a/main.py b/main.py index d2c93a4..2ff521a 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ from lib.config_utils import ( from src.mw_kps_unifier import KpsUnifier import src.rc_kps_unifier -__version__ = '0.1.0' +__version__ = '0.2.0' __version_info__ = tuple(map(int, __version__.split('.'))) ORG_NAME = "JnPrograms" @@ -27,7 +27,8 @@ def main(): db_path = get_default_db_path(config, ORG_NAME, APP_NAME) secrets_path = get_secrets_path(ORG_NAME, APP_NAME) - win = KpsUnifier(db_path, secrets_path, config, __version__) + win = KpsUnifier(db_path, secrets_path, config, + ORG_NAME, APP_NAME, __version__) win.show() return app.exec() diff --git a/src/da_entry_info.py b/src/da_entry_info.py index a67f833..4e12d22 100644 --- a/src/da_entry_info.py +++ b/src/da_entry_info.py @@ -71,12 +71,11 @@ class DaEntryInfo(QtWidgets.QDialog): def __init__( self, entry_id: int, - config: dict, sqh: Sqlite3Worker, parent: QtWidgets.QWidget = None ): super().__init__(parent) - _, results = sqh.select(config["table_name"], all_columns, + _, results = sqh.select("entries", all_columns, where=Operand(entry_id_col).equal_to(entry_id)) entry = results[0] @@ -88,5 +87,3 @@ class DaEntryInfo(QtWidgets.QDialog): def sizeHint(self): return QtCore.QSize(640, 360) - - diff --git a/src/da_target_login.py b/src/da_target_login.py index 04cdd43..2d4ccd5 100644 --- a/src/da_target_login.py +++ b/src/da_target_login.py @@ -1,13 +1,14 @@ # coding: utf8 from PySide6 import QtWidgets, QtCore, QtGui +from pykeepass import PyKeePass +from pykeepass.exceptions import HeaderChecksumError, CredentialsError -class DaTargetLogin(QtWidgets.QDialog): - def __init__(self, parent=None): - super().__init__(parent) - self.setWindowTitle("目标文件") +class UiDaTargetLogin(object): + def __init__(self, window: QtWidgets.QWidget): + window.setWindowTitle("目标文件") self.vly_m = QtWidgets.QVBoxLayout() - self.setLayout(self.vly_m) + window.setLayout(self.vly_m) icon_ellipsis = QtGui.QIcon(":/asset/img/ellipsis.svg") self.icon_eye = QtGui.QIcon(":/asset/img/eye.svg") @@ -15,61 +16,77 @@ class DaTargetLogin(QtWidgets.QDialog): self.hly_path = QtWidgets.QHBoxLayout() self.vly_m.addLayout(self.hly_path) - self.lb_path = QtWidgets.QLabel("路径:", self) - self.lne_path = QtWidgets.QLineEdit(self) - self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=self) + self.lb_path = QtWidgets.QLabel("路径:", window) + self.lne_path = QtWidgets.QLineEdit(window) + self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=window) self.hly_path.addWidget(self.lb_path) self.hly_path.addWidget(self.lne_path) self.hly_path.addWidget(self.pbn_browse) self.hly_password = QtWidgets.QHBoxLayout() self.vly_m.addLayout(self.hly_password) - self.lb_password = QtWidgets.QLabel("密码:", self) - self.lne_password = QtWidgets.QLineEdit(self) + self.lb_password = QtWidgets.QLabel("密码:", window) + self.lne_password = QtWidgets.QLineEdit(window) self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password) - self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=self) + self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=window) self.hly_password.addWidget(self.lb_password) self.hly_password.addWidget(self.lne_password) self.hly_password.addWidget(self.pbn_eye) self.hly_bottom = QtWidgets.QHBoxLayout() self.vly_m.addLayout(self.hly_bottom) - self.pbn_ok = QtWidgets.QPushButton("确定", self) - self.pbn_cancel = QtWidgets.QPushButton("取消", self) + self.pbn_ok = QtWidgets.QPushButton("确定", window) + self.pbn_cancel = QtWidgets.QPushButton("取消", window) self.hly_bottom.addStretch(1) self.hly_bottom.addWidget(self.pbn_ok) self.hly_bottom.addWidget(self.pbn_cancel) self.vly_m.addStretch(1) - self.pbn_browse.clicked.connect(self.on_pbn_browse_clicked) - self.pbn_eye.clicked.connect(self.on_pbn_eye_clicked) - self.pbn_ok.clicked.connect(self.on_pbn_ok_clicked) - self.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked) + self.pbn_ok.setFocus() + + +class DaTargetLogin(QtWidgets.QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.ui = UiDaTargetLogin(self) + self.tar_kp: PyKeePass | None = None + + self.ui.pbn_browse.clicked.connect(self.on_pbn_browse_clicked) + self.ui.pbn_eye.clicked.connect(self.on_pbn_eye_clicked) + self.ui.pbn_ok.clicked.connect(self.on_pbn_ok_clicked) + self.ui.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked) def on_pbn_browse_clicked(self): filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "选择", "../", filter="KeePass 2 数据库 (*.kdbx);;所有文件 (*)") if len(filename) == 0: return - self.lne_path.setText(filename) + self.ui.lne_path.setText(filename) def on_pbn_eye_clicked(self): - if self.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password: - self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal) - self.pbn_eye.setIcon(self.icon_eye) + if self.ui.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password: + self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal) + self.ui.pbn_eye.setIcon(self.ui.icon_eye) else: - self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password) - self.pbn_eye.setIcon(self.icon_eye_off) + self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password) + self.ui.pbn_eye.setIcon(self.ui.icon_eye_off) def sizeHint(self): return QtCore.QSize(540, 40) def on_pbn_ok_clicked(self): + try: + self.tar_kp = PyKeePass(self.ui.lne_path.text(), self.ui.lne_password.text()) + except CredentialsError: + QtWidgets.QMessageBox.critical(self, "错误", "keepass 密码错误") + return + except (FileNotFoundError, HeaderChecksumError): + QtWidgets.QMessageBox.critical(self, "错误", "文件不存在或不是 keepass 文件") + return + self.accept() def on_pbn_cancel_clicked(self): + self.tar_kp = None self.reject() - - - diff --git a/src/gbx_kps_login.py b/src/gbx_kps_login.py index 2afdfae..ea607ec 100644 --- a/src/gbx_kps_login.py +++ b/src/gbx_kps_login.py @@ -90,7 +90,7 @@ class GbxKpsLogin(QtWidgets.QGroupBox): kp = read_kps_to_db( kps_file=self.lne_path.text(), password=self.lne_password.text(), - table_name=self.config["table_name"], + table_name="entries", sqh=self.sqh ) except CredentialsError: diff --git a/src/mw_kps_unifier.py b/src/mw_kps_unifier.py index 42aac69..a5626fb 100644 --- a/src/mw_kps_unifier.py +++ b/src/mw_kps_unifier.py @@ -10,7 +10,7 @@ from .cmbx_styles import StyleComboBox from lib.Sqlite3Helper import Sqlite3Worker from lib.db_columns_def import all_columns from lib.sec_db_columns_def import sec_all_columns -from lib.config_utils import write_config +from lib.config_utils import write_config, get_or_generate_key class UiKpsUnifier(object): @@ -78,10 +78,15 @@ class KpsUnifier(QtWidgets.QMainWindow): db_path: str, secrets_path: str, config: dict, + org_name: str, + app_name: str, version: str, parent: QtWidgets.QWidget = None, ): super().__init__(parent) + self.org_name = org_name + self.app_name = app_name + self.db_path = db_path self.secrets_path = secrets_path self.config = config @@ -103,18 +108,23 @@ class KpsUnifier(QtWidgets.QMainWindow): def __del__(self): self.config["last_db_path"] = self.db_path - write_config(self.config, - QtWidgets.QApplication.organizationName(), - QtWidgets.QApplication.applicationName()) + write_config(self.config, self.org_name, self.app_name) def sizeHint(self): return QtCore.QSize(860, 640) def init_db(self) -> Sqlite3Worker: - sqh = Sqlite3Worker(self.db_path) - sqh.create_table(self.config["table_name"], all_columns, if_not_exists=True) + key = get_or_generate_key(self.db_path, self.org_name, self.app_name) + sqh = Sqlite3Worker(self.db_path, key) + sqh.create_table("entries", all_columns, if_not_exists=True) return sqh + def init_secrets_db(self) -> Sqlite3Worker: + key = get_or_generate_key(self.secrets_path, self.org_name, self.app_name) + sec_sqh = Sqlite3Worker(self.secrets_path, key) + sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True) + return sec_sqh + def update_db(self, filename: str): self.db_path = filename self.sqh = self.init_db() @@ -153,8 +163,3 @@ class KpsUnifier(QtWidgets.QMainWindow): def on_act_about_qt_triggered(self): QtWidgets.QMessageBox.aboutQt(self, "关于 Qt") - - def init_secrets_db(self) -> Sqlite3Worker: - sec_sqh = Sqlite3Worker(self.secrets_path) - sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True) - return sec_sqh diff --git a/src/page_load.py b/src/page_load.py index 831bf6d..72a5d5a 100644 --- a/src/page_load.py +++ b/src/page_load.py @@ -5,7 +5,8 @@ from PySide6 import QtWidgets from pykeepass import PyKeePass from .gbx_kps_login import GbxKpsLogin -from .utils import accept_warning +from .da_target_login import DaTargetLogin +from .utils import accept_warning, HorizontalLine from lib.Sqlite3Helper import Sqlite3Worker @@ -81,6 +82,15 @@ class PageLoad(QtWidgets.QWidget): self.vly_left = QtWidgets.QVBoxLayout() self.hly_m.addLayout(self.vly_left) + self.pbn_set_target = QtWidgets.QPushButton("设置目标文件", self) + # 暂时没用 + self.pbn_set_target.setDisabled(True) + self.pbn_set_target.setMinimumWidth(config["button_min_width"]) + self.vly_left.addWidget(self.pbn_set_target) + + self.hln_1 = HorizontalLine(self) + self.vly_left.addWidget(self.hln_1) + self.pbn_add_kps = QtWidgets.QPushButton("添加 KPS", self) self.pbn_add_kps.setMinimumWidth(config["button_min_width"]) self.vly_left.addWidget(self.pbn_add_kps) @@ -101,6 +111,8 @@ class PageLoad(QtWidgets.QWidget): self.wg_sa = WgLoadKps(config, file_kp, sqh, sec_sqh, self.sa_m) self.sa_m.setWidget(self.wg_sa) + self.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked) + self.pbn_add_kps.clicked.connect(self.on_pbn_add_kps_clicked) self.pbn_clear_db.clicked.connect(self.on_pbn_clear_db_clicked) self.pbn_clear_loaded_mem.clicked.connect(self.on_pbn_clear_loaded_mem_clicked) @@ -121,7 +133,7 @@ class PageLoad(QtWidgets.QWidget): return try: - self.sqh.delete_from(self.config["table_name"]) + self.sqh.delete_from("entries") except sqlite3.OperationalError as e: QtWidgets.QMessageBox.critical(self, "错误", f"清空数据库失败:\n{e}") else: @@ -148,3 +160,7 @@ class PageLoad(QtWidgets.QWidget): # 更新kps加载状态 for wg in self.wg_sa.kps_wgs: self.wg_sa.update_load_status(wg) + + def on_pbn_set_target_clicked(self): + da_target_login = DaTargetLogin(self) + da_target_login.exec() diff --git a/src/page_query.py b/src/page_query.py index 6223f76..13b68ec 100644 --- a/src/page_query.py +++ b/src/page_query.py @@ -5,7 +5,6 @@ from PySide6 import QtWidgets, QtCore, QtGui from pykeepass import PyKeePass from .da_entry_info import DaEntryInfo -from .da_target_login import DaTargetLogin from .utils import HorizontalLine, get_filepath_uuids_map, accept_warning from lib.Sqlite3Helper import Sqlite3Worker, Expression, Operand from lib.db_columns_def import ( @@ -18,7 +17,7 @@ from lib.kps_operations import blob_fy class QueryTableModel(QtCore.QAbstractTableModel): - def __init__(self, query_results: list[tuple], parent=None): + def __init__(self, query_results: list[list], parent=None): super().__init__(parent) self.query_results = query_results self.headers = ["序号", "标题", "用户名", "URL"] @@ -98,10 +97,6 @@ class UiPageQuery(object): self.pbn_execute.setMinimumWidth(config["button_min_width"]) self.vly_sa_wg.addWidget(self.pbn_execute) - self.pbn_set_target = QtWidgets.QPushButton("目标文件", window) - self.pbn_set_target.setMinimumWidth(config["button_min_width"]) - self.vly_sa_wg.addWidget(self.pbn_set_target) - self.hln_1 = HorizontalLine(window) self.vly_sa_wg.addWidget(self.hln_1) @@ -147,7 +142,6 @@ class PageQuery(QtWidgets.QWidget): self.ui.act_delete.triggered_with_str.connect(self.on_act_mark_triggered_with_str) self.ui.pbn_execute.clicked.connect(self.on_pbn_execute_clicked) - self.ui.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked) self.ui.pbn_all.clicked.connect(self.on_pbn_all_clicked) self.ui.pbn_deleted.clicked.connect(self.on_pbn_deleted_clicked) @@ -169,7 +163,7 @@ class PageQuery(QtWidgets.QWidget): }, { "name": "谷歌文档", - "where": "url LIKE 'https://docs.google.com/%' or url LIKE 'https://drive.google.com/%'" + "where": "(url LIKE 'https://docs.google.com/%' OR url LIKE 'https://drive.google.com/%')" }, ] for fil in default_filters: @@ -193,7 +187,7 @@ class PageQuery(QtWidgets.QWidget): self.set_filter_button(fil) def on_custom_filters_clicked_with_data(self, data: dict): - _, results = self.sqh.select(self.config["table_name"], query_columns, + _, results = self.sqh.select("entries", query_columns, where=Expression(data["where"]).and_(Operand(deleted_col).equal_to(0))) model = QueryTableModel(results, self) self.ui.trv_m.setModel(model) @@ -202,20 +196,20 @@ class PageQuery(QtWidgets.QWidget): self.sqh = sqh def on_pbn_all_clicked(self): - _, results = self.sqh.select(self.config["table_name"], query_columns, + _, results = self.sqh.select("entries", query_columns, where=Operand(deleted_col).equal_to(0)) model = QueryTableModel(results, self) self.ui.trv_m.setModel(model) def on_pbn_deleted_clicked(self): - _, results = self.sqh.select(self.config["table_name"], query_columns, + _, results = self.sqh.select("entries", query_columns, where=Operand(deleted_col).equal_to(1)) model = QueryTableModel(results, self) self.ui.trv_m.setModel(model) def on_trv_m_double_clicked(self, index: QtCore.QModelIndex): entry_id = index.siblingAtColumn(0).data(QtCore.Qt.ItemDataRole.DisplayRole) - da_entry_info = DaEntryInfo(entry_id, self.config, self.sqh, self) + da_entry_info = DaEntryInfo(entry_id, self.sqh, self) da_entry_info.exec() def on_trv_m_custom_context_menu_requested(self, pos: QtCore.QPoint): @@ -228,7 +222,7 @@ class PageQuery(QtWidgets.QWidget): for index in indexes if index.column() == 0 ] - self.sqh.update(self.config["table_name"], [(status_col, info)], + self.sqh.update("entries", [(status_col, info)], where=Operand(entry_id_col).in_(entry_ids)) def on_pbn_execute_clicked(self): @@ -236,7 +230,7 @@ class PageQuery(QtWidgets.QWidget): return # 删除功能 - _, results = self.sqh.select(self.config["table_name"], sim_columns, + _, results = self.sqh.select("entries", sim_columns, where=Operand(status_col).equal_to("delete")) file_uuids = get_filepath_uuids_map(results) @@ -252,7 +246,7 @@ class PageQuery(QtWidgets.QWidget): for u in file_uuids[file]: total += 1 - self.sqh.update(self.config["table_name"], [(deleted_col, 1)], + self.sqh.update("entries", [(deleted_col, 1)], where=Operand(uuid_col).equal_to(u).and_( Operand(filepath_col).equal_to(blob_fy(file)))) @@ -269,10 +263,6 @@ class PageQuery(QtWidgets.QWidget): QtWidgets.QMessageBox.information(self, "提示", f"共 {total} 条标记的条目,已删除 {success} 条,无效 {invalid} 条。") - def on_pbn_set_target_clicked(self): - da_target_login = DaTargetLogin(self) - da_target_login.exec() - class PushButtonWithData(QtWidgets.QPushButton): diff --git a/src/page_similar.py b/src/page_similar.py index 6bbe320..6da55e5 100644 --- a/src/page_similar.py +++ b/src/page_similar.py @@ -1,6 +1,5 @@ # coding: utf8 from itertools import combinations -from uuid import UUID from PySide6 import QtWidgets, QtCore from PySide6.QtCore import QAbstractTableModel @@ -70,7 +69,7 @@ class PageSimilar(QtWidgets.QWidget): self.pbn_delete_invalid_data.clicked.connect(self.on_pbn_delete_invalid_data_clicked) def on_pbn_read_db_clicked(self): - _, results = self.sqh.select(self.config["table_name"], sim_columns) + _, results = self.sqh.select("entries", sim_columns) file_uuids = get_filepath_uuids_map(results) files = file_uuids.keys() @@ -98,11 +97,11 @@ class PageSimilar(QtWidgets.QWidget): if accept_warning(self, True, "警告", "你确定要从数据库删除无效文件的记录吗?"): return - _, filepaths = self.sqh.select(self.config["table_name"], [filepath_col,]) + _, filepaths = self.sqh.select("entries", [filepath_col,]) unique_filepaths = set([p[0].decode("utf8") for p in filepaths]) invalid_filepaths = [p for p in unique_filepaths if path_not_exist(p)] for path in invalid_filepaths: - self.sqh.delete_from(self.config["table_name"], + self.sqh.delete_from("entries", where=Operand(filepath_col).equal_to(blob_fy(path)), commit=False) self.sqh.commit() diff --git a/src/utils.py b/src/utils.py index 2bb79c5..82e59f9 100644 --- a/src/utils.py +++ b/src/utils.py @@ -19,7 +19,7 @@ class HorizontalLine(QtWidgets.QFrame): self.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) -def get_filepath_uuids_map(query_results: list[tuple]) -> dict[str, list[str]]: +def get_filepath_uuids_map(query_results: list[list]) -> dict[str, list[str]]: file_uuids: dict[str, list[str]] = {} for u, filepath in query_results: filepath = filepath.decode("utf8")