import datetime from fastapi import FastAPI, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy import create_engine, Column, String, Integer, Text from sqlalchemy.orm import sessionmaker, Session, declarative_base from enum import IntEnum # --- 数据库设置 --- SQLALCHEMY_DATABASE_URL = "sqlite:///./safe_marks.db" engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # --- SQLAlchemy 模型 (表结构) --- class Extension(Base): __tablename__ = "extensions" ID = Column(String, primary_key=True, index=True) NAME = Column(String, nullable=False) SAFE = Column(Integer, nullable=False) UPDATE_DATE = Column(String, nullable=False) NOTES = Column(Text, nullable=True) # --- Pydantic 模型 (用于 API 数据验证) --- class SafeStatus(IntEnum): safe = 1 unsure = 0 unsafe = -1 unknown = -2 class ExtensionBase(BaseModel): ID: str = Field(..., description="插件唯一ID") NAME: str = Field(..., description="插件名称") NOTES: str | None = Field(None, description="备注") class ExtensionCreate(ExtensionBase): SAFE: SafeStatus = Field(..., description="安全状态: 1, 0, -1, -2") class ExtensionBatchCreateItem(ExtensionBase): pass class ExtensionUpdatePayload(BaseModel): NAME: str | None = None SAFE: SafeStatus | None = None NOTES: str | None = None class ExtensionInDB(ExtensionBase): SAFE: SafeStatus UPDATE_DATE: str model_config = { "from_attributes": True } class ExtensionNecessary(BaseModel): ID: str SAFE: SafeStatus model_config = { "from_attributes": True } # --- FastAPI 应用实例 --- app = FastAPI(title="Safe Marks API") # --- 数据库依赖 --- def get_db(): db = SessionLocal() try: yield db finally: db.close() # --- 辅助函数:获取数据库条目 --- def get_extension_by_id(db: Session, item_id: str) -> Extension | None: return db.query(Extension).filter(Extension.ID == item_id).first() # --- API 路由 --- @app.get("/api/v1/ext/query_all", response_model=list[ExtensionInDB], summary="查询所有插件 (完整信息)") def query_all_extensions(db: Session = Depends(get_db)): extensions = db.query(Extension).all() return extensions @app.get("/api/v1/ext/query_necessary", response_model=list[ExtensionNecessary], summary="查询所有插件 (仅 ID 和 SAFE)") def query_necessary_extensions(db: Session = Depends(get_db)): extensions = db.query(Extension).all() return extensions @app.post("/api/v1/ext/add_one", response_model=ExtensionInDB, summary="添加单个插件") def add_one_extension(extension: ExtensionCreate, db: Session = Depends(get_db)): if get_extension_by_id(db, extension.ID): raise HTTPException(status_code=400, detail=f"Extension with ID {extension.ID} already exists.") current_date = datetime.datetime.now().isoformat() # 已更新: 使用 .model_dump() 替代 .dict() db_extension = Extension( **extension.model_dump(), UPDATE_DATE=current_date ) db.add(db_extension) db.commit() db.refresh(db_extension) return db_extension @app.post("/api/v1/ext/add_batch", response_model=list[ExtensionInDB], summary="批量添加插件 (SAFE 自动设为 'unknown')") def add_batch_extensions(items: list[ExtensionBatchCreateItem], db: Session = Depends(get_db)): created_extensions = [] current_date = datetime.datetime.now().isoformat() for item in items: if get_extension_by_id(db, item.ID): raise HTTPException(status_code=400, detail=f"Extension with ID {item.ID} already exists.") # 已更新: 使用 .model_dump() 替代 .dict() db_extension = Extension( **item.model_dump(), SAFE=SafeStatus.unknown, UPDATE_DATE=current_date ) db.add(db_extension) created_extensions.append(db_extension) db.commit() for ext in created_extensions: db.refresh(ext) return created_extensions @app.put("/api/v1/ext/update_one/{item_id}", response_model=ExtensionInDB, summary="更新单个插件") def update_one_extension( item_id: str, update_data: ExtensionUpdatePayload, db: Session = Depends(get_db) ): db_ext = get_extension_by_id(db, item_id) if not db_ext: raise HTTPException(status_code=404, detail=f"Extension with ID {item_id} not found.") # 已更新: 使用 .model_dump() 替代 .dict() update_dict = update_data.model_dump(exclude_unset=True) if not update_dict: raise HTTPException(status_code=400, detail="No update data provided.") updated = False for key, value in update_dict.items(): if getattr(db_ext, key) != value: setattr(db_ext, key, value) updated = True if updated: db_ext.UPDATE_DATE = datetime.datetime.now().isoformat() db.commit() db.refresh(db_ext) return db_ext @app.delete("/api/v1/ext/delete_one/{item_id}", response_model=dict[str, str], summary="删除单个插件") def delete_one_extension(item_id: str, db: Session = Depends(get_db)): db_ext = get_extension_by_id(db, item_id) if not db_ext: raise HTTPException(status_code=404, detail=f"Extension with ID {item_id} not found.") db.delete(db_ext) db.commit() return {"message": f"Extension with ID {item_id} successfully deleted."} # --- 运行服务 (用于开发) --- # 注意:在生产环境中,应使用 Gunicorn 或 Uvicorn 命令行工具启动 # 例如: uvicorn api_server:app --reload # if __name__ == "__main__": # import uvicorn # # print("启动 FastAPI 服务,访问 http://127.0.0.1:8000/docs 查看 API 文档") # uvicorn.run(app, host="127.0.0.1", port=8000)