translate.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from functools import lru_cache
  2. from typing import Optional
  3. from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
  4. BASE_MODEL_NAME = "facebook/nllb-200-distilled-600M"
  5. CUSTOM_MODEL_DIR = "/work/models/nllb-custom"
  6. @lru_cache(maxsize=1)
  7. def get_tokenizer(model_name_or_path: Optional[str] = None):
  8. name = model_name_or_path or BASE_MODEL_NAME
  9. return AutoTokenizer.from_pretrained(name)
  10. @lru_cache(maxsize=1)
  11. def get_model(model_name_or_path: Optional[str] = None):
  12. name = model_name_or_path or BASE_MODEL_NAME
  13. return AutoModelForSeq2SeqLM.from_pretrained(name)
  14. def load_custom_or_base():
  15. """Charge le modèle custom s'il existe, sinon le modèle de base."""
  16. import os
  17. if os.path.isdir(CUSTOM_MODEL_DIR):
  18. tok = get_tokenizer(CUSTOM_MODEL_DIR)
  19. mdl = get_model(CUSTOM_MODEL_DIR)
  20. else:
  21. tok = get_tokenizer(BASE_MODEL_NAME)
  22. mdl = get_model(BASE_MODEL_NAME)
  23. return tok, mdl
  24. def translate_text(text: str, src_lang: str, tgt_lang: str) -> str:
  25. """
  26. Traduit le texte de src_lang vers tgt_lang via NLLB.
  27. Les codes de langue NLLB ressemblent à 'fra_Latn', 'lin_Latn', etc.
  28. """
  29. tok, mdl = load_custom_or_base()
  30. inputs = tok(text, return_tensors="pt")
  31. # Indication des langues (optionnelle, selon le tokenizer)
  32. # Quelques tokenizers NLLB utilisent des tokens de langue spécifiques.
  33. generated_tokens = mdl.generate(**inputs, max_new_tokens=200)
  34. out = tok.batch_decode(generated_tokens, skip_special_tokens=True)
  35. return out[0] if out else ""