train_iris.py 900 B

1234567891011121314151617181920212223242526272829303132
  1. """Script d'entraînement Iris hors-API, utilisable depuis un notebook ou la console.
  2. Sauvegarde le modèle dans /work/models/iris_model.pkl.
  3. """
  4. from pathlib import Path
  5. import joblib
  6. from sklearn.datasets import load_iris
  7. from sklearn.model_selection import train_test_split
  8. from sklearn.linear_model import LogisticRegression
  9. MODELS_DIR = Path("/work/models")
  10. MODEL_PATH = MODELS_DIR / "iris_model.pkl"
  11. def main():
  12. iris = load_iris()
  13. X = iris.data
  14. y = iris.target
  15. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  16. clf = LogisticRegression(max_iter=1000)
  17. clf.fit(X_train, y_train)
  18. acc = clf.score(X_test, y_test)
  19. MODELS_DIR.mkdir(parents=True, exist_ok=True)
  20. joblib.dump(clf, MODEL_PATH)
  21. print({"status": "trained", "accuracy": float(acc), "model_path": str(MODEL_PATH)})
  22. if __name__ == "__main__":
  23. main()