diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 4297d6f519bf2..35bad766798b1 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -201,18 +201,28 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): logger.log_hyperparams(params) +@mock.patch('pytorch_lightning.loggers.mlflow.time') @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') -def test_mlflow_logger_with_artifact_location(client, mlflow, tmpdir): +def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): """ - Test that the logger raises warning with special characters not accepted by MLFlow. + Test that the logger calls methods on the mlflow experiment correctly. """ + time.return_value = 1 + logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') logger._mlflow_client.get_experiment_by_name.return_value = None + params = {'test': 'test_param'} + logger.log_hyperparams(params) + + logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param') + metrics = {'some_metric': 10} logger.log_metrics(metrics) + logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None) + logger._mlflow_client.create_experiment.assert_called_once_with( name='test', artifact_location='my_artifact_location',