129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
from fastapi import FastAPI, File, UploadFile, Form
|
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.requests import Request
|
|
from typing import List
|
|
from pathlib import Path
|
|
from processing import process_image
|
|
from zipfile import ZipFile
|
|
import shutil
|
|
import uuid
|
|
from rembg import remove
|
|
import os
|
|
|
|
app = FastAPI()
|
|
|
|
# 文件路径设置
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
STATIC_DIR = BASE_DIR / "static"
|
|
OUTPUT_DIR = STATIC_DIR / "output"
|
|
UPLOAD_DIR = STATIC_DIR / "uploads"
|
|
TEMPLATE_DIR = BASE_DIR / "templates"
|
|
|
|
# 确保目录存在
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 挂载静态目录和模板
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def index(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
|
|
@app.post("/upload")
|
|
async def upload_images(files: List[UploadFile] = File(...)):
|
|
session_id = str(uuid.uuid4())
|
|
session_upload_dir = UPLOAD_DIR / session_id
|
|
session_output_dir = OUTPUT_DIR / session_id
|
|
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
|
|
|
|
session_upload_dir.mkdir(parents=True)
|
|
session_output_dir.mkdir(parents=True)
|
|
|
|
for file in files:
|
|
contents = await file.read()
|
|
input_path = session_upload_dir / file.filename
|
|
with open(input_path, "wb") as f:
|
|
f.write(contents)
|
|
|
|
output_path = session_output_dir / file.filename
|
|
process_image(input_path, output_path)
|
|
|
|
# 打包为 zip
|
|
with ZipFile(session_zip_path, 'w') as zipf:
|
|
for image_file in session_output_dir.iterdir():
|
|
zipf.write(image_file, arcname=image_file.name)
|
|
|
|
# 清理上传目录(可选)
|
|
shutil.rmtree(session_upload_dir)
|
|
shutil.rmtree(session_output_dir)
|
|
|
|
return {"download_url": f"/static/output/{session_zip_path.name}"}
|
|
|
|
|
|
@app.get("/download/{zip_filename}")
|
|
async def download_zip(zip_filename: str):
|
|
file_path = OUTPUT_DIR / zip_filename
|
|
if file_path.exists():
|
|
return FileResponse(path=file_path, filename=zip_filename, media_type='application/zip')
|
|
return {"error": "File not found"}
|
|
|
|
|
|
@app.post("/remove-bg")
|
|
async def remove_bg(files: List[UploadFile] = File(...)):
|
|
session_id = str(uuid.uuid4())
|
|
session_upload_dir = UPLOAD_DIR / session_id
|
|
session_output_dir = OUTPUT_DIR / session_id
|
|
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
|
|
|
|
session_upload_dir.mkdir(parents=True, exist_ok=True)
|
|
session_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
result_files = []
|
|
|
|
for file in files:
|
|
if file.filename == '':
|
|
return JSONResponse(content={"error": "No selected file"}, status_code=400)
|
|
|
|
input_path = session_upload_dir / file.filename
|
|
output_path = session_output_dir / file.filename
|
|
|
|
contents = await file.read()
|
|
with open(input_path, "wb") as f:
|
|
f.write(contents)
|
|
|
|
with open(input_path, 'rb') as input_file:
|
|
with open(output_path, 'wb') as output_file:
|
|
input_data = input_file.read()
|
|
output_data = remove(
|
|
input_data,
|
|
alpha_matting=True,
|
|
alpha_matting_erode_size=15,
|
|
alpha_matting_background_threshold=5,
|
|
alpha_matting_foreground_threshold=250,
|
|
)
|
|
output_file.write(output_data)
|
|
|
|
result_files.append(f'/static/output/{session_id}/{file.filename}')
|
|
|
|
# 打包为 zip 文件
|
|
with ZipFile(session_zip_path, 'w') as zipf:
|
|
for image_file in session_output_dir.iterdir():
|
|
zipf.write(image_file, arcname=image_file.name)
|
|
|
|
# 清理临时目录
|
|
shutil.rmtree(session_upload_dir)
|
|
shutil.rmtree(session_output_dir)
|
|
|
|
return JSONResponse(content={"download_url": f"/static/output/{session_zip_path.name}"})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="127.0.0.1", port=7310)
|