train_translate.py 875 B

1234567891011121314151617181920212223242526
  1. import os
  2. from pathlib import Path
  3. from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
  4. DATASET_PATH = Path("/work/models/data/dataset.json")
  5. OUTPUT_DIR = Path("/work/models/nllb-custom")
  6. BASE_MODEL = "facebook/nllb-200-distilled-600M"
  7. def train_from_local_dataset() -> str:
  8. """
  9. Placeholder de fine-tuning NLLB. Pour un environnement CPU et rapide, on se contente
  10. de préparer le répertoire custom avec le tokenizer et le modèle de base afin
  11. de permettre les tests de pipeline. Dans un environnement GPU, remplacer par un
  12. entraînement réel (Trainer, datasets, etc.).
  13. """
  14. OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
  15. tok = AutoTokenizer.from_pretrained(BASE_MODEL)
  16. mdl = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
  17. tok.save_pretrained(OUTPUT_DIR)
  18. mdl.save_pretrained(OUTPUT_DIR)
  19. return str(OUTPUT_DIR)