Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Mar 25, 2021
1 parent a94060a commit 164c8dc
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,28 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
logger.log_hyperparams(params)


@mock.patch('pytorch_lightning.loggers.mlflow.time')
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_logger_with_artifact_location(client, mlflow, tmpdir):
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
"""
Test that the logger raises warning with special characters not accepted by MLFlow.
Test that the logger calls methods on the mlflow experiment correctly.
"""
time.return_value = 1

logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location')
logger._mlflow_client.get_experiment_by_name.return_value = None

params = {'test': 'test_param'}
logger.log_hyperparams(params)

logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param')

metrics = {'some_metric': 10}
logger.log_metrics(metrics)

logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None)

logger._mlflow_client.create_experiment.assert_called_once_with(
name='test',
artifact_location='my_artifact_location',
Expand Down

0 comments on commit 164c8dc

Please sign in to comment.