diff --git a/README.md b/README.md index dccb4eb..96e4b8f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,17 @@ # KpsUnifier 一个可以同时处理多个 keepass 文件的工具 + +## 版本日志 + +### v1.2.0 + +- 修复添加的 keepass 文件为空时静默报错的问题 +- 增加记录上次打开目录的位置 +- 转移条目时直接转移整个条目而不是逐项复制(可以转移历史) +- 查询页增加一个输入框显示当前查询条数 + +### v1.1.1 + +- 修复无法保存配置文件的问题 +- 修复切换数据库时 kps 文件加载状态更新错误的问题 diff --git a/lib/Sqlite3Helper.py b/lib/Sqlite3Helper.py index 0202284..7b4cdb2 100644 --- a/lib/Sqlite3Helper.py +++ b/lib/Sqlite3Helper.py @@ -7,11 +7,12 @@ import time from dataclasses import dataclass from enum import StrEnum from os import PathLike +from types import NoneType from cryptography.fernet import Fernet, InvalidToken -__version__ = "2.2.0" +__version__ = "2.2.2" __version_info__ = tuple(map(int, __version__.split("."))) @@ -32,13 +33,18 @@ class NullType(object): class BlobType(object): def __init__(self, data: bytes = b""): - self.data = data + self._data = data def __str__(self): - return f"X'{self.data.hex()}'" + return f"X'{self._data.hex()}'" + + def encrypt(self, fernet: _NotRandomFernet) -> BlobType: + if fernet is None: + raise ValueError("Key is not set") + return BlobType(fernet.encrypt(self._data)) -class NotRandomFernet(Fernet): +class _NotRandomFernet(Fernet): """固定下来每次相同的 key 的加密结果相同,方便条件查询""" def __init__(self, key: bytes | str, fix_time: int, fix_iv: bytes, backend=None): @@ -50,44 +56,53 @@ class NotRandomFernet(Fernet): return self._encrypt_from_parts(data, self._fix_time, self._fix_iv) -def _encrypt_blob(blob: BlobType, fernet: NotRandomFernet) -> BlobType: - if fernet is None: - raise ValueError("Key is not set") - return BlobType(fernet.encrypt(blob.data)) +VALUE_TYPES = None | NullType | int | float | str | bytes | BlobType -def _get_type(data_type: DataType) -> type: +def _check_data_type(data_type: DataType, allow_null: bool, value) -> bool: + value_type = type(value) + allow_types = [] if data_type == DataType.NULL: - return NullType - if data_type == DataType.INTEGER: - return int - if data_type == DataType.REAL: - return float - if data_type == DataType.TEXT: - return str + pass + elif data_type == DataType.INTEGER: + allow_types.extend([int, ]) + elif data_type == DataType.REAL: + allow_types.extend([int, float]) + elif data_type == DataType.TEXT: + allow_types.extend([str, ]) + elif data_type == DataType.BLOB: + allow_types.extend([str, bytes, BlobType]) + + if allow_null: + allow_types.extend([NoneType, NullType]) + + return value_type in allow_types + + +def _implicitly_convert(data_type: DataType, value): + if data_type == DataType.REAL and type(value) is int: + return float(value) if data_type == DataType.BLOB: - return BlobType - raise TypeError(f"Data type {data_type} is not supported") + if type(value) is str: + return BlobType(value.encode("utf-8")) + if type(value) is bytes: + return BlobType(value) + + return value -def _get_data_type(type_: type) -> DataType: - if type_ is NullType: - return DataType.NULL - if type_ is int: - return DataType.INTEGER - if type_ is float: - return DataType.REAL - if type_ is str: - return DataType.TEXT - if type_ is BlobType: - return DataType.BLOB - raise TypeError(f"Data type {type_} is not supported") +def _is_null(value) -> bool: + return type(value) in (NoneType, NullType) def _to_string(value): - # 如果传入的类型不是 text 会直接返回原值 - if type(value) is str: - if not (value.startswith("'") or value.endswith("'")): + if value is None: + value = NullType() + elif type(value) is str: + # 只要开头或者结尾任意一个字符不是单引号 + if not (value.startswith("'") and value.endswith("'")): + # 把单引号换为两个单引号转义 + value = value.replace("'", "''") value = f"'{value}'" return str(value) @@ -171,7 +186,7 @@ class Operand(object): fix_time = self._fix_time if self._fix_time is not None else int(time.time()) fix_iv = self._fix_iv if self._fix_iv is not None else os.urandom(16) try: - value = _encrypt_blob(value, NotRandomFernet(self._key, fix_time, fix_iv)) + value = value.encrypt(_NotRandomFernet(self._key, fix_time, fix_iv)) except ValueError: pass op = "!=" if not_ else "=" @@ -264,7 +279,7 @@ class Sqlite3Worker(object): fix_time = fix_time if fix_time is not None else int(time.time()) fix_iv = fix_iv if fix_iv is not None else os.urandom(16) try: - self._fernet = NotRandomFernet(key, fix_time, fix_iv) + self._fernet = _NotRandomFernet(key, fix_time, fix_iv) except ValueError: pass @@ -284,6 +299,12 @@ class Sqlite3Worker(object): def commit(self): self._conn.commit() + def _execute(self, statement: str): + try: + self._cursor.execute(statement) + except sqlite3.Error as e: + raise sqlite3.Error(f"Error name: {e.sqlite_errorname};\nError statement: {statement}") + def create_table(self, table_name: str, columns: list[Column], if_not_exists: bool = False, schema_name: str = "", *, execute: bool = True) -> str: @@ -301,7 +322,7 @@ class Sqlite3Worker(object): statement = f"{head} {name} ({columns_str});" if execute: - self._cursor.execute(statement) + self._execute(statement) return statement def drop_table(self, table_name: str, if_exists: bool = False, @@ -316,14 +337,14 @@ class Sqlite3Worker(object): statement = f"{head} {name};" if execute: - self._cursor.execute(statement) + self._execute(statement) return statement def rename_table(self, table_name: str, new_name: str, *, execute: bool = True) -> str: head = "ALTER TABLE" statement = f"{head} {table_name} RENAME TO {new_name};" if execute: - self._cursor.execute(statement) + self._execute(statement) return statement def add_column(self, table_name: str, column: Column, *, execute: bool = True) -> str: @@ -338,7 +359,7 @@ class Sqlite3Worker(object): head = "ALTER TABLE" statement = f"{head} {table_name} ADD COLUMN {str(column)};" if execute: - self._cursor.execute(statement) + self._execute(statement) return statement def rename_column(self, table_name: str, column_name: str, @@ -349,7 +370,7 @@ class Sqlite3Worker(object): head = "ALTER TABLE" statement = f"{head} {table_name} RENAME COLUMN {column_name} TO {new_name};" if execute: - self._cursor.execute(statement) + self._execute(statement) return statement def show_tables(self) -> list[str]: @@ -370,7 +391,7 @@ class Sqlite3Worker(object): return ", ".join(columns_str_ls) def insert_into(self, table_name: str, columns: list[Column | str], - values: list[list[NullType | str | int | float | BlobType]], + values: list[list[VALUE_TYPES]], *, execute: bool = True, commit: bool = True) -> str: col_count = len(columns) columns_str = self._columns_to_string(columns) @@ -381,26 +402,15 @@ class Sqlite3Worker(object): raise ValueError(f"Length of values must be {col_count}") value_row_str_ls = [] - for i in range(col_count): - column = columns[i] - value = value_row[i] + for column, value in zip(columns, value_row): if isinstance(column, Column): - col_type = _get_type(column.data_type) - val_type = type(value) - # 支持将 int 隐式转为 float - if val_type is int and col_type is float: - pass - # 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的 - elif val_type is NullType and column.nullable is True: - pass - # 其他类型不匹配 - elif val_type is not col_type: - raise ValueError(f"The {i + 1}(th) type of value must be {col_type}," - f" because the column type is {column.data_type}," - f" found {val_type}") + if not _check_data_type(column.data_type, column.nullable, value): + raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}") + # 这一步一定在加密之前 + value = _implicitly_convert(column.data_type, value) # 如果加密 - if column.secure and val_type is not NullType: - value = _to_string(_encrypt_blob(value, self._fernet)) + if column.secure and not _is_null(value): + value = value.encrypt(self._fernet) value_row_str_ls.append(_to_string(value)) @@ -411,7 +421,7 @@ class Sqlite3Worker(object): head = "INSERT INTO" statement = f"{head} {table_name} ({columns_str}) VALUES {values_str};" if execute: - self._cursor.execute(statement) + self._execute(statement) if commit: self._conn.commit() return statement @@ -450,7 +460,7 @@ class Sqlite3Worker(object): statement = f"{body};" if execute: - self._cursor.execute(statement) + self._execute(statement) rows = self._cursor.fetchall() rows = [list(row) for row in rows] # 将每行转成列表,方便替换解密数据 # 下面的整个循环都是为了找到需要解密的数据尝试解密 @@ -481,25 +491,24 @@ class Sqlite3Worker(object): statement = f"{body};" if execute: - self._cursor.execute(statement) + self._execute(statement) if commit: self._conn.commit() return statement - def update(self, table_name: str, new_values: list[tuple[Column | str, NullType | int | float | str | BlobType]], + def update(self, table_name: str, new_values: list[tuple[Column | str, VALUE_TYPES]], where: Expression = None, *, execute: bool = True, commit: bool = True) -> str: new_values_str_ls = [] for column, value in new_values: if isinstance(column, Column): - # 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的 - if type(value) is NullType and column.nullable is True: - pass - elif _get_data_type(type(value)) != column.data_type: + if not _check_data_type(column.data_type, column.nullable, value): raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}") - - if column.secure and type(value) is not NullType: - value = _encrypt_blob(value, self._fernet) + # 这一步一定在加密之前 + value = _implicitly_convert(column.data_type, value) + # 如果加密 + if column.secure and not _is_null(value): + value = value.encrypt(self._fernet) name = column.name else: @@ -514,7 +523,7 @@ class Sqlite3Worker(object): statement = f"{body};" if execute: - self._cursor.execute(statement) + self._execute(statement) if commit: self._conn.commit() return statement diff --git a/lib/config_utils.py b/lib/config_utils.py index 76cc8c0..570ed36 100644 --- a/lib/config_utils.py +++ b/lib/config_utils.py @@ -49,17 +49,22 @@ def get_config_path(org_name: str, app_name: str) -> Path: def read_config(org_name: str, app_name: str) -> dict: + config = { + "button_min_width": 120, + "last_db_path": "", + "last_open_path": "../", + "loaded_memory": {} + } config_path = get_config_path(org_name, app_name) if not config_path.exists(): - config = { - "button_min_width": 120, - "last_db_path": "", - "loaded_memory": {} - } config_path.write_text(json.dumps(config, ensure_ascii=False, indent=4), encoding="utf-8") return config else: - return json.loads(config_path.read_text(encoding="utf-8")) + exist_config = json.loads(config_path.read_text(encoding="utf-8")) + for key, value in config.items(): + if key not in exist_config: + exist_config[key] = value + return exist_config def write_config(config: dict, org_name: str, app_name: str): diff --git a/lib/kps_operations.py b/lib/kps_operations.py index 47e5454..d3df350 100644 --- a/lib/kps_operations.py +++ b/lib/kps_operations.py @@ -46,5 +46,8 @@ def read_kps_to_db(kps_file: str | PathLike[str], password: str, blob_fy("::".join(entry.path[:-1])), ]) + if len(values) == 0: + raise ValueError("Keepass 文件为空") + sqh.insert_into(table_name, insert_columns, values) return kp diff --git a/main.py b/main.py index 9d000a9..1d27fdc 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from lib.config_utils import ( from src.mw_kps_unifier import KpsUnifier import src.rc_kps_unifier -__version__ = '1.1.1' +__version__ = '1.2.0' __version_info__ = tuple(map(int, __version__.split('.'))) ORG_NAME = "JnPrograms" diff --git a/src/da_target_login.py b/src/da_target_login.py index 5d8b2ff..f8bc4db 100644 --- a/src/da_target_login.py +++ b/src/da_target_login.py @@ -1,4 +1,6 @@ # coding: utf8 +from pathlib import Path + from PySide6 import QtWidgets, QtCore, QtGui from pykeepass import PyKeePass from pykeepass.exceptions import HeaderChecksumError, CredentialsError @@ -45,10 +47,11 @@ class UiDaTargetLogin(object): class DaTargetLogin(QtWidgets.QDialog): - def __init__(self, parent=None): + def __init__(self, config: dict, parent=None): super().__init__(parent) self.ui = UiDaTargetLogin(self) self.tar_kp: PyKeePass | None = None + self.config = config self.ui.pbn_browse.clicked.connect(self.on_pbn_browse_clicked) self.ui.pbn_eye.clicked.connect(self.on_pbn_eye_clicked) @@ -56,11 +59,12 @@ class DaTargetLogin(QtWidgets.QDialog): self.ui.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked) def on_pbn_browse_clicked(self): - filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "选择", "../", + filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "选择", self.config["last_open_path"], filter="KeePass 2 数据库 (*.kdbx);;所有文件 (*)") if len(filename) == 0: return self.ui.lne_path.setText(filename) + self.config["last_open_path"] = str(Path(filename).parent) def on_pbn_eye_clicked(self): if self.ui.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password: diff --git a/src/gbx_kps_login.py b/src/gbx_kps_login.py index ea607ec..b0e0213 100644 --- a/src/gbx_kps_login.py +++ b/src/gbx_kps_login.py @@ -97,6 +97,9 @@ class GbxKpsLogin(QtWidgets.QGroupBox): QtWidgets.QMessageBox.critical(self, "密码错误", f"{self.lne_path.text()}\n密码错误") return + except ValueError as e: + QtWidgets.QMessageBox.critical(self, "错误", str(e)) + return self.file_kp[self.lne_path.text()] = kp self.sec_sqh.insert_into("secrets", insert_sec_columns, [ diff --git a/src/mw_kps_unifier.py b/src/mw_kps_unifier.py index bae4eb1..77bf8fb 100644 --- a/src/mw_kps_unifier.py +++ b/src/mw_kps_unifier.py @@ -1,4 +1,6 @@ # coding: utf8 +from pathlib import Path + from PySide6 import QtWidgets, QtCore, QtGui from pykeepass import PyKeePass @@ -131,18 +133,20 @@ class KpsUnifier(QtWidgets.QMainWindow): self.ui.lne_db_path.setText(filename) def on_act_new_triggered(self): - filename, _ = QtWidgets.QFileDialog.getSaveFileName(self, "新建", "../", + filename, _ = QtWidgets.QFileDialog.getSaveFileName(self, "新建", self.config["last_open_path"], filter="数据库 (*.db);;所有文件 (*)") if len(filename) == 0: return self.update_db(filename) + self.config["last_open_path"] = str(Path(filename).parent) def on_act_open_triggered(self): - filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "打开", "../", + filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "打开", self.config["last_open_path"], filter="数据库 (*.db);;所有文件 (*)") if len(filename) == 0: return self.update_db(filename) + self.config["last_open_path"] = str(Path(filename).parent) def on_act_load_triggered(self): self.ui.sw_m.setCurrentIndex(0) diff --git a/src/page_load.py b/src/page_load.py index bf29811..da36193 100644 --- a/src/page_load.py +++ b/src/page_load.py @@ -109,15 +109,17 @@ class PageLoad(QtWidgets.QWidget): self.pbn_clear_loaded_mem.clicked.connect(self.on_pbn_clear_loaded_mem_clicked) def update_sqh(self, sqh: Sqlite3Worker): + self.sqh = sqh self.wg_sa.update_sqh(sqh) def on_pbn_add_kps_clicked(self): - filenames, _ = QtWidgets.QFileDialog.getOpenFileNames(self, "选择", "../", + filenames, _ = QtWidgets.QFileDialog.getOpenFileNames(self, "选择", self.config["last_open_path"], filter="KeePass 2 数据库 (*.kdbx);;所有文件 (*)") if len(filenames) == 0: return for filename in filenames: self.wg_sa.add_kps(filename) + self.config["last_open_path"] = str(Path(filenames[0]).parent) def on_pbn_clear_db_clicked(self): if accept_warning(self, True, "警告", "你确定要清空当前数据库吗?"): @@ -141,12 +143,16 @@ class PageLoad(QtWidgets.QWidget): self.wg_sa.update_load_status(wg) def on_pbn_clear_loaded_mem_clicked(self): - if accept_warning(self, True, "警告", "你确定要清空所有加载记忆吗?"): + if accept_warning(self, True, "警告", "你确定要清空当前加载记忆吗?"): return - loaded_mem: dict = self.config["loaded_memory"] - loaded_mem.clear() - QtWidgets.QMessageBox.information(self, "提示", "已清空加载记忆") + filename = str(Path(self.sqh.db_name).name) + loaded_mem: list = self.config["loaded_memory"].get(filename, None) + if loaded_mem is None: + QtWidgets.QMessageBox.warning(self, "警告", f"没有找到 {filename} 的加载记忆") + else: + loaded_mem.clear() + QtWidgets.QMessageBox.information(self, "提示", "已清空加载记忆") # 更新kps加载状态 for wg in self.wg_sa.kps_wgs: diff --git a/src/page_query.py b/src/page_query.py index d43879e..fcdd02a 100644 --- a/src/page_query.py +++ b/src/page_query.py @@ -1,5 +1,6 @@ # coding: utf8 import json +from pathlib import Path from uuid import UUID from PySide6 import QtWidgets, QtCore, QtGui from pykeepass import PyKeePass @@ -117,7 +118,13 @@ class UiPageQuery(object): self.vly_sa_wg.addStretch(1) - self.pbn_read_filters = QtWidgets.QPushButton("更多过滤", window) + self.lne_entries_count = QtWidgets.QLineEdit(self.sa_wg) + self.lne_entries_count.setDisabled(True) + self.lne_entries_count.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + self.lne_entries_count.setMinimumWidth(config["button_min_width"]) + self.vly_sa_wg.addWidget(self.lne_entries_count) + + self.pbn_read_filters = QtWidgets.QPushButton("更多过滤", self.sa_wg) self.pbn_read_filters.setMinimumWidth(config["button_min_width"]) self.vly_sa_wg.addWidget(self.pbn_read_filters) @@ -181,11 +188,11 @@ class PageQuery(QtWidgets.QWidget): def set_filter_button(self, fil: dict): pbn_fil = PushButtonWithData(fil, self.ui.sa_wg, fil["name"]) pbn_fil.setMinimumWidth(self.config["button_min_width"]) - self.ui.vly_sa_wg.insertWidget(self.ui.vly_sa_wg.count() - 2, pbn_fil) + self.ui.vly_sa_wg.insertWidget(self.ui.vly_sa_wg.count() - 3, pbn_fil) pbn_fil.clicked_with_data.connect(self.on_custom_filters_clicked_with_data) def on_pbn_read_filters_clicked(self): - filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "打开", "../", + filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "打开", self.config["last_open_path"], filter="JSON 文件 (*.json);;所有文件 (*)") if len(filename) == 0: return @@ -194,12 +201,14 @@ class PageQuery(QtWidgets.QWidget): for fil in filter_ls: self.set_filter_button(fil) + self.config["last_open_path"] = str(Path(filename).parent) def on_custom_filters_clicked_with_data(self, data: dict): _, results = self.sqh.select("entries", query_columns, where=Expression(data["where"]).and_(Operand(deleted_col).equal_to(0))) model = QueryTableModel(results, self) self.ui.trv_m.setModel(model) + self.ui.lne_entries_count.setText(str(model.rowCount())) def update_sqh(self, sqh: Sqlite3Worker): self.sqh = sqh @@ -209,6 +218,7 @@ class PageQuery(QtWidgets.QWidget): where=Operand(deleted_col).equal_to(0)) model = QueryTableModel(results, self) self.ui.trv_m.setModel(model) + self.ui.lne_entries_count.setText(str(model.rowCount())) def on_pbn_deleted_clicked(self): _, results = self.sqh.select("entries", query_columns, @@ -247,13 +257,10 @@ class PageQuery(QtWidgets.QWidget): kp = self.file_kp[filepath] return kp - def delete_the_delete_and_transfer(self, transfer: bool = False): - cond = Operand(status_col).equal_to("delete") - if transfer is True: - cond = cond.or_(Operand(status_col).equal_to("transfer"), high_priority=True) - cond = cond.and_(Operand(deleted_col).equal_to(0)) - - _, results = self.sqh.select("entries", sim_columns, where=cond) + def delete_the_delete(self): + _, results = self.sqh.select("entries", sim_columns, + where=Operand(status_col).equal_to("delete") + .and_(Operand(deleted_col).equal_to(0))) file_uuids = get_filepath_uuids_map(results) total, success, invalid = 0, 0, 0 @@ -272,10 +279,11 @@ class PageQuery(QtWidgets.QWidget): continue kp.delete_entry(entry) + success += 1 + self.sqh.update("entries", [(deleted_col, 1)], where=Operand(uuid_col).equal_to(u).and_( Operand(filepath_col).equal_to(blob_fy(file)))) - success += 1 kp.save() @@ -309,17 +317,15 @@ class PageQuery(QtWidgets.QWidget): invalid += 1 continue - self.tar_kp.add_entry( - dest_group, - entry.title or "", - entry.username or "", - entry.password or "", - entry.url, - entry.notes, - otp=entry.otp, - force_creation=True - ) + kp.move_entry(entry, dest_group) success += 1 + + self.sqh.update("entries", [(deleted_col, 1)], + where=Operand(uuid_col).equal_to(u).and_( + Operand(filepath_col).equal_to(blob_fy(file)))) + + kp.save() + self.tar_kp.save() QtWidgets.QMessageBox.information(self, "提示", f"共 {total} 条转移条目,成功 {success} 条,失败 {invalid} 条。") @@ -336,10 +342,10 @@ class PageQuery(QtWidgets.QWidget): if transfer: self.transfer_the_transfer() - self.delete_the_delete_and_transfer(transfer) + self.delete_the_delete() def on_pbn_set_target_clicked(self): - da_target_login = DaTargetLogin(self) + da_target_login = DaTargetLogin(self.config, self) da_target_login.exec() self.tar_kp = da_target_login.tar_kp