| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from functools import lru_cache
- from typing import Optional
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
- BASE_MODEL_NAME = "facebook/nllb-200-distilled-600M"
- CUSTOM_MODEL_DIR = "/work/models/nllb-custom"
- @lru_cache(maxsize=1)
- def get_tokenizer(model_name_or_path: Optional[str] = None):
- name = model_name_or_path or BASE_MODEL_NAME
- return AutoTokenizer.from_pretrained(name)
- @lru_cache(maxsize=1)
- def get_model(model_name_or_path: Optional[str] = None):
- name = model_name_or_path or BASE_MODEL_NAME
- return AutoModelForSeq2SeqLM.from_pretrained(name)
- def load_custom_or_base():
- """Charge le modèle custom s'il existe, sinon le modèle de base."""
- import os
- if os.path.isdir(CUSTOM_MODEL_DIR):
- tok = get_tokenizer(CUSTOM_MODEL_DIR)
- mdl = get_model(CUSTOM_MODEL_DIR)
- else:
- tok = get_tokenizer(BASE_MODEL_NAME)
- mdl = get_model(BASE_MODEL_NAME)
- return tok, mdl
- def translate_text(text: str, src_lang: str, tgt_lang: str) -> str:
- """
- Traduit le texte de src_lang vers tgt_lang via NLLB.
- Les codes de langue NLLB ressemblent à 'fra_Latn', 'lin_Latn', etc.
- """
- tok, mdl = load_custom_or_base()
- inputs = tok(text, return_tensors="pt")
- # Indication des langues (optionnelle, selon le tokenizer)
- # Quelques tokenizers NLLB utilisent des tokens de langue spécifiques.
- generated_tokens = mdl.generate(**inputs, max_new_tokens=200)
- out = tok.batch_decode(generated_tokens, skip_special_tokens=True)
- return out[0] if out else ""
|