dev: 支持加密数据

This commit is contained in:
Julian Freeman
2024-07-20 16:58:31 -04:00
parent b476da4697
commit 4c6ee85f5a
14 changed files with 257 additions and 173 deletions

View File

@@ -1,13 +1,18 @@
# coding: utf8
from __future__ import annotations
import os
import sqlite3
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import time
from dataclasses import dataclass
from enum import StrEnum
from os import PathLike
__version__ = "1.1.0"
__version_info__ = (1, 1, 0)
from cryptography.fernet import Fernet, InvalidToken
__version__ = "2.2.0"
__version_info__ = tuple(map(int, __version__.split(".")))
class DataType(StrEnum):
@@ -27,10 +32,28 @@ class NullType(object):
class BlobType(object):
def __init__(self, data: bytes = b""):
self._data = data
self.data = data
def __str__(self):
return f"X'{self._data.hex()}'"
return f"X'{self.data.hex()}'"
class NotRandomFernet(Fernet):
"""固定下来每次相同的 key 的加密结果相同,方便条件查询"""
def __init__(self, key: bytes | str, fix_time: int, fix_iv: bytes, backend=None):
super().__init__(key, backend)
self._fix_time = fix_time
self._fix_iv = fix_iv
def encrypt(self, data: bytes) -> bytes:
return self._encrypt_from_parts(data, self._fix_time, self._fix_iv)
def _encrypt_blob(blob: BlobType, fernet: NotRandomFernet) -> BlobType:
if fernet is None:
raise ValueError("Key is not set")
return BlobType(fernet.encrypt(blob.data))
def _get_type(data_type: DataType) -> type:
@@ -79,8 +102,14 @@ class Column(object):
has_default: bool = False
default: NullType | int | float | str | BlobType = 0
secure: bool = False
def __post_init__(self):
if self.secure is True and self.data_type != DataType.BLOB:
raise ValueError("Only BLOB data can be secured")
def __str__(self):
head = f"{self.name} {self.data_type.name}"
head = f"{self.name} {self.data_type.value}"
if self.primary_key:
head = f"{head} PRIMARY KEY"
if not self.nullable:
@@ -94,63 +123,6 @@ class Column(object):
__repr__ = __str__
@dataclass
class AbsHeadRow(ABC):
# 这里的几个其实都是 list[str],但为了防止实例类中提示类型错误,就写成了 list
primary_keys: list = field(default_factory=list)
not_nulls: list = field(default_factory=list)
uniques: list = field(default_factory=list)
defaults: dict = field(default_factory=dict)
@abstractmethod
def __post_init__(self):
pass
def __getattr__(self, item):
# 此类声明的类变量自然都是不会被定义的,所以在 __post_init__ 中获取时就会跳转到这
# 此时需要的其实就是这个类变量名称本身,因此直接返回
if item in self.__annotations__:
return item
raise AttributeError
def to_columns(self) -> list[Column]:
columns: list[Column] = []
# self.__annotations__ 会返回实例对象中定义的类变量名称和类型的键值对
# 但不包括此抽象类中的
for column_name in self.__annotations__:
column_type = self.__annotations__[column_name]
data_type = _get_data_type(column_type)
# 实例对象的这个类变量设置了值,在此作为默认值
if column_name in self.defaults:
default = self.defaults[column_name]
if type(default) is not column_type:
raise TypeError(f"Column {column_name} must be {column_type} but was {type(default)}")
has_default = True
else:
default = 0 # 此处设置的值无用,但是默认给它一个值
has_default = False
primary_key = column_name in self.primary_keys
nullable = column_name not in self.not_nulls
unique = column_name in self.uniques
columns.append(Column(
name=column_name,
data_type=data_type,
primary_key=primary_key,
nullable=nullable,
unique=unique,
has_default=has_default,
default=default,
))
return columns
def to_column_dict(self) -> dict[str, Column]:
return {col.name: col for col in self.to_columns()}
class Expression(object):
def __init__(self, expr: str):
@@ -174,14 +146,33 @@ class Expression(object):
class Operand(object):
def __init__(self, column: Column | str):
def __init__(
self,
column: Column | str,
key: bytes = None,
fix_time: int = None,
fix_iv: bytes = None,
):
self._column = column
self._key = key
self._fix_time = fix_time
self._fix_iv = fix_iv
self._name = column.name if isinstance(column, Column) else column
def equal_to(self, value):
return Expression(f"{self._name} = {_to_string(value)}")
def not_equal_to(self, value):
return Expression(f"{self._name} =! {_to_string(value)}")
def equal_to(self, value, not_: bool = False):
if isinstance(self._column, Column):
if self._column.data_type == DataType.BLOB and type(value) is str:
value = BlobType(value.encode("utf-8"))
# 这里不能换成 elif
if self._key is not None and self._column.secure and isinstance(value, BlobType):
fix_time = self._fix_time if self._fix_time is not None else int(time.time())
fix_iv = self._fix_iv if self._fix_iv is not None else os.urandom(16)
try:
value = _encrypt_blob(value, NotRandomFernet(self._key, fix_time, fix_iv))
except ValueError:
pass
op = "!=" if not_ else "="
return Expression(f"{self._name} {op} {_to_string(value)}")
def less_than(self, value):
return Expression(f"{self._name} < {_to_string(value)}")
@@ -254,11 +245,25 @@ def order(column: Column | str | int,
class Sqlite3Worker(object):
def __init__(self, db_name: str | PathLike[str] = ":memory:"):
def __init__(
self,
db_name: str | PathLike[str] = ":memory:",
key: bytes = None,
fix_time: int = None,
fix_iv: bytes = None,
):
self._db_name = db_name
self._conn = sqlite3.connect(db_name)
self._cursor = self._conn.cursor()
self._is_closed = False
self._fernet = None
if key is not None:
fix_time = fix_time if fix_time is not None else int(time.time())
fix_iv = fix_iv if fix_iv is not None else os.urandom(16)
try:
self._fernet = NotRandomFernet(key, fix_time, fix_iv)
except ValueError:
pass
def __del__(self):
self.close()
@@ -276,13 +281,12 @@ class Sqlite3Worker(object):
def commit(self):
self._conn.commit()
def create_table(self, table_name: str, columns: list[Column] | AbsHeadRow,
if_not_exists: bool = False, schema_name: str = "", *, execute: bool = True) -> str:
def create_table(self, table_name: str, columns: list[Column],
if_not_exists: bool = False, schema_name: str = "",
*, execute: bool = True) -> str:
if table_name.startswith("sqlite_"):
raise ValueError("Table name must not start with 'sqlite_')")
if isinstance(columns, AbsHeadRow):
columns = columns.to_columns()
columns_str = ", ".join([str(col) for col in columns])
head = "CREATE TABLE"
if if_not_exists:
@@ -363,27 +367,41 @@ class Sqlite3Worker(object):
return ", ".join(columns_str_ls)
def insert_into(self, table_name: str, columns: list[Column | str],
values: list[list[str | int | float | BlobType]],
values: list[list[NullType | str | int | float | BlobType]],
*, execute: bool = True, commit: bool = True) -> str:
col_count = len(columns)
columns_str = self._columns_to_string(columns)
values_str_ls = []
for value in values:
if len(value) != col_count:
for value_row in values:
if len(value_row) != col_count:
raise ValueError(f"Length of values must be {col_count}")
value_row_str_ls = []
for i in range(col_count):
column = columns[i]
value = value_row[i]
if isinstance(column, Column):
type_ = _get_type(column.data_type)
col_type = _get_type(column.data_type)
val_type = type(value)
# 支持将 int 隐式转为 float
if type(value[i]) is int and type_ is float:
continue
if type(value[i]) is not type_:
raise ValueError(f"The {i + 1}(th) type of value must be {type_},"
f" because the column type is {column.data_type}")
# 这里的 value 是一行数据,是一个多值列表
values_str_ls.append(f"({', '.join([_to_string(val) for val in value])})")
if val_type is int and col_type is float:
pass
# 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的
elif val_type is NullType and column.nullable is True:
pass
# 其他类型不匹配
elif val_type is not col_type:
raise ValueError(f"The {i + 1}(th) type of value must be {col_type},"
f" because the column type is {column.data_type},"
f" found {val_type}")
# 如果加密
if column.secure and val_type is not NullType:
value = _to_string(_encrypt_blob(value, self._fernet))
value_row_str_ls.append(_to_string(value))
values_str_ls.append(f"({', '.join(value_row_str_ls)})")
values_str = ", ".join(values_str_ls)
@@ -415,7 +433,7 @@ class Sqlite3Worker(object):
where: Expression = None,
order_by: list[str] | str = None,
limit: int = None, offset: int = None,
*, execute: bool = True) -> tuple[str, list[tuple]]:
*, execute: bool = True) -> tuple[str, list[list]]:
if len(columns) == 0:
columns_str = "*"
else:
@@ -431,6 +449,22 @@ class Sqlite3Worker(object):
if execute:
self._cursor.execute(statement)
rows = self._cursor.fetchall()
rows = [list(row) for row in rows] # 将每行转成列表,方便替换解密数据
# 下面的整个循环都是为了找到需要解密的数据尝试解密
for i in range(len(columns)):
column = columns[i]
if isinstance(column, Column) and column.secure:
for row in rows:
# 如果是加密的 BLOB 但是值不为 NULL 才解密
if row[i] is not None and self._fernet is not None:
# 不管是key错误还是密文错误都是 InvalidToken貌似没法区分
# 因此如果有的数据不是加密过的,应该跳过,不应该影响之后的密文解密,
# 因此这里还是得继续循环下去
try:
row[i] = self._fernet.decrypt(row[i])
except InvalidToken:
pass
return statement, rows
else:
return statement, []
@@ -449,14 +483,26 @@ class Sqlite3Worker(object):
self._conn.commit()
return statement
def update(self, table_name: str, new_values: list[tuple[Column, NullType | int | float | str | BlobType]],
def update(self, table_name: str, new_values: list[tuple[Column | str, NullType | int | float | str | BlobType]],
where: Expression = None,
*, execute: bool = True, commit: bool = True) -> str:
new_values_str_ls = []
for column, value in new_values:
if _get_data_type(type(value)) != column.data_type:
if isinstance(column, Column):
# 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的
if type(value) is NullType and column.nullable is True:
pass
elif _get_data_type(type(value)) != column.data_type:
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
new_values_str_ls.append(f"{column.name} = {_to_string(value)}")
if column.secure and type(value) is not NullType:
value = _encrypt_blob(value, self._fernet)
name = column.name
else:
name = column
new_values_str_ls.append(f"{name} = {_to_string(value)}")
head = f"UPDATE {table_name}"
body = f"{head} SET {', '.join(new_values_str_ls)}"

View File

@@ -4,6 +4,8 @@ import sys
import json
from pathlib import Path
from cryptography.fernet import Fernet
def path_not_exist(path: str | Path) -> bool:
"""
@@ -50,7 +52,6 @@ def read_config(org_name: str, app_name: str) -> dict:
config_path = get_config_path(org_name, app_name)
if not config_path.exists():
config = {
"table_name": "entries",
"button_min_width": 120,
"last_db_path": "",
"loaded_memory": {}
@@ -76,3 +77,15 @@ def get_default_db_path(config: dict, org_name: str, app_name: str) -> str:
def get_secrets_path(org_name: str, app_name: str) -> str:
app_dir = get_app_dir(org_name, app_name)
return str(app_dir / "secrets.db")
def get_or_generate_key(db_name: str, org_name: str, app_name: str) -> bytes:
app_dir = get_app_dir(org_name, app_name)
name = Path(db_name).name
key_path = app_dir / f"{name}.key"
if key_path.exists():
key = key_path.read_bytes()
else:
key = Fernet.generate_key()
key_path.write_bytes(key)
return key

View File

@@ -5,10 +5,10 @@ columns_d = {
"entry_id": Column("entry_id", DataType.INTEGER, primary_key=True, unique=True),
"title": Column("title", DataType.BLOB),
"username": Column("username", DataType.BLOB),
"password": Column("password", DataType.BLOB),
"opt": Column("opt", DataType.TEXT),
"password": Column("password", DataType.BLOB, secure=True),
"opt": Column("opt", DataType.BLOB, secure=True),
"url": Column("url", DataType.BLOB),
"notes": Column("notes", DataType.BLOB),
"notes": Column("notes", DataType.BLOB, secure=True),
"uuid": Column("uuid", DataType.TEXT, nullable=False),
"filepath": Column("filepath", DataType.BLOB, nullable=False),
"path": Column("path", DataType.BLOB),

View File

@@ -38,7 +38,7 @@ def read_kps_to_db(kps_file: str | PathLike[str], password: str,
blob_fy(trim_str(entry.title)),
blob_fy(trim_str(entry.username)),
blob_fy(entry.password),
extract_otp(entry.otp),
blob_fy(extract_otp(entry.otp)),
blob_fy(trim_str(entry.url)),
blob_fy(entry.notes),
str(entry.uuid),

View File

@@ -4,7 +4,7 @@ from .Sqlite3Helper import Column, DataType
sec_columns_d = {
"secret_id": Column("secret_id", DataType.INTEGER, primary_key=True, unique=True),
"filepath": Column("filepath", DataType.BLOB),
"password": Column("password", DataType.BLOB),
"password": Column("password", DataType.BLOB, secure=True),
}
sec_all_columns = [

View File

@@ -11,7 +11,7 @@ from lib.config_utils import (
from src.mw_kps_unifier import KpsUnifier
import src.rc_kps_unifier
__version__ = '0.1.0'
__version__ = '0.2.0'
__version_info__ = tuple(map(int, __version__.split('.')))
ORG_NAME = "JnPrograms"
@@ -27,7 +27,8 @@ def main():
db_path = get_default_db_path(config, ORG_NAME, APP_NAME)
secrets_path = get_secrets_path(ORG_NAME, APP_NAME)
win = KpsUnifier(db_path, secrets_path, config, __version__)
win = KpsUnifier(db_path, secrets_path, config,
ORG_NAME, APP_NAME, __version__)
win.show()
return app.exec()

View File

@@ -71,12 +71,11 @@ class DaEntryInfo(QtWidgets.QDialog):
def __init__(
self,
entry_id: int,
config: dict,
sqh: Sqlite3Worker,
parent: QtWidgets.QWidget = None
):
super().__init__(parent)
_, results = sqh.select(config["table_name"], all_columns,
_, results = sqh.select("entries", all_columns,
where=Operand(entry_id_col).equal_to(entry_id))
entry = results[0]
@@ -88,5 +87,3 @@ class DaEntryInfo(QtWidgets.QDialog):
def sizeHint(self):
return QtCore.QSize(640, 360)

View File

@@ -1,13 +1,14 @@
# coding: utf8
from PySide6 import QtWidgets, QtCore, QtGui
from pykeepass import PyKeePass
from pykeepass.exceptions import HeaderChecksumError, CredentialsError
class DaTargetLogin(QtWidgets.QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("目标文件")
class UiDaTargetLogin(object):
def __init__(self, window: QtWidgets.QWidget):
window.setWindowTitle("目标文件")
self.vly_m = QtWidgets.QVBoxLayout()
self.setLayout(self.vly_m)
window.setLayout(self.vly_m)
icon_ellipsis = QtGui.QIcon(":/asset/img/ellipsis.svg")
self.icon_eye = QtGui.QIcon(":/asset/img/eye.svg")
@@ -15,61 +16,77 @@ class DaTargetLogin(QtWidgets.QDialog):
self.hly_path = QtWidgets.QHBoxLayout()
self.vly_m.addLayout(self.hly_path)
self.lb_path = QtWidgets.QLabel("路径:", self)
self.lne_path = QtWidgets.QLineEdit(self)
self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=self)
self.lb_path = QtWidgets.QLabel("路径:", window)
self.lne_path = QtWidgets.QLineEdit(window)
self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=window)
self.hly_path.addWidget(self.lb_path)
self.hly_path.addWidget(self.lne_path)
self.hly_path.addWidget(self.pbn_browse)
self.hly_password = QtWidgets.QHBoxLayout()
self.vly_m.addLayout(self.hly_password)
self.lb_password = QtWidgets.QLabel("密码:", self)
self.lne_password = QtWidgets.QLineEdit(self)
self.lb_password = QtWidgets.QLabel("密码:", window)
self.lne_password = QtWidgets.QLineEdit(window)
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=self)
self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=window)
self.hly_password.addWidget(self.lb_password)
self.hly_password.addWidget(self.lne_password)
self.hly_password.addWidget(self.pbn_eye)
self.hly_bottom = QtWidgets.QHBoxLayout()
self.vly_m.addLayout(self.hly_bottom)
self.pbn_ok = QtWidgets.QPushButton("确定", self)
self.pbn_cancel = QtWidgets.QPushButton("取消", self)
self.pbn_ok = QtWidgets.QPushButton("确定", window)
self.pbn_cancel = QtWidgets.QPushButton("取消", window)
self.hly_bottom.addStretch(1)
self.hly_bottom.addWidget(self.pbn_ok)
self.hly_bottom.addWidget(self.pbn_cancel)
self.vly_m.addStretch(1)
self.pbn_browse.clicked.connect(self.on_pbn_browse_clicked)
self.pbn_eye.clicked.connect(self.on_pbn_eye_clicked)
self.pbn_ok.clicked.connect(self.on_pbn_ok_clicked)
self.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked)
self.pbn_ok.setFocus()
class DaTargetLogin(QtWidgets.QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.ui = UiDaTargetLogin(self)
self.tar_kp: PyKeePass | None = None
self.ui.pbn_browse.clicked.connect(self.on_pbn_browse_clicked)
self.ui.pbn_eye.clicked.connect(self.on_pbn_eye_clicked)
self.ui.pbn_ok.clicked.connect(self.on_pbn_ok_clicked)
self.ui.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked)
def on_pbn_browse_clicked(self):
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "选择", "../",
filter="KeePass 2 数据库 (*.kdbx);;所有文件 (*)")
if len(filename) == 0:
return
self.lne_path.setText(filename)
self.ui.lne_path.setText(filename)
def on_pbn_eye_clicked(self):
if self.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password:
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal)
self.pbn_eye.setIcon(self.icon_eye)
if self.ui.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password:
self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal)
self.ui.pbn_eye.setIcon(self.ui.icon_eye)
else:
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
self.pbn_eye.setIcon(self.icon_eye_off)
self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
self.ui.pbn_eye.setIcon(self.ui.icon_eye_off)
def sizeHint(self):
return QtCore.QSize(540, 40)
def on_pbn_ok_clicked(self):
try:
self.tar_kp = PyKeePass(self.ui.lne_path.text(), self.ui.lne_password.text())
except CredentialsError:
QtWidgets.QMessageBox.critical(self, "错误", "keepass 密码错误")
return
except (FileNotFoundError, HeaderChecksumError):
QtWidgets.QMessageBox.critical(self, "错误", "文件不存在或不是 keepass 文件")
return
self.accept()
def on_pbn_cancel_clicked(self):
self.tar_kp = None
self.reject()

View File

@@ -90,7 +90,7 @@ class GbxKpsLogin(QtWidgets.QGroupBox):
kp = read_kps_to_db(
kps_file=self.lne_path.text(),
password=self.lne_password.text(),
table_name=self.config["table_name"],
table_name="entries",
sqh=self.sqh
)
except CredentialsError:

View File

@@ -10,7 +10,7 @@ from .cmbx_styles import StyleComboBox
from lib.Sqlite3Helper import Sqlite3Worker
from lib.db_columns_def import all_columns
from lib.sec_db_columns_def import sec_all_columns
from lib.config_utils import write_config
from lib.config_utils import write_config, get_or_generate_key
class UiKpsUnifier(object):
@@ -78,10 +78,15 @@ class KpsUnifier(QtWidgets.QMainWindow):
db_path: str,
secrets_path: str,
config: dict,
org_name: str,
app_name: str,
version: str,
parent: QtWidgets.QWidget = None,
):
super().__init__(parent)
self.org_name = org_name
self.app_name = app_name
self.db_path = db_path
self.secrets_path = secrets_path
self.config = config
@@ -103,18 +108,23 @@ class KpsUnifier(QtWidgets.QMainWindow):
def __del__(self):
self.config["last_db_path"] = self.db_path
write_config(self.config,
QtWidgets.QApplication.organizationName(),
QtWidgets.QApplication.applicationName())
write_config(self.config, self.org_name, self.app_name)
def sizeHint(self):
return QtCore.QSize(860, 640)
def init_db(self) -> Sqlite3Worker:
sqh = Sqlite3Worker(self.db_path)
sqh.create_table(self.config["table_name"], all_columns, if_not_exists=True)
key = get_or_generate_key(self.db_path, self.org_name, self.app_name)
sqh = Sqlite3Worker(self.db_path, key)
sqh.create_table("entries", all_columns, if_not_exists=True)
return sqh
def init_secrets_db(self) -> Sqlite3Worker:
key = get_or_generate_key(self.secrets_path, self.org_name, self.app_name)
sec_sqh = Sqlite3Worker(self.secrets_path, key)
sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True)
return sec_sqh
def update_db(self, filename: str):
self.db_path = filename
self.sqh = self.init_db()
@@ -153,8 +163,3 @@ class KpsUnifier(QtWidgets.QMainWindow):
def on_act_about_qt_triggered(self):
QtWidgets.QMessageBox.aboutQt(self, "关于 Qt")
def init_secrets_db(self) -> Sqlite3Worker:
sec_sqh = Sqlite3Worker(self.secrets_path)
sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True)
return sec_sqh

View File

@@ -5,7 +5,8 @@ from PySide6 import QtWidgets
from pykeepass import PyKeePass
from .gbx_kps_login import GbxKpsLogin
from .utils import accept_warning
from .da_target_login import DaTargetLogin
from .utils import accept_warning, HorizontalLine
from lib.Sqlite3Helper import Sqlite3Worker
@@ -81,6 +82,15 @@ class PageLoad(QtWidgets.QWidget):
self.vly_left = QtWidgets.QVBoxLayout()
self.hly_m.addLayout(self.vly_left)
self.pbn_set_target = QtWidgets.QPushButton("设置目标文件", self)
# 暂时没用
self.pbn_set_target.setDisabled(True)
self.pbn_set_target.setMinimumWidth(config["button_min_width"])
self.vly_left.addWidget(self.pbn_set_target)
self.hln_1 = HorizontalLine(self)
self.vly_left.addWidget(self.hln_1)
self.pbn_add_kps = QtWidgets.QPushButton("添加 KPS", self)
self.pbn_add_kps.setMinimumWidth(config["button_min_width"])
self.vly_left.addWidget(self.pbn_add_kps)
@@ -101,6 +111,8 @@ class PageLoad(QtWidgets.QWidget):
self.wg_sa = WgLoadKps(config, file_kp, sqh, sec_sqh, self.sa_m)
self.sa_m.setWidget(self.wg_sa)
self.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked)
self.pbn_add_kps.clicked.connect(self.on_pbn_add_kps_clicked)
self.pbn_clear_db.clicked.connect(self.on_pbn_clear_db_clicked)
self.pbn_clear_loaded_mem.clicked.connect(self.on_pbn_clear_loaded_mem_clicked)
@@ -121,7 +133,7 @@ class PageLoad(QtWidgets.QWidget):
return
try:
self.sqh.delete_from(self.config["table_name"])
self.sqh.delete_from("entries")
except sqlite3.OperationalError as e:
QtWidgets.QMessageBox.critical(self, "错误", f"清空数据库失败:\n{e}")
else:
@@ -148,3 +160,7 @@ class PageLoad(QtWidgets.QWidget):
# 更新kps加载状态
for wg in self.wg_sa.kps_wgs:
self.wg_sa.update_load_status(wg)
def on_pbn_set_target_clicked(self):
da_target_login = DaTargetLogin(self)
da_target_login.exec()

View File

@@ -5,7 +5,6 @@ from PySide6 import QtWidgets, QtCore, QtGui
from pykeepass import PyKeePass
from .da_entry_info import DaEntryInfo
from .da_target_login import DaTargetLogin
from .utils import HorizontalLine, get_filepath_uuids_map, accept_warning
from lib.Sqlite3Helper import Sqlite3Worker, Expression, Operand
from lib.db_columns_def import (
@@ -18,7 +17,7 @@ from lib.kps_operations import blob_fy
class QueryTableModel(QtCore.QAbstractTableModel):
def __init__(self, query_results: list[tuple], parent=None):
def __init__(self, query_results: list[list], parent=None):
super().__init__(parent)
self.query_results = query_results
self.headers = ["序号", "标题", "用户名", "URL"]
@@ -98,10 +97,6 @@ class UiPageQuery(object):
self.pbn_execute.setMinimumWidth(config["button_min_width"])
self.vly_sa_wg.addWidget(self.pbn_execute)
self.pbn_set_target = QtWidgets.QPushButton("目标文件", window)
self.pbn_set_target.setMinimumWidth(config["button_min_width"])
self.vly_sa_wg.addWidget(self.pbn_set_target)
self.hln_1 = HorizontalLine(window)
self.vly_sa_wg.addWidget(self.hln_1)
@@ -147,7 +142,6 @@ class PageQuery(QtWidgets.QWidget):
self.ui.act_delete.triggered_with_str.connect(self.on_act_mark_triggered_with_str)
self.ui.pbn_execute.clicked.connect(self.on_pbn_execute_clicked)
self.ui.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked)
self.ui.pbn_all.clicked.connect(self.on_pbn_all_clicked)
self.ui.pbn_deleted.clicked.connect(self.on_pbn_deleted_clicked)
@@ -169,7 +163,7 @@ class PageQuery(QtWidgets.QWidget):
},
{
"name": "谷歌文档",
"where": "url LIKE 'https://docs.google.com/%' or url LIKE 'https://drive.google.com/%'"
"where": "(url LIKE 'https://docs.google.com/%' OR url LIKE 'https://drive.google.com/%')"
},
]
for fil in default_filters:
@@ -193,7 +187,7 @@ class PageQuery(QtWidgets.QWidget):
self.set_filter_button(fil)
def on_custom_filters_clicked_with_data(self, data: dict):
_, results = self.sqh.select(self.config["table_name"], query_columns,
_, results = self.sqh.select("entries", query_columns,
where=Expression(data["where"]).and_(Operand(deleted_col).equal_to(0)))
model = QueryTableModel(results, self)
self.ui.trv_m.setModel(model)
@@ -202,20 +196,20 @@ class PageQuery(QtWidgets.QWidget):
self.sqh = sqh
def on_pbn_all_clicked(self):
_, results = self.sqh.select(self.config["table_name"], query_columns,
_, results = self.sqh.select("entries", query_columns,
where=Operand(deleted_col).equal_to(0))
model = QueryTableModel(results, self)
self.ui.trv_m.setModel(model)
def on_pbn_deleted_clicked(self):
_, results = self.sqh.select(self.config["table_name"], query_columns,
_, results = self.sqh.select("entries", query_columns,
where=Operand(deleted_col).equal_to(1))
model = QueryTableModel(results, self)
self.ui.trv_m.setModel(model)
def on_trv_m_double_clicked(self, index: QtCore.QModelIndex):
entry_id = index.siblingAtColumn(0).data(QtCore.Qt.ItemDataRole.DisplayRole)
da_entry_info = DaEntryInfo(entry_id, self.config, self.sqh, self)
da_entry_info = DaEntryInfo(entry_id, self.sqh, self)
da_entry_info.exec()
def on_trv_m_custom_context_menu_requested(self, pos: QtCore.QPoint):
@@ -228,7 +222,7 @@ class PageQuery(QtWidgets.QWidget):
for index in indexes if index.column() == 0
]
self.sqh.update(self.config["table_name"], [(status_col, info)],
self.sqh.update("entries", [(status_col, info)],
where=Operand(entry_id_col).in_(entry_ids))
def on_pbn_execute_clicked(self):
@@ -236,7 +230,7 @@ class PageQuery(QtWidgets.QWidget):
return
# 删除功能
_, results = self.sqh.select(self.config["table_name"], sim_columns,
_, results = self.sqh.select("entries", sim_columns,
where=Operand(status_col).equal_to("delete"))
file_uuids = get_filepath_uuids_map(results)
@@ -252,7 +246,7 @@ class PageQuery(QtWidgets.QWidget):
for u in file_uuids[file]:
total += 1
self.sqh.update(self.config["table_name"], [(deleted_col, 1)],
self.sqh.update("entries", [(deleted_col, 1)],
where=Operand(uuid_col).equal_to(u).and_(
Operand(filepath_col).equal_to(blob_fy(file))))
@@ -269,10 +263,6 @@ class PageQuery(QtWidgets.QWidget):
QtWidgets.QMessageBox.information(self, "提示",
f"{total} 条标记的条目,已删除 {success} 条,无效 {invalid} 条。")
def on_pbn_set_target_clicked(self):
da_target_login = DaTargetLogin(self)
da_target_login.exec()
class PushButtonWithData(QtWidgets.QPushButton):

View File

@@ -1,6 +1,5 @@
# coding: utf8
from itertools import combinations
from uuid import UUID
from PySide6 import QtWidgets, QtCore
from PySide6.QtCore import QAbstractTableModel
@@ -70,7 +69,7 @@ class PageSimilar(QtWidgets.QWidget):
self.pbn_delete_invalid_data.clicked.connect(self.on_pbn_delete_invalid_data_clicked)
def on_pbn_read_db_clicked(self):
_, results = self.sqh.select(self.config["table_name"], sim_columns)
_, results = self.sqh.select("entries", sim_columns)
file_uuids = get_filepath_uuids_map(results)
files = file_uuids.keys()
@@ -98,11 +97,11 @@ class PageSimilar(QtWidgets.QWidget):
if accept_warning(self, True, "警告", "你确定要从数据库删除无效文件的记录吗?"):
return
_, filepaths = self.sqh.select(self.config["table_name"], [filepath_col,])
_, filepaths = self.sqh.select("entries", [filepath_col,])
unique_filepaths = set([p[0].decode("utf8") for p in filepaths])
invalid_filepaths = [p for p in unique_filepaths if path_not_exist(p)]
for path in invalid_filepaths:
self.sqh.delete_from(self.config["table_name"],
self.sqh.delete_from("entries",
where=Operand(filepath_col).equal_to(blob_fy(path)),
commit=False)
self.sqh.commit()

View File

@@ -19,7 +19,7 @@ class HorizontalLine(QtWidgets.QFrame):
self.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
def get_filepath_uuids_map(query_results: list[tuple]) -> dict[str, list[str]]:
def get_filepath_uuids_map(query_results: list[list]) -> dict[str, list[str]]:
file_uuids: dict[str, list[str]] = {}
for u, filepath in query_results:
filepath = filepath.decode("utf8")