dev: 支持加密数据
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
# coding: utf8
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from os import PathLike
|
||||
|
||||
__version__ = "1.1.0"
|
||||
__version_info__ = (1, 1, 0)
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
|
||||
__version__ = "2.2.0"
|
||||
__version_info__ = tuple(map(int, __version__.split(".")))
|
||||
|
||||
|
||||
class DataType(StrEnum):
|
||||
@@ -27,10 +32,28 @@ class NullType(object):
|
||||
class BlobType(object):
|
||||
|
||||
def __init__(self, data: bytes = b""):
|
||||
self._data = data
|
||||
self.data = data
|
||||
|
||||
def __str__(self):
|
||||
return f"X'{self._data.hex()}'"
|
||||
return f"X'{self.data.hex()}'"
|
||||
|
||||
|
||||
class NotRandomFernet(Fernet):
|
||||
"""固定下来每次相同的 key 的加密结果相同,方便条件查询"""
|
||||
|
||||
def __init__(self, key: bytes | str, fix_time: int, fix_iv: bytes, backend=None):
|
||||
super().__init__(key, backend)
|
||||
self._fix_time = fix_time
|
||||
self._fix_iv = fix_iv
|
||||
|
||||
def encrypt(self, data: bytes) -> bytes:
|
||||
return self._encrypt_from_parts(data, self._fix_time, self._fix_iv)
|
||||
|
||||
|
||||
def _encrypt_blob(blob: BlobType, fernet: NotRandomFernet) -> BlobType:
|
||||
if fernet is None:
|
||||
raise ValueError("Key is not set")
|
||||
return BlobType(fernet.encrypt(blob.data))
|
||||
|
||||
|
||||
def _get_type(data_type: DataType) -> type:
|
||||
@@ -79,8 +102,14 @@ class Column(object):
|
||||
has_default: bool = False
|
||||
default: NullType | int | float | str | BlobType = 0
|
||||
|
||||
secure: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.secure is True and self.data_type != DataType.BLOB:
|
||||
raise ValueError("Only BLOB data can be secured")
|
||||
|
||||
def __str__(self):
|
||||
head = f"{self.name} {self.data_type.name}"
|
||||
head = f"{self.name} {self.data_type.value}"
|
||||
if self.primary_key:
|
||||
head = f"{head} PRIMARY KEY"
|
||||
if not self.nullable:
|
||||
@@ -94,63 +123,6 @@ class Column(object):
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsHeadRow(ABC):
|
||||
# 这里的几个其实都是 list[str],但为了防止实例类中提示类型错误,就写成了 list
|
||||
primary_keys: list = field(default_factory=list)
|
||||
not_nulls: list = field(default_factory=list)
|
||||
uniques: list = field(default_factory=list)
|
||||
defaults: dict = field(default_factory=dict)
|
||||
|
||||
@abstractmethod
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
def __getattr__(self, item):
|
||||
# 此类声明的类变量自然都是不会被定义的,所以在 __post_init__ 中获取时就会跳转到这
|
||||
# 此时需要的其实就是这个类变量名称本身,因此直接返回
|
||||
if item in self.__annotations__:
|
||||
return item
|
||||
raise AttributeError
|
||||
|
||||
def to_columns(self) -> list[Column]:
|
||||
columns: list[Column] = []
|
||||
|
||||
# self.__annotations__ 会返回实例对象中定义的类变量名称和类型的键值对
|
||||
# 但不包括此抽象类中的
|
||||
for column_name in self.__annotations__:
|
||||
column_type = self.__annotations__[column_name]
|
||||
data_type = _get_data_type(column_type)
|
||||
|
||||
# 实例对象的这个类变量设置了值,在此作为默认值
|
||||
if column_name in self.defaults:
|
||||
default = self.defaults[column_name]
|
||||
if type(default) is not column_type:
|
||||
raise TypeError(f"Column {column_name} must be {column_type} but was {type(default)}")
|
||||
has_default = True
|
||||
else:
|
||||
default = 0 # 此处设置的值无用,但是默认给它一个值
|
||||
has_default = False
|
||||
|
||||
primary_key = column_name in self.primary_keys
|
||||
nullable = column_name not in self.not_nulls
|
||||
unique = column_name in self.uniques
|
||||
|
||||
columns.append(Column(
|
||||
name=column_name,
|
||||
data_type=data_type,
|
||||
primary_key=primary_key,
|
||||
nullable=nullable,
|
||||
unique=unique,
|
||||
has_default=has_default,
|
||||
default=default,
|
||||
))
|
||||
return columns
|
||||
|
||||
def to_column_dict(self) -> dict[str, Column]:
|
||||
return {col.name: col for col in self.to_columns()}
|
||||
|
||||
|
||||
class Expression(object):
|
||||
|
||||
def __init__(self, expr: str):
|
||||
@@ -174,14 +146,33 @@ class Expression(object):
|
||||
|
||||
class Operand(object):
|
||||
|
||||
def __init__(self, column: Column | str):
|
||||
def __init__(
|
||||
self,
|
||||
column: Column | str,
|
||||
key: bytes = None,
|
||||
fix_time: int = None,
|
||||
fix_iv: bytes = None,
|
||||
):
|
||||
self._column = column
|
||||
self._key = key
|
||||
self._fix_time = fix_time
|
||||
self._fix_iv = fix_iv
|
||||
self._name = column.name if isinstance(column, Column) else column
|
||||
|
||||
def equal_to(self, value):
|
||||
return Expression(f"{self._name} = {_to_string(value)}")
|
||||
|
||||
def not_equal_to(self, value):
|
||||
return Expression(f"{self._name} =! {_to_string(value)}")
|
||||
def equal_to(self, value, not_: bool = False):
|
||||
if isinstance(self._column, Column):
|
||||
if self._column.data_type == DataType.BLOB and type(value) is str:
|
||||
value = BlobType(value.encode("utf-8"))
|
||||
# 这里不能换成 elif
|
||||
if self._key is not None and self._column.secure and isinstance(value, BlobType):
|
||||
fix_time = self._fix_time if self._fix_time is not None else int(time.time())
|
||||
fix_iv = self._fix_iv if self._fix_iv is not None else os.urandom(16)
|
||||
try:
|
||||
value = _encrypt_blob(value, NotRandomFernet(self._key, fix_time, fix_iv))
|
||||
except ValueError:
|
||||
pass
|
||||
op = "!=" if not_ else "="
|
||||
return Expression(f"{self._name} {op} {_to_string(value)}")
|
||||
|
||||
def less_than(self, value):
|
||||
return Expression(f"{self._name} < {_to_string(value)}")
|
||||
@@ -254,11 +245,25 @@ def order(column: Column | str | int,
|
||||
|
||||
class Sqlite3Worker(object):
|
||||
|
||||
def __init__(self, db_name: str | PathLike[str] = ":memory:"):
|
||||
def __init__(
|
||||
self,
|
||||
db_name: str | PathLike[str] = ":memory:",
|
||||
key: bytes = None,
|
||||
fix_time: int = None,
|
||||
fix_iv: bytes = None,
|
||||
):
|
||||
self._db_name = db_name
|
||||
self._conn = sqlite3.connect(db_name)
|
||||
self._cursor = self._conn.cursor()
|
||||
self._is_closed = False
|
||||
self._fernet = None
|
||||
if key is not None:
|
||||
fix_time = fix_time if fix_time is not None else int(time.time())
|
||||
fix_iv = fix_iv if fix_iv is not None else os.urandom(16)
|
||||
try:
|
||||
self._fernet = NotRandomFernet(key, fix_time, fix_iv)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@@ -276,13 +281,12 @@ class Sqlite3Worker(object):
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def create_table(self, table_name: str, columns: list[Column] | AbsHeadRow,
|
||||
if_not_exists: bool = False, schema_name: str = "", *, execute: bool = True) -> str:
|
||||
def create_table(self, table_name: str, columns: list[Column],
|
||||
if_not_exists: bool = False, schema_name: str = "",
|
||||
*, execute: bool = True) -> str:
|
||||
if table_name.startswith("sqlite_"):
|
||||
raise ValueError("Table name must not start with 'sqlite_')")
|
||||
|
||||
if isinstance(columns, AbsHeadRow):
|
||||
columns = columns.to_columns()
|
||||
columns_str = ", ".join([str(col) for col in columns])
|
||||
head = "CREATE TABLE"
|
||||
if if_not_exists:
|
||||
@@ -363,27 +367,41 @@ class Sqlite3Worker(object):
|
||||
return ", ".join(columns_str_ls)
|
||||
|
||||
def insert_into(self, table_name: str, columns: list[Column | str],
|
||||
values: list[list[str | int | float | BlobType]],
|
||||
values: list[list[NullType | str | int | float | BlobType]],
|
||||
*, execute: bool = True, commit: bool = True) -> str:
|
||||
col_count = len(columns)
|
||||
columns_str = self._columns_to_string(columns)
|
||||
|
||||
values_str_ls = []
|
||||
for value in values:
|
||||
if len(value) != col_count:
|
||||
for value_row in values:
|
||||
if len(value_row) != col_count:
|
||||
raise ValueError(f"Length of values must be {col_count}")
|
||||
|
||||
value_row_str_ls = []
|
||||
for i in range(col_count):
|
||||
column = columns[i]
|
||||
value = value_row[i]
|
||||
if isinstance(column, Column):
|
||||
type_ = _get_type(column.data_type)
|
||||
col_type = _get_type(column.data_type)
|
||||
val_type = type(value)
|
||||
# 支持将 int 隐式转为 float
|
||||
if type(value[i]) is int and type_ is float:
|
||||
continue
|
||||
if type(value[i]) is not type_:
|
||||
raise ValueError(f"The {i + 1}(th) type of value must be {type_},"
|
||||
f" because the column type is {column.data_type}")
|
||||
# 这里的 value 是一行数据,是一个多值列表
|
||||
values_str_ls.append(f"({', '.join([_to_string(val) for val in value])})")
|
||||
if val_type is int and col_type is float:
|
||||
pass
|
||||
# 支持将 NULL 值插入任意类型的列,除了 NOT NULL 限制的
|
||||
elif val_type is NullType and column.nullable is True:
|
||||
pass
|
||||
# 其他类型不匹配
|
||||
elif val_type is not col_type:
|
||||
raise ValueError(f"The {i + 1}(th) type of value must be {col_type},"
|
||||
f" because the column type is {column.data_type},"
|
||||
f" found {val_type}")
|
||||
# 如果加密
|
||||
if column.secure and val_type is not NullType:
|
||||
value = _to_string(_encrypt_blob(value, self._fernet))
|
||||
|
||||
value_row_str_ls.append(_to_string(value))
|
||||
|
||||
values_str_ls.append(f"({', '.join(value_row_str_ls)})")
|
||||
|
||||
values_str = ", ".join(values_str_ls)
|
||||
|
||||
@@ -391,8 +409,8 @@ class Sqlite3Worker(object):
|
||||
statement = f"{head} {table_name} ({columns_str}) VALUES {values_str};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
|
||||
@staticmethod
|
||||
@@ -415,7 +433,7 @@ class Sqlite3Worker(object):
|
||||
where: Expression = None,
|
||||
order_by: list[str] | str = None,
|
||||
limit: int = None, offset: int = None,
|
||||
*, execute: bool = True) -> tuple[str, list[tuple]]:
|
||||
*, execute: bool = True) -> tuple[str, list[list]]:
|
||||
if len(columns) == 0:
|
||||
columns_str = "*"
|
||||
else:
|
||||
@@ -431,6 +449,22 @@ class Sqlite3Worker(object):
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
rows = self._cursor.fetchall()
|
||||
rows = [list(row) for row in rows] # 将每行转成列表,方便替换解密数据
|
||||
# 下面的整个循环都是为了找到需要解密的数据尝试解密
|
||||
for i in range(len(columns)):
|
||||
column = columns[i]
|
||||
if isinstance(column, Column) and column.secure:
|
||||
for row in rows:
|
||||
# 如果是加密的 BLOB 但是值不为 NULL 才解密
|
||||
if row[i] is not None and self._fernet is not None:
|
||||
# 不管是key错误还是密文错误,都是 InvalidToken,貌似没法区分
|
||||
# 因此如果有的数据不是加密过的,应该跳过,不应该影响之后的密文解密,
|
||||
# 因此这里还是得继续循环下去
|
||||
try:
|
||||
row[i] = self._fernet.decrypt(row[i])
|
||||
except InvalidToken:
|
||||
pass
|
||||
|
||||
return statement, rows
|
||||
else:
|
||||
return statement, []
|
||||
@@ -445,18 +479,30 @@ class Sqlite3Worker(object):
|
||||
statement = f"{body};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
|
||||
def update(self, table_name: str, new_values: list[tuple[Column, NullType | int | float | str | BlobType]],
|
||||
def update(self, table_name: str, new_values: list[tuple[Column | str, NullType | int | float | str | BlobType]],
|
||||
where: Expression = None,
|
||||
*, execute: bool = True, commit: bool = True) -> str:
|
||||
new_values_str_ls = []
|
||||
for column, value in new_values:
|
||||
if _get_data_type(type(value)) != column.data_type:
|
||||
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
|
||||
new_values_str_ls.append(f"{column.name} = {_to_string(value)}")
|
||||
if isinstance(column, Column):
|
||||
# 支持将 NULL 值填入任意类型的列,除了 NOT NULL 限制的
|
||||
if type(value) is NullType and column.nullable is True:
|
||||
pass
|
||||
elif _get_data_type(type(value)) != column.data_type:
|
||||
raise ValueError(f"Type of {column.name} must be {column.data_type}, found {type(value)}")
|
||||
|
||||
if column.secure and type(value) is not NullType:
|
||||
value = _encrypt_blob(value, self._fernet)
|
||||
|
||||
name = column.name
|
||||
else:
|
||||
name = column
|
||||
|
||||
new_values_str_ls.append(f"{name} = {_to_string(value)}")
|
||||
|
||||
head = f"UPDATE {table_name}"
|
||||
body = f"{head} SET {', '.join(new_values_str_ls)}"
|
||||
@@ -466,6 +512,6 @@ class Sqlite3Worker(object):
|
||||
statement = f"{body};"
|
||||
if execute:
|
||||
self._cursor.execute(statement)
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
if commit:
|
||||
self._conn.commit()
|
||||
return statement
|
||||
|
||||
@@ -4,6 +4,8 @@ import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def path_not_exist(path: str | Path) -> bool:
|
||||
"""
|
||||
@@ -50,7 +52,6 @@ def read_config(org_name: str, app_name: str) -> dict:
|
||||
config_path = get_config_path(org_name, app_name)
|
||||
if not config_path.exists():
|
||||
config = {
|
||||
"table_name": "entries",
|
||||
"button_min_width": 120,
|
||||
"last_db_path": "",
|
||||
"loaded_memory": {}
|
||||
@@ -76,3 +77,15 @@ def get_default_db_path(config: dict, org_name: str, app_name: str) -> str:
|
||||
def get_secrets_path(org_name: str, app_name: str) -> str:
|
||||
app_dir = get_app_dir(org_name, app_name)
|
||||
return str(app_dir / "secrets.db")
|
||||
|
||||
|
||||
def get_or_generate_key(db_name: str, org_name: str, app_name: str) -> bytes:
|
||||
app_dir = get_app_dir(org_name, app_name)
|
||||
name = Path(db_name).name
|
||||
key_path = app_dir / f"{name}.key"
|
||||
if key_path.exists():
|
||||
key = key_path.read_bytes()
|
||||
else:
|
||||
key = Fernet.generate_key()
|
||||
key_path.write_bytes(key)
|
||||
return key
|
||||
|
||||
@@ -5,10 +5,10 @@ columns_d = {
|
||||
"entry_id": Column("entry_id", DataType.INTEGER, primary_key=True, unique=True),
|
||||
"title": Column("title", DataType.BLOB),
|
||||
"username": Column("username", DataType.BLOB),
|
||||
"password": Column("password", DataType.BLOB),
|
||||
"opt": Column("opt", DataType.TEXT),
|
||||
"password": Column("password", DataType.BLOB, secure=True),
|
||||
"opt": Column("opt", DataType.BLOB, secure=True),
|
||||
"url": Column("url", DataType.BLOB),
|
||||
"notes": Column("notes", DataType.BLOB),
|
||||
"notes": Column("notes", DataType.BLOB, secure=True),
|
||||
"uuid": Column("uuid", DataType.TEXT, nullable=False),
|
||||
"filepath": Column("filepath", DataType.BLOB, nullable=False),
|
||||
"path": Column("path", DataType.BLOB),
|
||||
|
||||
@@ -38,7 +38,7 @@ def read_kps_to_db(kps_file: str | PathLike[str], password: str,
|
||||
blob_fy(trim_str(entry.title)),
|
||||
blob_fy(trim_str(entry.username)),
|
||||
blob_fy(entry.password),
|
||||
extract_otp(entry.otp),
|
||||
blob_fy(extract_otp(entry.otp)),
|
||||
blob_fy(trim_str(entry.url)),
|
||||
blob_fy(entry.notes),
|
||||
str(entry.uuid),
|
||||
|
||||
@@ -4,7 +4,7 @@ from .Sqlite3Helper import Column, DataType
|
||||
sec_columns_d = {
|
||||
"secret_id": Column("secret_id", DataType.INTEGER, primary_key=True, unique=True),
|
||||
"filepath": Column("filepath", DataType.BLOB),
|
||||
"password": Column("password", DataType.BLOB),
|
||||
"password": Column("password", DataType.BLOB, secure=True),
|
||||
}
|
||||
|
||||
sec_all_columns = [
|
||||
|
||||
5
main.py
5
main.py
@@ -11,7 +11,7 @@ from lib.config_utils import (
|
||||
from src.mw_kps_unifier import KpsUnifier
|
||||
import src.rc_kps_unifier
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__version__ = '0.2.0'
|
||||
__version_info__ = tuple(map(int, __version__.split('.')))
|
||||
|
||||
ORG_NAME = "JnPrograms"
|
||||
@@ -27,7 +27,8 @@ def main():
|
||||
db_path = get_default_db_path(config, ORG_NAME, APP_NAME)
|
||||
secrets_path = get_secrets_path(ORG_NAME, APP_NAME)
|
||||
|
||||
win = KpsUnifier(db_path, secrets_path, config, __version__)
|
||||
win = KpsUnifier(db_path, secrets_path, config,
|
||||
ORG_NAME, APP_NAME, __version__)
|
||||
win.show()
|
||||
return app.exec()
|
||||
|
||||
|
||||
@@ -71,12 +71,11 @@ class DaEntryInfo(QtWidgets.QDialog):
|
||||
def __init__(
|
||||
self,
|
||||
entry_id: int,
|
||||
config: dict,
|
||||
sqh: Sqlite3Worker,
|
||||
parent: QtWidgets.QWidget = None
|
||||
):
|
||||
super().__init__(parent)
|
||||
_, results = sqh.select(config["table_name"], all_columns,
|
||||
_, results = sqh.select("entries", all_columns,
|
||||
where=Operand(entry_id_col).equal_to(entry_id))
|
||||
|
||||
entry = results[0]
|
||||
@@ -88,5 +87,3 @@ class DaEntryInfo(QtWidgets.QDialog):
|
||||
|
||||
def sizeHint(self):
|
||||
return QtCore.QSize(640, 360)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# coding: utf8
|
||||
from PySide6 import QtWidgets, QtCore, QtGui
|
||||
from pykeepass import PyKeePass
|
||||
from pykeepass.exceptions import HeaderChecksumError, CredentialsError
|
||||
|
||||
|
||||
class DaTargetLogin(QtWidgets.QDialog):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setWindowTitle("目标文件")
|
||||
class UiDaTargetLogin(object):
|
||||
def __init__(self, window: QtWidgets.QWidget):
|
||||
window.setWindowTitle("目标文件")
|
||||
self.vly_m = QtWidgets.QVBoxLayout()
|
||||
self.setLayout(self.vly_m)
|
||||
window.setLayout(self.vly_m)
|
||||
|
||||
icon_ellipsis = QtGui.QIcon(":/asset/img/ellipsis.svg")
|
||||
self.icon_eye = QtGui.QIcon(":/asset/img/eye.svg")
|
||||
@@ -15,61 +16,77 @@ class DaTargetLogin(QtWidgets.QDialog):
|
||||
|
||||
self.hly_path = QtWidgets.QHBoxLayout()
|
||||
self.vly_m.addLayout(self.hly_path)
|
||||
self.lb_path = QtWidgets.QLabel("路径:", self)
|
||||
self.lne_path = QtWidgets.QLineEdit(self)
|
||||
self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=self)
|
||||
self.lb_path = QtWidgets.QLabel("路径:", window)
|
||||
self.lne_path = QtWidgets.QLineEdit(window)
|
||||
self.pbn_browse = QtWidgets.QPushButton(icon=icon_ellipsis, parent=window)
|
||||
self.hly_path.addWidget(self.lb_path)
|
||||
self.hly_path.addWidget(self.lne_path)
|
||||
self.hly_path.addWidget(self.pbn_browse)
|
||||
|
||||
self.hly_password = QtWidgets.QHBoxLayout()
|
||||
self.vly_m.addLayout(self.hly_password)
|
||||
self.lb_password = QtWidgets.QLabel("密码:", self)
|
||||
self.lne_password = QtWidgets.QLineEdit(self)
|
||||
self.lb_password = QtWidgets.QLabel("密码:", window)
|
||||
self.lne_password = QtWidgets.QLineEdit(window)
|
||||
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
|
||||
self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=self)
|
||||
self.pbn_eye = QtWidgets.QPushButton(icon=self.icon_eye_off, parent=window)
|
||||
self.hly_password.addWidget(self.lb_password)
|
||||
self.hly_password.addWidget(self.lne_password)
|
||||
self.hly_password.addWidget(self.pbn_eye)
|
||||
|
||||
self.hly_bottom = QtWidgets.QHBoxLayout()
|
||||
self.vly_m.addLayout(self.hly_bottom)
|
||||
self.pbn_ok = QtWidgets.QPushButton("确定", self)
|
||||
self.pbn_cancel = QtWidgets.QPushButton("取消", self)
|
||||
self.pbn_ok = QtWidgets.QPushButton("确定", window)
|
||||
self.pbn_cancel = QtWidgets.QPushButton("取消", window)
|
||||
self.hly_bottom.addStretch(1)
|
||||
self.hly_bottom.addWidget(self.pbn_ok)
|
||||
self.hly_bottom.addWidget(self.pbn_cancel)
|
||||
|
||||
self.vly_m.addStretch(1)
|
||||
|
||||
self.pbn_browse.clicked.connect(self.on_pbn_browse_clicked)
|
||||
self.pbn_eye.clicked.connect(self.on_pbn_eye_clicked)
|
||||
self.pbn_ok.clicked.connect(self.on_pbn_ok_clicked)
|
||||
self.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked)
|
||||
self.pbn_ok.setFocus()
|
||||
|
||||
|
||||
class DaTargetLogin(QtWidgets.QDialog):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.ui = UiDaTargetLogin(self)
|
||||
self.tar_kp: PyKeePass | None = None
|
||||
|
||||
self.ui.pbn_browse.clicked.connect(self.on_pbn_browse_clicked)
|
||||
self.ui.pbn_eye.clicked.connect(self.on_pbn_eye_clicked)
|
||||
self.ui.pbn_ok.clicked.connect(self.on_pbn_ok_clicked)
|
||||
self.ui.pbn_cancel.clicked.connect(self.on_pbn_cancel_clicked)
|
||||
|
||||
def on_pbn_browse_clicked(self):
|
||||
filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, "选择", "../",
|
||||
filter="KeePass 2 数据库 (*.kdbx);;所有文件 (*)")
|
||||
if len(filename) == 0:
|
||||
return
|
||||
self.lne_path.setText(filename)
|
||||
self.ui.lne_path.setText(filename)
|
||||
|
||||
def on_pbn_eye_clicked(self):
|
||||
if self.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password:
|
||||
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal)
|
||||
self.pbn_eye.setIcon(self.icon_eye)
|
||||
if self.ui.lne_password.echoMode() == QtWidgets.QLineEdit.EchoMode.Password:
|
||||
self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Normal)
|
||||
self.ui.pbn_eye.setIcon(self.ui.icon_eye)
|
||||
else:
|
||||
self.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
|
||||
self.pbn_eye.setIcon(self.icon_eye_off)
|
||||
self.ui.lne_password.setEchoMode(QtWidgets.QLineEdit.EchoMode.Password)
|
||||
self.ui.pbn_eye.setIcon(self.ui.icon_eye_off)
|
||||
|
||||
def sizeHint(self):
|
||||
return QtCore.QSize(540, 40)
|
||||
|
||||
def on_pbn_ok_clicked(self):
|
||||
try:
|
||||
self.tar_kp = PyKeePass(self.ui.lne_path.text(), self.ui.lne_password.text())
|
||||
except CredentialsError:
|
||||
QtWidgets.QMessageBox.critical(self, "错误", "keepass 密码错误")
|
||||
return
|
||||
except (FileNotFoundError, HeaderChecksumError):
|
||||
QtWidgets.QMessageBox.critical(self, "错误", "文件不存在或不是 keepass 文件")
|
||||
return
|
||||
|
||||
self.accept()
|
||||
|
||||
def on_pbn_cancel_clicked(self):
|
||||
self.tar_kp = None
|
||||
self.reject()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class GbxKpsLogin(QtWidgets.QGroupBox):
|
||||
kp = read_kps_to_db(
|
||||
kps_file=self.lne_path.text(),
|
||||
password=self.lne_password.text(),
|
||||
table_name=self.config["table_name"],
|
||||
table_name="entries",
|
||||
sqh=self.sqh
|
||||
)
|
||||
except CredentialsError:
|
||||
|
||||
@@ -10,7 +10,7 @@ from .cmbx_styles import StyleComboBox
|
||||
from lib.Sqlite3Helper import Sqlite3Worker
|
||||
from lib.db_columns_def import all_columns
|
||||
from lib.sec_db_columns_def import sec_all_columns
|
||||
from lib.config_utils import write_config
|
||||
from lib.config_utils import write_config, get_or_generate_key
|
||||
|
||||
|
||||
class UiKpsUnifier(object):
|
||||
@@ -78,10 +78,15 @@ class KpsUnifier(QtWidgets.QMainWindow):
|
||||
db_path: str,
|
||||
secrets_path: str,
|
||||
config: dict,
|
||||
org_name: str,
|
||||
app_name: str,
|
||||
version: str,
|
||||
parent: QtWidgets.QWidget = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.org_name = org_name
|
||||
self.app_name = app_name
|
||||
|
||||
self.db_path = db_path
|
||||
self.secrets_path = secrets_path
|
||||
self.config = config
|
||||
@@ -103,18 +108,23 @@ class KpsUnifier(QtWidgets.QMainWindow):
|
||||
|
||||
def __del__(self):
|
||||
self.config["last_db_path"] = self.db_path
|
||||
write_config(self.config,
|
||||
QtWidgets.QApplication.organizationName(),
|
||||
QtWidgets.QApplication.applicationName())
|
||||
write_config(self.config, self.org_name, self.app_name)
|
||||
|
||||
def sizeHint(self):
|
||||
return QtCore.QSize(860, 640)
|
||||
|
||||
def init_db(self) -> Sqlite3Worker:
|
||||
sqh = Sqlite3Worker(self.db_path)
|
||||
sqh.create_table(self.config["table_name"], all_columns, if_not_exists=True)
|
||||
key = get_or_generate_key(self.db_path, self.org_name, self.app_name)
|
||||
sqh = Sqlite3Worker(self.db_path, key)
|
||||
sqh.create_table("entries", all_columns, if_not_exists=True)
|
||||
return sqh
|
||||
|
||||
def init_secrets_db(self) -> Sqlite3Worker:
|
||||
key = get_or_generate_key(self.secrets_path, self.org_name, self.app_name)
|
||||
sec_sqh = Sqlite3Worker(self.secrets_path, key)
|
||||
sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True)
|
||||
return sec_sqh
|
||||
|
||||
def update_db(self, filename: str):
|
||||
self.db_path = filename
|
||||
self.sqh = self.init_db()
|
||||
@@ -153,8 +163,3 @@ class KpsUnifier(QtWidgets.QMainWindow):
|
||||
|
||||
def on_act_about_qt_triggered(self):
|
||||
QtWidgets.QMessageBox.aboutQt(self, "关于 Qt")
|
||||
|
||||
def init_secrets_db(self) -> Sqlite3Worker:
|
||||
sec_sqh = Sqlite3Worker(self.secrets_path)
|
||||
sec_sqh.create_table("secrets", sec_all_columns, if_not_exists=True)
|
||||
return sec_sqh
|
||||
|
||||
@@ -5,7 +5,8 @@ from PySide6 import QtWidgets
|
||||
from pykeepass import PyKeePass
|
||||
|
||||
from .gbx_kps_login import GbxKpsLogin
|
||||
from .utils import accept_warning
|
||||
from .da_target_login import DaTargetLogin
|
||||
from .utils import accept_warning, HorizontalLine
|
||||
from lib.Sqlite3Helper import Sqlite3Worker
|
||||
|
||||
|
||||
@@ -81,6 +82,15 @@ class PageLoad(QtWidgets.QWidget):
|
||||
self.vly_left = QtWidgets.QVBoxLayout()
|
||||
self.hly_m.addLayout(self.vly_left)
|
||||
|
||||
self.pbn_set_target = QtWidgets.QPushButton("设置目标文件", self)
|
||||
# 暂时没用
|
||||
self.pbn_set_target.setDisabled(True)
|
||||
self.pbn_set_target.setMinimumWidth(config["button_min_width"])
|
||||
self.vly_left.addWidget(self.pbn_set_target)
|
||||
|
||||
self.hln_1 = HorizontalLine(self)
|
||||
self.vly_left.addWidget(self.hln_1)
|
||||
|
||||
self.pbn_add_kps = QtWidgets.QPushButton("添加 KPS", self)
|
||||
self.pbn_add_kps.setMinimumWidth(config["button_min_width"])
|
||||
self.vly_left.addWidget(self.pbn_add_kps)
|
||||
@@ -101,6 +111,8 @@ class PageLoad(QtWidgets.QWidget):
|
||||
self.wg_sa = WgLoadKps(config, file_kp, sqh, sec_sqh, self.sa_m)
|
||||
self.sa_m.setWidget(self.wg_sa)
|
||||
|
||||
self.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked)
|
||||
|
||||
self.pbn_add_kps.clicked.connect(self.on_pbn_add_kps_clicked)
|
||||
self.pbn_clear_db.clicked.connect(self.on_pbn_clear_db_clicked)
|
||||
self.pbn_clear_loaded_mem.clicked.connect(self.on_pbn_clear_loaded_mem_clicked)
|
||||
@@ -121,7 +133,7 @@ class PageLoad(QtWidgets.QWidget):
|
||||
return
|
||||
|
||||
try:
|
||||
self.sqh.delete_from(self.config["table_name"])
|
||||
self.sqh.delete_from("entries")
|
||||
except sqlite3.OperationalError as e:
|
||||
QtWidgets.QMessageBox.critical(self, "错误", f"清空数据库失败:\n{e}")
|
||||
else:
|
||||
@@ -148,3 +160,7 @@ class PageLoad(QtWidgets.QWidget):
|
||||
# 更新kps加载状态
|
||||
for wg in self.wg_sa.kps_wgs:
|
||||
self.wg_sa.update_load_status(wg)
|
||||
|
||||
def on_pbn_set_target_clicked(self):
|
||||
da_target_login = DaTargetLogin(self)
|
||||
da_target_login.exec()
|
||||
|
||||
@@ -5,7 +5,6 @@ from PySide6 import QtWidgets, QtCore, QtGui
|
||||
from pykeepass import PyKeePass
|
||||
|
||||
from .da_entry_info import DaEntryInfo
|
||||
from .da_target_login import DaTargetLogin
|
||||
from .utils import HorizontalLine, get_filepath_uuids_map, accept_warning
|
||||
from lib.Sqlite3Helper import Sqlite3Worker, Expression, Operand
|
||||
from lib.db_columns_def import (
|
||||
@@ -18,7 +17,7 @@ from lib.kps_operations import blob_fy
|
||||
|
||||
class QueryTableModel(QtCore.QAbstractTableModel):
|
||||
|
||||
def __init__(self, query_results: list[tuple], parent=None):
|
||||
def __init__(self, query_results: list[list], parent=None):
|
||||
super().__init__(parent)
|
||||
self.query_results = query_results
|
||||
self.headers = ["序号", "标题", "用户名", "URL"]
|
||||
@@ -98,10 +97,6 @@ class UiPageQuery(object):
|
||||
self.pbn_execute.setMinimumWidth(config["button_min_width"])
|
||||
self.vly_sa_wg.addWidget(self.pbn_execute)
|
||||
|
||||
self.pbn_set_target = QtWidgets.QPushButton("目标文件", window)
|
||||
self.pbn_set_target.setMinimumWidth(config["button_min_width"])
|
||||
self.vly_sa_wg.addWidget(self.pbn_set_target)
|
||||
|
||||
self.hln_1 = HorizontalLine(window)
|
||||
self.vly_sa_wg.addWidget(self.hln_1)
|
||||
|
||||
@@ -147,7 +142,6 @@ class PageQuery(QtWidgets.QWidget):
|
||||
self.ui.act_delete.triggered_with_str.connect(self.on_act_mark_triggered_with_str)
|
||||
|
||||
self.ui.pbn_execute.clicked.connect(self.on_pbn_execute_clicked)
|
||||
self.ui.pbn_set_target.clicked.connect(self.on_pbn_set_target_clicked)
|
||||
|
||||
self.ui.pbn_all.clicked.connect(self.on_pbn_all_clicked)
|
||||
self.ui.pbn_deleted.clicked.connect(self.on_pbn_deleted_clicked)
|
||||
@@ -169,7 +163,7 @@ class PageQuery(QtWidgets.QWidget):
|
||||
},
|
||||
{
|
||||
"name": "谷歌文档",
|
||||
"where": "url LIKE 'https://docs.google.com/%' or url LIKE 'https://drive.google.com/%'"
|
||||
"where": "(url LIKE 'https://docs.google.com/%' OR url LIKE 'https://drive.google.com/%')"
|
||||
},
|
||||
]
|
||||
for fil in default_filters:
|
||||
@@ -193,7 +187,7 @@ class PageQuery(QtWidgets.QWidget):
|
||||
self.set_filter_button(fil)
|
||||
|
||||
def on_custom_filters_clicked_with_data(self, data: dict):
|
||||
_, results = self.sqh.select(self.config["table_name"], query_columns,
|
||||
_, results = self.sqh.select("entries", query_columns,
|
||||
where=Expression(data["where"]).and_(Operand(deleted_col).equal_to(0)))
|
||||
model = QueryTableModel(results, self)
|
||||
self.ui.trv_m.setModel(model)
|
||||
@@ -202,20 +196,20 @@ class PageQuery(QtWidgets.QWidget):
|
||||
self.sqh = sqh
|
||||
|
||||
def on_pbn_all_clicked(self):
|
||||
_, results = self.sqh.select(self.config["table_name"], query_columns,
|
||||
_, results = self.sqh.select("entries", query_columns,
|
||||
where=Operand(deleted_col).equal_to(0))
|
||||
model = QueryTableModel(results, self)
|
||||
self.ui.trv_m.setModel(model)
|
||||
|
||||
def on_pbn_deleted_clicked(self):
|
||||
_, results = self.sqh.select(self.config["table_name"], query_columns,
|
||||
_, results = self.sqh.select("entries", query_columns,
|
||||
where=Operand(deleted_col).equal_to(1))
|
||||
model = QueryTableModel(results, self)
|
||||
self.ui.trv_m.setModel(model)
|
||||
|
||||
def on_trv_m_double_clicked(self, index: QtCore.QModelIndex):
|
||||
entry_id = index.siblingAtColumn(0).data(QtCore.Qt.ItemDataRole.DisplayRole)
|
||||
da_entry_info = DaEntryInfo(entry_id, self.config, self.sqh, self)
|
||||
da_entry_info = DaEntryInfo(entry_id, self.sqh, self)
|
||||
da_entry_info.exec()
|
||||
|
||||
def on_trv_m_custom_context_menu_requested(self, pos: QtCore.QPoint):
|
||||
@@ -228,7 +222,7 @@ class PageQuery(QtWidgets.QWidget):
|
||||
for index in indexes if index.column() == 0
|
||||
]
|
||||
|
||||
self.sqh.update(self.config["table_name"], [(status_col, info)],
|
||||
self.sqh.update("entries", [(status_col, info)],
|
||||
where=Operand(entry_id_col).in_(entry_ids))
|
||||
|
||||
def on_pbn_execute_clicked(self):
|
||||
@@ -236,7 +230,7 @@ class PageQuery(QtWidgets.QWidget):
|
||||
return
|
||||
|
||||
# 删除功能
|
||||
_, results = self.sqh.select(self.config["table_name"], sim_columns,
|
||||
_, results = self.sqh.select("entries", sim_columns,
|
||||
where=Operand(status_col).equal_to("delete"))
|
||||
file_uuids = get_filepath_uuids_map(results)
|
||||
|
||||
@@ -252,7 +246,7 @@ class PageQuery(QtWidgets.QWidget):
|
||||
|
||||
for u in file_uuids[file]:
|
||||
total += 1
|
||||
self.sqh.update(self.config["table_name"], [(deleted_col, 1)],
|
||||
self.sqh.update("entries", [(deleted_col, 1)],
|
||||
where=Operand(uuid_col).equal_to(u).and_(
|
||||
Operand(filepath_col).equal_to(blob_fy(file))))
|
||||
|
||||
@@ -269,10 +263,6 @@ class PageQuery(QtWidgets.QWidget):
|
||||
QtWidgets.QMessageBox.information(self, "提示",
|
||||
f"共 {total} 条标记的条目,已删除 {success} 条,无效 {invalid} 条。")
|
||||
|
||||
def on_pbn_set_target_clicked(self):
|
||||
da_target_login = DaTargetLogin(self)
|
||||
da_target_login.exec()
|
||||
|
||||
|
||||
class PushButtonWithData(QtWidgets.QPushButton):
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# coding: utf8
|
||||
from itertools import combinations
|
||||
from uuid import UUID
|
||||
from PySide6 import QtWidgets, QtCore
|
||||
from PySide6.QtCore import QAbstractTableModel
|
||||
|
||||
@@ -70,7 +69,7 @@ class PageSimilar(QtWidgets.QWidget):
|
||||
self.pbn_delete_invalid_data.clicked.connect(self.on_pbn_delete_invalid_data_clicked)
|
||||
|
||||
def on_pbn_read_db_clicked(self):
|
||||
_, results = self.sqh.select(self.config["table_name"], sim_columns)
|
||||
_, results = self.sqh.select("entries", sim_columns)
|
||||
file_uuids = get_filepath_uuids_map(results)
|
||||
|
||||
files = file_uuids.keys()
|
||||
@@ -98,11 +97,11 @@ class PageSimilar(QtWidgets.QWidget):
|
||||
if accept_warning(self, True, "警告", "你确定要从数据库删除无效文件的记录吗?"):
|
||||
return
|
||||
|
||||
_, filepaths = self.sqh.select(self.config["table_name"], [filepath_col,])
|
||||
_, filepaths = self.sqh.select("entries", [filepath_col,])
|
||||
unique_filepaths = set([p[0].decode("utf8") for p in filepaths])
|
||||
invalid_filepaths = [p for p in unique_filepaths if path_not_exist(p)]
|
||||
for path in invalid_filepaths:
|
||||
self.sqh.delete_from(self.config["table_name"],
|
||||
self.sqh.delete_from("entries",
|
||||
where=Operand(filepath_col).equal_to(blob_fy(path)),
|
||||
commit=False)
|
||||
self.sqh.commit()
|
||||
|
||||
@@ -19,7 +19,7 @@ class HorizontalLine(QtWidgets.QFrame):
|
||||
self.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
|
||||
|
||||
|
||||
def get_filepath_uuids_map(query_results: list[tuple]) -> dict[str, list[str]]:
|
||||
def get_filepath_uuids_map(query_results: list[list]) -> dict[str, list[str]]:
|
||||
file_uuids: dict[str, list[str]] = {}
|
||||
for u, filepath in query_results:
|
||||
filepath = filepath.decode("utf8")
|
||||
|
||||
Reference in New Issue
Block a user