from functools import lru_cache from typing import Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM BASE_MODEL_NAME = "facebook/nllb-200-distilled-600M" CUSTOM_MODEL_DIR = "/work/models/nllb-custom" @lru_cache(maxsize=1) def get_tokenizer(model_name_or_path: Optional[str] = None): name = model_name_or_path or BASE_MODEL_NAME return AutoTokenizer.from_pretrained(name) @lru_cache(maxsize=1) def get_model(model_name_or_path: Optional[str] = None): name = model_name_or_path or BASE_MODEL_NAME return AutoModelForSeq2SeqLM.from_pretrained(name) def load_custom_or_base(): """Charge le modèle custom s'il existe, sinon le modèle de base.""" import os if os.path.isdir(CUSTOM_MODEL_DIR): tok = get_tokenizer(CUSTOM_MODEL_DIR) mdl = get_model(CUSTOM_MODEL_DIR) else: tok = get_tokenizer(BASE_MODEL_NAME) mdl = get_model(BASE_MODEL_NAME) return tok, mdl def translate_text(text: str, src_lang: str, tgt_lang: str) -> str: """ Traduit le texte de src_lang vers tgt_lang via NLLB. Les codes de langue NLLB ressemblent à 'fra_Latn', 'lin_Latn', etc. """ tok, mdl = load_custom_or_base() inputs = tok(text, return_tensors="pt") # Indication des langues (optionnelle, selon le tokenizer) # Quelques tokenizers NLLB utilisent des tokens de langue spécifiques. generated_tokens = mdl.generate(**inputs, max_new_tokens=200) out = tok.batch_decode(generated_tokens, skip_special_tokens=True) return out[0] if out else ""