main.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from fastapi import FastAPI, UploadFile, File
  2. from fastapi.responses import JSONResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from starlette.staticfiles import StaticFiles
  5. from pathlib import Path
  6. from .translate_pipeline import process_audio_file, AUDIO_DIR
  7. from .train_translate import train_from_local_dataset
  8. app = FastAPI(title="ML API - Speech-to-Speech")
  9. # CORS (peut être ajusté via env WEB_ORIGIN si nécessaire)
  10. app.add_middleware(
  11. CORSMiddleware,
  12. allow_origins=["*"],
  13. allow_credentials=True,
  14. allow_methods=["*"],
  15. allow_headers=["*"],
  16. )
  17. # Fichiers statiques pour audio
  18. AUDIO_DIR.mkdir(parents=True, exist_ok=True)
  19. app.mount("/audio", StaticFiles(directory=str(AUDIO_DIR)), name="audio")
  20. @app.get("/health")
  21. def health():
  22. return {"status": "ok"}
  23. @app.post("/translate")
  24. async def translate(file: UploadFile = File(...)):
  25. data = await file.read()
  26. result = process_audio_file(data, file.filename)
  27. return JSONResponse(result)
  28. @app.post("/train")
  29. def train():
  30. out_dir = train_from_local_dataset()
  31. return {"status": "ok", "model_path": out_dir}
  32. @app.post("/predict")
  33. async def predict_compat(file: UploadFile = File(None)):
  34. """
  35. Endoint de compatibilité: si on reçoit un fichier audio, traite comme /translate.
  36. Sinon, renvoie un message d’orientation.
  37. """
  38. if file is not None:
  39. data = await file.read()
  40. result = process_audio_file(data, file.filename)
  41. return JSONResponse(result)
  42. return JSONResponse({"detail": "Utilisez /translate pour la traduction vocale."}, status_code=400)