201 lines
7.3 KiB
Python
201 lines
7.3 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File as FastAPIFile, Request
|
|
from fastapi.responses import StreamingResponse, RedirectResponse
|
|
from sqlalchemy.orm import Session
|
|
from . import crud, models, schemas
|
|
from .database import SessionLocal, init_db
|
|
from .services import drive, file_handler
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from dotenv import load_dotenv
|
|
import os
|
|
|
|
load_dotenv()
|
|
|
|
app = FastAPI(
|
|
title="MultiDrive Box",
|
|
description="Union de varias cuentas de Google Drive para formar un único pool de almacenamiento.",
|
|
version="0.1.0"
|
|
)
|
|
|
|
# Add CORS middleware to allow cross-origin requests
|
|
# This is crucial for the frontend (e.g., running on port 3000)
|
|
# to communicate with the backend (running on port 8000).
|
|
origins = [
|
|
"http://localhost:3000",
|
|
"http://127.0.0.1:3000",
|
|
]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True, # Allow cookies to be sent and received
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Add session middleware with a fixed secret key
|
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
|
if not SECRET_KEY:
|
|
raise ValueError("No SECRET_KEY set for session middleware.")
|
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
init_db()
|
|
|
|
@app.get("/api/storage-status", response_model=schemas.StorageStatus)
|
|
def get_storage_status(db: Session = Depends(get_db)):
|
|
# Dummy user for now
|
|
user = crud.get_user(db, 1)
|
|
if not user:
|
|
user = crud.create_user(db, schemas.UserCreate(email="dummy@example.com"))
|
|
|
|
accounts = crud.get_accounts(db, user_id=user.id)
|
|
total_space = sum(acc.drive_space_total for acc in accounts)
|
|
used_space = sum(acc.drive_space_used for acc in accounts)
|
|
return {
|
|
"total_space": total_space,
|
|
"used_space": used_space,
|
|
"free_space": total_space - used_space,
|
|
"accounts": accounts
|
|
}
|
|
|
|
@app.post("/api/upload-file")
|
|
async def upload_file(file: UploadFile = FastAPIFile(...), db: Session = Depends(get_db)):
|
|
# Dummy user
|
|
user = crud.get_user(db, 1)
|
|
if not user:
|
|
user = crud.create_user(db, schemas.UserCreate(email="dummy@example.com"))
|
|
|
|
accounts = crud.get_accounts(db, user_id=user.id)
|
|
if not accounts:
|
|
raise HTTPException(status_code=400, detail="No Google Drive accounts linked.")
|
|
|
|
file_path, sha256 = await file_handler.save_temp_file(file)
|
|
|
|
parts = file_handler.split_file(file_path, accounts)
|
|
|
|
uploaded_parts = []
|
|
for part in parts:
|
|
drive_service = drive.get_drive_service(credentials_info=part['account'].credentials)
|
|
if not drive_service:
|
|
raise HTTPException(status_code=401, detail=f"Could not get drive service for account {part['account'].google_email}")
|
|
drive_file_id = drive_service.upload_file(part['path'], part['account'])
|
|
uploaded_parts.append({
|
|
"part_index": part['index'],
|
|
"size": part['size'],
|
|
"account_id": part['account'].id,
|
|
"drive_file_id": drive_file_id,
|
|
"sha256": part['sha256']
|
|
})
|
|
|
|
download_token = file_handler.generate_token()
|
|
file_data = schemas.FileCreate(
|
|
filename=file.filename,
|
|
original_size=file.size,
|
|
sha256=sha256,
|
|
download_token=download_token
|
|
)
|
|
|
|
crud.create_file(db, file=file_data, user_id=user.id, parts=uploaded_parts)
|
|
|
|
return {"download_link": f"/api/file/{download_token}"}
|
|
|
|
@app.get("/api/file/{token}", response_model=schemas.File)
|
|
def get_file_metadata(token: str, db: Session = Depends(get_db)):
|
|
db_file = crud.get_file_by_token(db, token)
|
|
if not db_file:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
return db_file
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
import shutil
|
|
import tempfile
|
|
|
|
@app.get("/api/file/{token}/download")
|
|
def download_file(token: str, db: Session = Depends(get_db)):
|
|
db_file = crud.get_file_by_token(db, token)
|
|
if not db_file:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
temp_dir = tempfile.mkdtemp()
|
|
part_paths = []
|
|
for part in sorted(db_file.parts, key=lambda p: p.part_index):
|
|
account = crud.get_account(db, part.account_id)
|
|
drive_service = drive.get_drive_service(credentials_info=account.credentials)
|
|
if not drive_service:
|
|
raise HTTPException(status_code=401, detail=f"Could not get drive service for account {account.google_email}")
|
|
|
|
part_path = os.path.join(temp_dir, f"{db_file.filename}.part{part.part_index}")
|
|
drive.download_file(drive_service, part.drive_file_id, part_path)
|
|
part_paths.append(part_path)
|
|
|
|
merged_file_path = os.path.join(temp_dir, db_file.filename)
|
|
file_handler.merge_files(part_paths, merged_file_path)
|
|
|
|
def file_iterator(file_path):
|
|
with open(file_path, 'rb') as f:
|
|
yield from f
|
|
# Clean up temp files
|
|
shutil.rmtree(temp_dir)
|
|
|
|
return StreamingResponse(file_iterator(merged_file_path), media_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename={db_file.filename}"})
|
|
|
|
@app.get("/api/add-account")
|
|
def add_account(request: Request):
|
|
print("Received request for /api/add-account")
|
|
try:
|
|
redirect_uri = request.url_for('oauth2callback')
|
|
print(f"Redirect URI for oauth2callback: {redirect_uri}")
|
|
authorization_url, state = drive.authenticate(redirect_uri)
|
|
print(f"Generated authorization URL: {authorization_url}")
|
|
request.session['state'] = state
|
|
print(f"Stored state in session: {state}")
|
|
# Directly redirect the user's browser to Google's authorization page
|
|
return RedirectResponse(authorization_url)
|
|
except Exception as e:
|
|
print(f"Error in /api/add-account: {e}")
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
@app.get("/api/oauth2callback")
|
|
def oauth2callback(request: Request, code: str, state: str, db: Session = Depends(get_db)):
|
|
session_state = request.session.get('state')
|
|
print(f"Callback received. State from Google: {state}, State from session: {session_state}")
|
|
|
|
if not session_state or state != session_state:
|
|
print("State mismatch error!")
|
|
print(f"Session content: {request.session}")
|
|
raise HTTPException(status_code=400, detail="State mismatch")
|
|
|
|
redirect_uri = request.url_for('oauth2callback')
|
|
credentials = drive.exchange_code_for_credentials(code, redirect_uri, state)
|
|
|
|
drive_service = drive.get_drive_service(credentials_info=credentials.to_json())
|
|
about = drive_service.about().get(fields="user, storageQuota").execute()
|
|
user_info = about['user']
|
|
storage_quota = about['storageQuota']
|
|
|
|
# Dummy user
|
|
user = crud.get_user(db, 1)
|
|
if not user:
|
|
user = crud.create_user(db, schemas.UserCreate(email="dummy@example.com"))
|
|
|
|
account_data = schemas.AccountCreate(
|
|
google_email=user_info['emailAddress'],
|
|
credentials=credentials.to_json(),
|
|
drive_space_total=int(storage_quota.get('limit', 0)),
|
|
drive_space_used=int(storage_quota.get('usage', 0))
|
|
)
|
|
|
|
crud.create_account(db, account=account_data, user_id=user.id)
|
|
|
|
return RedirectResponse(url="/")
|