Skip to content

Commit

Permalink
Add conditional model training based on configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamish Burke committed Jan 14, 2025
1 parent 30972d3 commit 26a91b2
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions scripts/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@ def run_pipeline(config_path: str, log_file: str = None):
# Run data preprocessing and poisoning
logger.info("Starting data preprocessing and poisoning")
pipeline.run()

# Create and train model
logger.info("Creating model")
model_factory = ModelFactory()
model = model_factory.create_model(config.model_type)

# Train model
logger.info("Training model")
X_train, y_train, X_val, y_val = pipeline.get_training_data()
model.train(X_train, y_train, X_val, y_val)

# Save model and metadata
model_path = os.path.join(config.model_output, "model")
model.save(model_path)
config.save(os.path.join(config.model_output, "config.yaml"))


if config.train_model:
# Create and train model
logger.info("Creating model")
model_factory = ModelFactory()
model = model_factory.create_model(config.model_type)

# Train model
logger.info("Training model")
X_train, y_train, X_val, y_val = pipeline.get_training_data()
model.train(X_train, y_train, X_val, y_val)

# Save model and metadata
model_path = os.path.join(config.model_output, "model")
model.save(model_path)
config.save(os.path.join(config.model_output, "config.yaml"))

logger.info("Model saved to {}".format(model_path))
else:
logger.info("Skipping model training")

logger.info("Pipeline completed successfully")

except Exception as e:
Expand Down

0 comments on commit 26a91b2

Please sign in to comment.