Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for run_with_accelerate #2935

Merged
merged 8 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 62 additions & 44 deletions examples/quickstart/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
" # Pull required modules from this example\n",
" !git clone -b main https://github.com/zenml-io/zenml\n",
" !cp -r zenml/examples/quickstart/* .\n",
" !rm -rf zenml\n"
" !rm -rf zenml"
]
},
{
Expand All @@ -84,6 +84,7 @@
"!zenml integration install sklearn -y\n",
"\n",
"import IPython\n",
"\n",
"IPython.Application.instance().kernel.do_shutdown(restart=True)"
]
},
Expand Down Expand Up @@ -145,28 +146,22 @@
"outputs": [],
"source": [
"# Do the imports at the top\n",
"from typing_extensions import Annotated\n",
"from sklearn.datasets import load_breast_cancer\n",
"\n",
"import random\n",
"import pandas as pd\n",
"from zenml import step, pipeline, Model, get_step_context\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"from typing import List, Optional\n",
"from uuid import UUID\n",
"\n",
"from typing import Optional, List\n",
"\n",
"from zenml import pipeline\n",
"\n",
"import pandas as pd\n",
"from sklearn.datasets import load_breast_cancer\n",
"from steps import (\n",
" data_loader,\n",
" data_preprocessor,\n",
" data_splitter,\n",
" inference_preprocessor,\n",
" model_evaluator,\n",
" inference_preprocessor\n",
")\n",
"from typing_extensions import Annotated\n",
"\n",
"from zenml import Model, get_step_context, pipeline, step\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"\n",
"logger = get_logger(__name__)\n",
Expand Down Expand Up @@ -205,20 +200,22 @@
"@step\n",
"def data_loader_simplified(\n",
" random_state: int, is_inference: bool = False, target: str = \"target\"\n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n",
" \"\"\"Dataset reader step.\"\"\"\n",
" dataset = load_breast_cancer(as_frame=True)\n",
" inference_size = int(len(dataset.target) * 0.05)\n",
" dataset: pd.DataFrame = dataset.frame\n",
" inference_subset = dataset.sample(inference_size, random_state=random_state)\n",
" inference_subset = dataset.sample(\n",
" inference_size, random_state=random_state\n",
" )\n",
" if is_inference:\n",
" dataset = inference_subset\n",
" dataset.drop(columns=target, inplace=True)\n",
" else:\n",
" dataset.drop(inference_subset.index, inplace=True)\n",
" dataset.reset_index(drop=True, inplace=True)\n",
" logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n",
" return dataset\n"
" return dataset"
]
},
{
Expand Down Expand Up @@ -291,7 +288,7 @@
" normalize: Optional[bool] = None,\n",
" drop_columns: Optional[List[str]] = None,\n",
" target: Optional[str] = \"target\",\n",
" random_state: int = 17\n",
" random_state: int = 17,\n",
"):\n",
" \"\"\"Feature engineering pipeline.\"\"\"\n",
" # Link all the steps together by calling them and passing the output\n",
Expand Down Expand Up @@ -402,7 +399,6 @@
"from zenml.environment import Environment\n",
"from zenml.zen_stores.rest_zen_store import RestZenStore\n",
"\n",
"\n",
"if not isinstance(client.zen_store, RestZenStore):\n",
" # Only spin up a local Dashboard in case you aren't already connected to a remote server\n",
" if Environment.in_google_colab():\n",
Expand Down Expand Up @@ -479,7 +475,9 @@
"outputs": [],
"source": [
"# Get artifact version from our run\n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n",
" \"dataset_trn\"\n",
"]\n",
"\n",
"# Get latest version from client directly\n",
"dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n",
Expand All @@ -498,7 +496,9 @@
"source": [
"# Fetch the rest of the artifacts\n",
"dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n",
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")"
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\n",
" \"preprocess_pipeline\"\n",
")"
]
},
{
Expand Down Expand Up @@ -566,6 +566,7 @@
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import SGDClassifier\n",
"from typing_extensions import Annotated\n",
"\n",
"from zenml import ArtifactConfig, step\n",
"from zenml.logger import get_logger\n",
"\n",
Expand All @@ -576,23 +577,26 @@
"def model_trainer(\n",
" dataset_trn: pd.DataFrame,\n",
" model_type: str = \"sgd\",\n",
") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n",
") -> Annotated[\n",
" ClassifierMixin,\n",
" ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True),\n",
"]:\n",
" \"\"\"Configure and train a model on the training dataset.\"\"\"\n",
" target = \"target\"\n",
" if model_type == \"sgd\":\n",
" model = SGDClassifier()\n",
" elif model_type == \"rf\":\n",
" model = RandomForestClassifier()\n",
" else:\n",
" raise ValueError(f\"Unknown model type {model_type}\") \n",
" raise ValueError(f\"Unknown model type {model_type}\")\n",
"\n",
" logger.info(f\"Training model {model}...\")\n",
"\n",
" model.fit(\n",
" dataset_trn.drop(columns=[target]),\n",
" dataset_trn[target],\n",
" )\n",
" return model\n"
" return model"
]
},
{
Expand Down Expand Up @@ -630,14 +634,18 @@
" min_train_accuracy: float = 0.0,\n",
" min_test_accuracy: float = 0.0,\n",
"):\n",
" \"\"\"Model training pipeline.\"\"\" \n",
" \"\"\"Model training pipeline.\"\"\"\n",
" if train_dataset_id is None or test_dataset_id is None:\n",
" # If we dont pass the IDs, this will run the feature engineering pipeline \n",
" # If we dont pass the IDs, this will run the feature engineering pipeline\n",
" dataset_trn, dataset_tst = feature_engineering()\n",
" else:\n",
" # Load the datasets from an older pipeline\n",
" dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n",
" dataset_trn = client.get_artifact_version(\n",
" name_id_or_prefix=train_dataset_id\n",
" )\n",
" dataset_tst = client.get_artifact_version(\n",
" name_id_or_prefix=test_dataset_id\n",
" )\n",
"\n",
" trained_model = model_trainer(\n",
" dataset_trn=dataset_trn,\n",
Expand Down Expand Up @@ -676,7 +684,7 @@
"training(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"rf_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -693,7 +701,7 @@
"sgd_run = training(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"sgd_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -717,7 +725,9 @@
"outputs": [],
"source": [
"# The evaluator returns a float value with the accuracy\n",
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()"
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n",
" \"model_evaluator\"\n",
"].output.load()"
]
},
{
Expand Down Expand Up @@ -776,7 +786,7 @@
"training_configured(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand All @@ -798,7 +808,7 @@
"training_configured(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand Down Expand Up @@ -845,10 +855,14 @@
"outputs": [],
"source": [
"# Let's load the RF version\n",
"rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n",
"rf_zenml_model_version = client.get_model_version(\n",
" \"breast_cancer_classifier\", \"rf\"\n",
")\n",
"\n",
"# We can now load our classifier directly as well\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\n",
" \"sklearn_classifier\"\n",
").load()\n",
"\n",
"random_forest_classifier"
]
Expand Down Expand Up @@ -945,7 +959,9 @@
"outputs": [],
"source": [
"@step\n",
"def inference_predict(dataset_inf: pd.DataFrame) -> Annotated[pd.Series, \"predictions\"]:\n",
"def inference_predict(\n",
" dataset_inf: pd.DataFrame,\n",
") -> Annotated[pd.Series, \"predictions\"]:\n",
" \"\"\"Predictions step\"\"\"\n",
" # Get the model\n",
" model = get_step_context().model\n",
Expand All @@ -956,7 +972,7 @@
"\n",
" predictions = pd.Series(predictions, name=\"predicted\")\n",
"\n",
" return predictions\n"
" return predictions"
]
},
{
Expand All @@ -983,18 +999,18 @@
" random_state = 42\n",
" target = \"target\"\n",
"\n",
" df_inference = data_loader(\n",
" random_state=random_state, is_inference=True\n",
" )\n",
" df_inference = data_loader(random_state=random_state, is_inference=True)\n",
" df_inference = inference_preprocessor(\n",
" dataset_inf=df_inference,\n",
" # We use the preprocess pipeline from the feature engineering pipeline\n",
" preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n",
" preprocess_pipeline=client.get_artifact_version(\n",
" name_id_or_prefix=preprocess_pipeline_id\n",
" ),\n",
" target=target,\n",
" )\n",
" inference_predict(\n",
" dataset_inf=df_inference,\n",
" )\n"
" )"
]
},
{
Expand All @@ -1018,7 +1034,7 @@
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" version=\"production\", # We can pass in the stage name here!\n",
" version=\"production\", # We can pass in the stage name here!\n",
" license=\"Apache 2.0\",\n",
" description=\"A breast cancer classifier\",\n",
" tags=[\"breast_cancer\", \"classifier\"],\n",
Expand Down Expand Up @@ -1061,7 +1077,9 @@
"outputs": [],
"source": [
"# Fetch production model\n",
"production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n",
"production_model_version = client.get_model_version(\n",
" \"breast_cancer_classifier\", \"production\"\n",
")\n",
"\n",
"# Get the predictions artifact\n",
"production_model_version.get_artifact(\"predictions\").load()"
Expand Down
Loading
Loading