from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.requests import Request from typing import List from pathlib import Path from processing import process_image from zipfile import ZipFile import shutil import uuid from rembg import remove import os app = FastAPI() # 文件路径设置 BASE_DIR = Path(__file__).resolve().parent STATIC_DIR = BASE_DIR / "static" OUTPUT_DIR = STATIC_DIR / "output" UPLOAD_DIR = STATIC_DIR / "uploads" TEMPLATE_DIR = BASE_DIR / "templates" # 确保目录存在 OUTPUT_DIR.mkdir(parents=True, exist_ok=True) UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # 挂载静态目录和模板 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") templates = Jinja2Templates(directory=TEMPLATE_DIR) @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/upload") async def upload_images(files: List[UploadFile] = File(...)): session_id = str(uuid.uuid4()) session_upload_dir = UPLOAD_DIR / session_id session_output_dir = OUTPUT_DIR / session_id session_zip_path = OUTPUT_DIR / f"{session_id}.zip" session_upload_dir.mkdir(parents=True) session_output_dir.mkdir(parents=True) for file in files: contents = await file.read() input_path = session_upload_dir / file.filename with open(input_path, "wb") as f: f.write(contents) output_path = session_output_dir / file.filename process_image(input_path, output_path) # 打包为 zip with ZipFile(session_zip_path, 'w') as zipf: for image_file in session_output_dir.iterdir(): zipf.write(image_file, arcname=image_file.name) # 清理上传目录(可选) shutil.rmtree(session_upload_dir) shutil.rmtree(session_output_dir) return {"download_url": f"/static/output/{session_zip_path.name}"} @app.get("/download/{zip_filename}") async def download_zip(zip_filename: str): file_path = OUTPUT_DIR / zip_filename if file_path.exists(): return FileResponse(path=file_path, filename=zip_filename, media_type='application/zip') return {"error": "File not found"} @app.post("/remove-bg") async def remove_bg(files: List[UploadFile] = File(...)): session_id = str(uuid.uuid4()) session_upload_dir = UPLOAD_DIR / session_id session_output_dir = OUTPUT_DIR / session_id session_zip_path = OUTPUT_DIR / f"{session_id}.zip" session_upload_dir.mkdir(parents=True, exist_ok=True) session_output_dir.mkdir(parents=True, exist_ok=True) result_files = [] for file in files: if file.filename == '': return JSONResponse(content={"error": "No selected file"}, status_code=400) input_path = session_upload_dir / file.filename output_path = session_output_dir / file.filename contents = await file.read() with open(input_path, "wb") as f: f.write(contents) with open(input_path, 'rb') as input_file: with open(output_path, 'wb') as output_file: input_data = input_file.read() output_data = remove( input_data, alpha_matting=True, alpha_matting_erode_size=15, alpha_matting_background_threshold=5, alpha_matting_foreground_threshold=250, ) output_file.write(output_data) result_files.append(f'/static/output/{session_id}/{file.filename}') # 打包为 zip 文件 with ZipFile(session_zip_path, 'w') as zipf: for image_file in session_output_dir.iterdir(): zipf.write(image_file, arcname=image_file.name) # 清理临时目录 shutil.rmtree(session_upload_dir) shutil.rmtree(session_output_dir) return JSONResponse(content={"download_url": f"/static/output/{session_zip_path.name}"}) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7310)