add basic files
This commit is contained in:
128
main.py
Normal file
128
main.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from fastapi import FastAPI, File, UploadFile, Form
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.requests import Request
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from processing import process_image
|
||||
from zipfile import ZipFile
|
||||
import shutil
|
||||
import uuid
|
||||
from rembg import remove
|
||||
import os
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 文件路径设置
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
STATIC_DIR = BASE_DIR / "static"
|
||||
OUTPUT_DIR = STATIC_DIR / "output"
|
||||
UPLOAD_DIR = STATIC_DIR / "uploads"
|
||||
TEMPLATE_DIR = BASE_DIR / "templates"
|
||||
|
||||
# 确保目录存在
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 挂载静态目录和模板
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
|
||||
|
||||
@app.post("/upload")
|
||||
async def upload_images(files: List[UploadFile] = File(...)):
|
||||
session_id = str(uuid.uuid4())
|
||||
session_upload_dir = UPLOAD_DIR / session_id
|
||||
session_output_dir = OUTPUT_DIR / session_id
|
||||
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
|
||||
|
||||
session_upload_dir.mkdir(parents=True)
|
||||
session_output_dir.mkdir(parents=True)
|
||||
|
||||
for file in files:
|
||||
contents = await file.read()
|
||||
input_path = session_upload_dir / file.filename
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
output_path = session_output_dir / file.filename
|
||||
process_image(input_path, output_path)
|
||||
|
||||
# 打包为 zip
|
||||
with ZipFile(session_zip_path, 'w') as zipf:
|
||||
for image_file in session_output_dir.iterdir():
|
||||
zipf.write(image_file, arcname=image_file.name)
|
||||
|
||||
# 清理上传目录(可选)
|
||||
shutil.rmtree(session_upload_dir)
|
||||
shutil.rmtree(session_output_dir)
|
||||
|
||||
return {"download_url": f"/static/output/{session_zip_path.name}"}
|
||||
|
||||
|
||||
@app.get("/download/{zip_filename}")
|
||||
async def download_zip(zip_filename: str):
|
||||
file_path = OUTPUT_DIR / zip_filename
|
||||
if file_path.exists():
|
||||
return FileResponse(path=file_path, filename=zip_filename, media_type='application/zip')
|
||||
return {"error": "File not found"}
|
||||
|
||||
|
||||
@app.post("/remove-bg")
|
||||
async def remove_bg(files: List[UploadFile] = File(...)):
|
||||
session_id = str(uuid.uuid4())
|
||||
session_upload_dir = UPLOAD_DIR / session_id
|
||||
session_output_dir = OUTPUT_DIR / session_id
|
||||
session_zip_path = OUTPUT_DIR / f"{session_id}.zip"
|
||||
|
||||
session_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
session_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result_files = []
|
||||
|
||||
for file in files:
|
||||
if file.filename == '':
|
||||
return JSONResponse(content={"error": "No selected file"}, status_code=400)
|
||||
|
||||
input_path = session_upload_dir / file.filename
|
||||
output_path = session_output_dir / file.filename
|
||||
|
||||
contents = await file.read()
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
with open(input_path, 'rb') as input_file:
|
||||
with open(output_path, 'wb') as output_file:
|
||||
input_data = input_file.read()
|
||||
output_data = remove(
|
||||
input_data,
|
||||
alpha_matting=True,
|
||||
alpha_matting_erode_size=15,
|
||||
alpha_matting_background_threshold=5,
|
||||
alpha_matting_foreground_threshold=250,
|
||||
)
|
||||
output_file.write(output_data)
|
||||
|
||||
result_files.append(f'/static/output/{session_id}/{file.filename}')
|
||||
|
||||
# 打包为 zip 文件
|
||||
with ZipFile(session_zip_path, 'w') as zipf:
|
||||
for image_file in session_output_dir.iterdir():
|
||||
zipf.write(image_file, arcname=image_file.name)
|
||||
|
||||
# 清理临时目录
|
||||
shutil.rmtree(session_upload_dir)
|
||||
shutil.rmtree(session_output_dir)
|
||||
|
||||
return JSONResponse(content={"download_url": f"/static/output/{session_zip_path.name}"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="127.0.0.1", port=7310)
|
||||
Reference in New Issue
Block a user