add basic files

This commit is contained in:
Julian Freeman
2025-06-26 13:54:35 -04:00
parent efdd76d11e
commit e6b7c8d34d
5 changed files with 521 additions and 0 deletions

128
main.py Normal file
View File

@@ -0,0 +1,128 @@
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.requests import Request
from typing import List
from pathlib import Path
from processing import process_image
from zipfile import ZipFile
import shutil
import uuid
from rembg import remove
import os
app = FastAPI()
# 文件路径设置
BASE_DIR = Path(__file__).resolve().parent
STATIC_DIR = BASE_DIR / "static"
OUTPUT_DIR = STATIC_DIR / "output"
UPLOAD_DIR = STATIC_DIR / "uploads"
TEMPLATE_DIR = BASE_DIR / "templates"
# 确保目录存在
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# 挂载静态目录和模板
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
templates = Jinja2Templates(directory=TEMPLATE_DIR)
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload")
async def upload_images(files: List[UploadFile] = File(...)):
session_id = str(uuid.uuid4())
session_upload_dir = UPLOAD_DIR / session_id
session_output_dir = OUTPUT_DIR / session_id
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
session_upload_dir.mkdir(parents=True)
session_output_dir.mkdir(parents=True)
for file in files:
contents = await file.read()
input_path = session_upload_dir / file.filename
with open(input_path, "wb") as f:
f.write(contents)
output_path = session_output_dir / file.filename
process_image(input_path, output_path)
# 打包为 zip
with ZipFile(session_zip_path, 'w') as zipf:
for image_file in session_output_dir.iterdir():
zipf.write(image_file, arcname=image_file.name)
# 清理上传目录(可选)
shutil.rmtree(session_upload_dir)
shutil.rmtree(session_output_dir)
return {"download_url": f"/static/output/{session_zip_path.name}"}
@app.get("/download/{zip_filename}")
async def download_zip(zip_filename: str):
file_path = OUTPUT_DIR / zip_filename
if file_path.exists():
return FileResponse(path=file_path, filename=zip_filename, media_type='application/zip')
return {"error": "File not found"}
@app.post("/remove-bg")
async def remove_bg(files: List[UploadFile] = File(...)):
session_id = str(uuid.uuid4())
session_upload_dir = UPLOAD_DIR / session_id
session_output_dir = OUTPUT_DIR / session_id
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
session_upload_dir.mkdir(parents=True, exist_ok=True)
session_output_dir.mkdir(parents=True, exist_ok=True)
result_files = []
for file in files:
if file.filename == '':
return JSONResponse(content={"error": "No selected file"}, status_code=400)
input_path = session_upload_dir / file.filename
output_path = session_output_dir / file.filename
contents = await file.read()
with open(input_path, "wb") as f:
f.write(contents)
with open(input_path, 'rb') as input_file:
with open(output_path, 'wb') as output_file:
input_data = input_file.read()
output_data = remove(
input_data,
alpha_matting=True,
alpha_matting_erode_size=15,
alpha_matting_background_threshold=5,
alpha_matting_foreground_threshold=250,
)
output_file.write(output_data)
result_files.append(f'/static/output/{session_id}/{file.filename}')
# 打包为 zip 文件
with ZipFile(session_zip_path, 'w') as zipf:
for image_file in session_output_dir.iterdir():
zipf.write(image_file, arcname=image_file.name)
# 清理临时目录
shutil.rmtree(session_upload_dir)
shutil.rmtree(session_output_dir)
return JSONResponse(content={"download_url": f"/static/output/{session_zip_path.name}"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=7310)