diff --git a/mlserver/parallel/registry.py b/mlserver/parallel/registry.py index 86bd51ce8..55c1109f9 100644 --- a/mlserver/parallel/registry.py +++ b/mlserver/parallel/registry.py @@ -130,6 +130,12 @@ def model_initialiser(self, model_settings: ModelSettings) -> MLModel: # as normal. return model_initialiser(model_settings) + parameters = model_settings.parameters + if not parameters or not parameters.environment_tarball: + # If model is not using a custom environment, instantiate the model + # as normal. + return model_initialiser(model_settings) + # Otherwise, return a dummy model for now and wait for the load_model # hook to create the actual thing. # This avoids instantiating the model's actual class within the diff --git a/runtimes/alibi-explain/tests/conftest.py b/runtimes/alibi-explain/tests/conftest.py index 2bffbe59a..3298272a9 100644 --- a/runtimes/alibi-explain/tests/conftest.py +++ b/runtimes/alibi-explain/tests/conftest.py @@ -115,6 +115,7 @@ async def model_registry( on_model_load=[inference_pool_registry.load_model], on_model_reload=[inference_pool_registry.reload_model], on_model_unload=[inference_pool_registry.unload_model], + model_initialiser=inference_pool_registry.model_initialiser, ) await model_registry.load(custom_runtime_tf_settings) diff --git a/runtimes/mlflow/tests/rest/conftest.py b/runtimes/mlflow/tests/rest/conftest.py index 80af20dc4..8a5df840c 100644 --- a/runtimes/mlflow/tests/rest/conftest.py +++ b/runtimes/mlflow/tests/rest/conftest.py @@ -71,6 +71,7 @@ async def model_registry( on_model_load=[inference_pool_registry.load_model], on_model_reload=[inference_pool_registry.reload_model], on_model_unload=[inference_pool_registry.unload_model], + model_initialiser=inference_pool_registry.model_initialiser, ) await model_registry.load(model_settings) diff --git a/tests/grpc/conftest.py b/tests/grpc/conftest.py index 7c66a08ab..be4d5c765 100644 --- a/tests/grpc/conftest.py +++ b/tests/grpc/conftest.py @@ -37,6 +37,7 @@ async def model_registry( on_model_load=[inference_pool_registry.load_model, load_batching], on_model_reload=[inference_pool_registry.reload_model], on_model_unload=[inference_pool_registry.unload_model], + model_initialiser=inference_pool_registry.model_initialiser, ) model_name = sum_model_settings.name diff --git a/tests/rest/conftest.py b/tests/rest/conftest.py index 4bf8767d6..e31ad5473 100644 --- a/tests/rest/conftest.py +++ b/tests/rest/conftest.py @@ -22,6 +22,7 @@ async def model_registry( on_model_load=[inference_pool_registry.load_model, load_batching], on_model_reload=[inference_pool_registry.reload_model], on_model_unload=[inference_pool_registry.unload_model], + model_initialiser=inference_pool_registry.model_initialiser, ) model_name = sum_model_settings.name