diff --git a/docs/book/.gitbook/assets/argilla-interface-embeddings-finetuning.png b/docs/book/.gitbook/assets/argilla-interface-embeddings-finetuning.png new file mode 100644 index 00000000000..fd2e6c2c69b Binary files /dev/null and b/docs/book/.gitbook/assets/argilla-interface-embeddings-finetuning.png differ diff --git a/docs/book/.gitbook/assets/distilabel-synthetic-dataset-hf.png b/docs/book/.gitbook/assets/distilabel-synthetic-dataset-hf.png new file mode 100644 index 00000000000..ff339696e92 Binary files /dev/null and b/docs/book/.gitbook/assets/distilabel-synthetic-dataset-hf.png differ diff --git a/docs/book/.gitbook/assets/finetuning-embeddings-visualization.png b/docs/book/.gitbook/assets/finetuning-embeddings-visualization.png new file mode 100644 index 00000000000..220fa05d693 Binary files /dev/null and b/docs/book/.gitbook/assets/finetuning-embeddings-visualization.png differ diff --git a/docs/book/.gitbook/assets/mcp-embeddings.gif b/docs/book/.gitbook/assets/mcp-embeddings.gif new file mode 100644 index 00000000000..3b72b4faeeb Binary files /dev/null and b/docs/book/.gitbook/assets/mcp-embeddings.gif differ diff --git a/docs/book/.gitbook/assets/rag-dataset-hf.png b/docs/book/.gitbook/assets/rag-dataset-hf.png new file mode 100644 index 00000000000..22a66009597 Binary files /dev/null and b/docs/book/.gitbook/assets/rag-dataset-hf.png differ diff --git a/docs/book/.gitbook/assets/rag-finetuning-embeddings-pipeline.png b/docs/book/.gitbook/assets/rag-finetuning-embeddings-pipeline.png new file mode 100644 index 00000000000..056c0cda097 Binary files /dev/null and b/docs/book/.gitbook/assets/rag-finetuning-embeddings-pipeline.png differ diff --git a/docs/book/.gitbook/assets/rag-synthetic-data-pipeline.png b/docs/book/.gitbook/assets/rag-synthetic-data-pipeline.png new file mode 100644 index 00000000000..5fca85b6098 Binary files /dev/null and b/docs/book/.gitbook/assets/rag-synthetic-data-pipeline.png differ diff --git a/docs/book/component-guide/orchestrators/azureml.md b/docs/book/component-guide/orchestrators/azureml.md index fe29f399e91..4b518f710fb 100644 --- a/docs/book/component-guide/orchestrators/azureml.md +++ b/docs/book/component-guide/orchestrators/azureml.md @@ -25,7 +25,8 @@ You should use the AzureML orchestrator if: The ZenML AzureML orchestrator implementation uses [the Python SDK v2 of AzureML](https://learn.microsoft.com/en-gb/python/api/overview/azure/ai-ml-readme?view=azure-python) to allow our users to build their Machine Learning pipelines. For each ZenML step, -it creates an AzureML `[CommandComponent](https://learn.microsoft.com/en-us/python/api/azure-ai-ml/azure.ai.ml.entities.commandcomponent?view=azure-python)` and brings them together in a pipeline. +it creates an AzureML [CommandComponent](https://learn.microsoft.com/en-us/python/api/azure-ai-ml/azure.ai.ml.entities.commandcomponent?view=azure-python) +and brings them together in a pipeline. ## How to deploy it @@ -149,7 +150,8 @@ def pipeline(): ### Run pipelines on a schedule The AzureML orchestrator supports running pipelines on a schedule using -its `[JobSchedules](https://learn.microsoft.com/en-us/azure/templates/microsoft.automation/2023-11-01/automationaccounts/jobschedules?pivots=deployment-language-bicep)`. Both cron expression and intervals are supported. +its [JobSchedules](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-schedule-pipeline-job?view=azureml-api-2&tabs=python). +Both cron expression and intervals are supported. ```python from zenml.config.schedule import Schedule diff --git a/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md b/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md index 50374cb89b5..72cd628c5f1 100644 --- a/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md +++ b/docs/book/how-to/run-remote-pipelines-from-notebooks/README.md @@ -8,7 +8,7 @@ ZenML steps and pipelines can be defined in a Jupyter notebook and executed remo Learn more about it in the following sections: -
Define steps in notebook cellsdefine-steps-in-notebook-cells.md
Configure the notebook path
+
Define steps in notebook cellsdefine-steps-in-notebook-cells.md
ZenML Scarf
diff --git a/docs/book/toc.md b/docs/book/toc.md index 649ca169c56..76ca613eac4 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -56,8 +56,11 @@ * [Understanding reranking](user-guide/llmops-guide/reranking/understanding-reranking.md) * [Implementing reranking in ZenML](user-guide/llmops-guide/reranking/implementing-reranking.md) * [Evaluating reranking performance](user-guide/llmops-guide/reranking/evaluating-reranking-performance.md) - * [Improve retrieval by finetuning embeddings](user-guide/llmops-guide/finetuning-embeddings.md) - * [Finetuning LLMs with ZenML](user-guide/llmops-guide/finetuning-llms.md) + * [Improve retrieval by finetuning embeddings](user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings.md) + * [Synthetic data generation](user-guide/llmops-guide/finetuning-embeddings/synthetic-data-generation.md) + * [Finetuning embeddings with Sentence Transformers](user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md) + * [Evaluating finetuned embeddings](user-guide/llmops-guide/finetuning-embeddings/evaluating-finetuned-embeddings.md) + * [Finetuning LLMs with ZenML](user-guide/llmops-guide/finetuning-llms/finetuning-llms.md) ## How-To diff --git a/docs/book/user-guide/llmops-guide/README.md b/docs/book/user-guide/llmops-guide/README.md index 1196648d9ab..3f32d79e746 100644 --- a/docs/book/user-guide/llmops-guide/README.md +++ b/docs/book/user-guide/llmops-guide/README.md @@ -26,8 +26,11 @@ In this guide, we'll explore various aspects of working with LLMs in ZenML, incl * [Understanding reranking](reranking/understanding-reranking.md) * [Implementing reranking in ZenML](reranking/implementing-reranking.md) * [Evaluating reranking performance](reranking/evaluating-reranking-performance.md) -* [Improve retrieval by finetuning embeddings](finetuning-embeddings.md) -* [Finetuning LLMs with ZenML](finetuning-llms.md) +* [Improve retrieval by finetuning embeddings](finetuning-embeddings/finetuning-embeddings.md) + * [Synthetic data generation](finetuning-embeddings/synthetic-data-generation.md) + * [Finetuning embeddings with Sentence Transformers](finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md) + * [Evaluating finetuned embeddings](finetuning-embeddings/evaluating-finetuned-embeddings.md) +* [Finetuning LLMs with ZenML](finetuning-llms/finetuning-llms.md) To follow along with the examples and tutorials in this guide, ensure you have a Python environment set up with ZenML installed. Familiarity with the concepts covered in the [Starter Guide](../starter-guide/README.md) and [Production Guide](../production-guide/README.md) is recommended. diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings.md deleted file mode 100644 index 95754ac17f2..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-embeddings.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Finetune embeddings to improve retrieval performance. ---- - -🚧 This guide is a work in progress. Please check back soon for updates. - -Coming soon! -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/evaluating-finetuned-embeddings.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/evaluating-finetuned-embeddings.md new file mode 100644 index 00000000000..89851016aaa --- /dev/null +++ b/docs/book/user-guide/llmops-guide/finetuning-embeddings/evaluating-finetuned-embeddings.md @@ -0,0 +1,139 @@ +--- +description: Evaluate finetuned embeddings and compare to original base embeddings. +--- + +Now that we've finetuned our embeddings, we can evaluate them and compare to the +base embeddings. We have all the data saved and versioned already, and we will +reuse the same MatryoshkaLoss function for evaluation. + +In code, our evaluation steps are easy to comprehend. Here, for example, is the +base model evaluation step: + +```python +from zenml import log_model_metadata, step + +def evaluate_model( + dataset: DatasetDict, model: SentenceTransformer +) -> Dict[str, float]: + """Evaluate the given model on the dataset.""" + evaluator = get_evaluator( + dataset=dataset, + model=model, + ) + return evaluator(model) + +@step +def evaluate_base_model( + dataset: DatasetDict, +) -> Annotated[Dict[str, float], "base_model_evaluation_results"]: + """Evaluate the base model on the given dataset.""" + model = SentenceTransformer( + EMBEDDINGS_MODEL_ID_BASELINE, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + results = evaluate_model( + dataset=dataset, + model=model, + ) + + # Convert numpy.float64 values to regular Python floats + # (needed for serialization) + base_model_eval = { + f"dim_{dim}_cosine_ndcg@10": float( + results[f"dim_{dim}_cosine_ndcg@10"] + ) + for dim in EMBEDDINGS_MODEL_MATRYOSHKA_DIMS + } + + log_model_metadata( + metadata={"base_model_eval": base_model_eval}, + ) + + return results +``` + +We log the results for our core Matryoshka dimensions as model metadata to ZenML +within our evaluation step. This will allow us to inspect these results from +within [the Model Control Plane](https://docs.zenml.io/how-to/use-the-model-control-plane) (see +below for more details). Our results come in the form of a dictionary of string +keys and float values which will, like all step inputs and outputs, be +versioned, tracked and saved in your artifact store. + +## Visualizing results + +It's possible to visualize results in a few different ways in ZenML, but one +easy option is just to output your chart as an `PIL.Image` object. (See our +[documentation on more ways to visualize your +results](../../../how-to/visualize-artifacts/README.md).) The rest the +implementation of our `visualize_results` step is just simple `matplotlib` code +to plot out the base model evaluation against the finetuned model evaluation. We +represent the results as percentage values and horizontally stack the two sets +to make comparison a little easier. + +![Visualizing finetuned embeddings evaluation +results](../../../.gitbook/assets/finetuning-embeddings-visualization.png) + +We can see that our finetuned embeddings have improved the recall of our +retrieval system across all of the dimensions, but the results are still not +amazing. In a production setting, we would likely want to focus on improving the +data being used for the embeddings training. In particular, we could consider +stripping out some of the logs output from the documentation, and perhaps omit +some pages which offer low signal for the retrieval task. This embeddings +finetuning was run purely on the full set of synthetic data generated by +`distilabel` and `gpt-4o`, so we wouldn't necessarily expect to see huge +improvements out of the box, especially when the underlying data chunks are +complex and contain multiple topics. + +## Model Control Plane as unified interface + +Once all our pipelines are finished running, the best place to inspect our +results as well as the artifacts and models we generated is the Model Control +Plane. + +![Model Control Plane](../../../.gitbook/assets/mcp-embeddings.gif) + +The interface is split into sections that correspond to: + +- the artifacts generated by our steps +- the models generated by our steps +- the metadata logged by our steps +- (potentially) any deployments of models made, though we didn't use this in + this guide so far +- any pipeline runs associated with this 'Model' + +We can easily see which are the latest artifact or technical model versions, as +well as compare the actual values of our evals or inspect the hardware or +hyperparameters used for training. + +This one-stop-shop interface is available on ZenML Pro and you can learn more +about it in the [Model Control Plane +documentation](https://docs.zenml.io/how-to/use-the-model-control-plane). + +## Next Steps + +Now that we've finetuned our embeddings and evaluated them, when they were in a +good shape for use we could bring these into [the original RAG pipeline](../rag/basic-rag-inference-pipeline.md), +regenerate a new series of embeddings for our data and then rerun our RAG +retrieval evaluations to see how they've improved in our hand-crafted and +LLM-powered evaluations. + +The next section will cover [LLM finetuning and deployment](../finetuning-llms/finetuning-llms.md) as the +final part of our LLMops guide. (This section is currently still a work in +progress, but if you're eager to try out LLM finetuning with ZenML, you can use +[our LoRA +project](https://github.com/zenml-io/zenml-projects/blob/main/llm-lora-finetuning/README.md) +to get started. We also have [a +blogpost](https://www.zenml.io/blog/how-to-finetune-llama-3-1-with-zenml) guide which +takes you through +[all the steps you need to finetune Llama 3.1](https://www.zenml.io/blog/how-to-finetune-llama-3-1-with-zenml) using GCP's Vertex AI with ZenML, +including one-click stack creation!) + +To try out the two pipelines, please follow the instructions in [the project +repository README](https://github.com/zenml-io/zenml-projects/blob/main/llm-complete-guide/README.md), +and you can find the full code in that same directory. + + +
ZenML Scarf
+ + diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-for-better-retrieval-performance.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-for-better-retrieval-performance.md deleted file mode 100644 index d92d9d19e6b..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-for-better-retrieval-performance.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Learn how to fine-tune embeddings for better retrieval performance. ---- - -# Finetuning Embeddings for Better Retrieval Performance - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md new file mode 100644 index 00000000000..c4ee10b9157 --- /dev/null +++ b/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md @@ -0,0 +1,102 @@ +--- +description: Finetune embeddings with Sentence Transformers. +--- + +We now have a dataset that we can use to finetune our embeddings. You can +[inspect the positive and negative examples](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0_distilabel) on the Hugging Face [datasets page](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0_distilabel) since +our previous pipeline pushed the data there. + +![Synthetic data generated with distilabel for embeddings finetuning](../../../.gitbook/assets/distilabel-synthetic-dataset-hf.png) + +Our pipeline for finetuning the embeddings is relatively simple. We'll do the +following: + +- load our data either from Hugging Face or [from Argilla via the ZenML + annotation integration](../../../component-guide/annotators/argilla.md) +- finetune our model using the [Sentence + Transformers](https://www.sbert.net/) library +- evaluate the base and finetuned embeddings +- visualise the results of the evaluation + +![Embeddings finetuning pipeline with Sentence Transformers and +ZenML](../../../.gitbook/assets/rag-finetuning-embeddings-pipeline.png) + +## Loading data + +By default the pipeline will load the data from our Hugging Face dataset. If +you've annotated your data in Argilla, you can load the data from there instead. +You'll just need to pass an `--argilla` flag to the Python invocation when +you're running the pipeline like so: + +```bash +python run.py --embeddings --argilla +``` + +This assumes that you've set up an Argilla annotator in your stack. The code +checks for the annotator and downloads the data that was annotated in Argilla. +Please see our [guide to using the Argilla integration with ZenML](../../../component-guide/annotators/argilla.md) for more details. + +## Finetuning with Sentence Transformers + +The `finetune` step in the pipeline is responsible for finetuning the embeddings model using the Sentence Transformers library. Let's break down the key aspects of this step: + +1. **Model Loading**: The code loads the base model (`EMBEDDINGS_MODEL_ID_BASELINE`) using the Sentence Transformers library. It utilizes the SDPA (Self-Distilled Pruned Attention) implementation for efficient training with Flash Attention 2. + +2. **Loss Function**: The finetuning process employs a custom loss function called `MatryoshkaLoss`. This loss function is a wrapper around the `MultipleNegativesRankingLoss` provided by Sentence Transformers. The Matryoshka approach involves training the model with different embedding dimensions simultaneously. It allows the model to learn embeddings at various granularities, improving its performance across different embedding sizes. + +3. **Dataset Preparation**: The training dataset is loaded from the provided `dataset` parameter. The code saves the training data to a temporary JSON file and then loads it using the Hugging Face `load_dataset` function. + +4. **Evaluator**: An evaluator is created using the `get_evaluator` function. The evaluator is responsible for assessing the model's performance during training. + +5. **Training Arguments**: The code sets up the training arguments using the `SentenceTransformerTrainingArguments` class. It specifies various hyperparameters such as the number of epochs, batch size, learning rate, optimizer, precision (TF32 and BF16), and evaluation strategy. + +6. **Trainer**: The `SentenceTransformerTrainer` is initialized with the model, + training arguments, training dataset, loss function, and evaluator. The + trainer handles the training process. The `trainer.train()` method is called + to start the finetuning process. The model is trained for the specified + number of epochs using the provided hyperparameters. + +7. **Model Saving**: After training, the finetuned model is pushed to the Hugging Face Hub using the `trainer.model.push_to_hub()` method. The model is saved with the specified ID (`EMBEDDINGS_MODEL_ID_FINE_TUNED`). + +9. **Metadata Logging**: The code logs relevant metadata about the training process, including the training parameters, hardware information, and accelerator details. + +10. **Model Rehydration**: To handle materialization errors, the code saves the + trained model to a temporary file, loads it back into a new + `SentenceTransformer` instance, and returns the rehydrated model. + +(*Thanks and credit to Phil Schmid for [his tutorial on finetuning embeddings](https://www.philschmid.de/fine-tune-embedding-model-for-rag) with Sentence +Transformers and a Matryoshka loss function. This project uses many ideas and +some code from his implementation.*) + +## Finetuning in code + +Here's a simplified code snippet highlighting the key parts of the finetuning process: + +```python +# Load the base model +model = SentenceTransformer(EMBEDDINGS_MODEL_ID_BASELINE) +# Define the loss function +train_loss = MatryoshkaLoss(model, MultipleNegativesRankingLoss(model)) +# Prepare the training dataset +train_dataset = load_dataset("json", data_files=train_dataset_path) +# Set up the training arguments +args = SentenceTransformerTrainingArguments(...) +# Create the trainer +trainer = SentenceTransformerTrainer(model, args, train_dataset, train_loss) +# Start training +trainer.train() +# Save the finetuned model +trainer.model.push_to_hub(EMBEDDINGS_MODEL_ID_FINE_TUNED) +``` + +The finetuning process leverages the capabilities of the Sentence Transformers library to efficiently train the embeddings model. The Matryoshka approach allows for learning embeddings at different dimensions simultaneously, enhancing the model's performance across various embedding sizes. + +Our model is finetuned, saved in the Hugging Face Hub for easy access and +reference in subsequent steps, but also versioned and tracked within ZenML for +full observability. At this point the pipeline will evaluate the base and +finetuned embeddings and visualise the results. + + +
ZenML Scarf
+ + diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings.md index 95754ac17f2..6e68ec9a387 100644 --- a/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings.md +++ b/docs/book/user-guide/llmops-guide/finetuning-embeddings/finetuning-embeddings.md @@ -1,8 +1,43 @@ --- -description: Finetune embeddings to improve retrieval performance. +description: Finetune embeddings on custom synthetic data to improve retrieval performance. --- -🚧 This guide is a work in progress. Please check back soon for updates. +We previously learned [how to use RAG with ZenML](../rag-with-zenml/README.md) to +build a production-ready RAG pipeline. In this section, we will explore how to +optimize and maintain your embedding models through synthetic data generation and +human feedback. So far, we've been using off-the-shelf embeddings, which provide +a good baseline and decent performance on standard tasks. However, you can often +significantly improve performance by finetuning embeddings on your own domain-specific data. -Coming soon! +Our RAG pipeline uses a retrieval-based approach, where it first retrieves the +most relevant documents from our vector database, and then uses a language model +to generate a response based on those documents. By finetuning our embeddings on +a dataset of technical documentation similar to our target domain, we can improve +the retrieval step and overall performance of the RAG pipeline. + +The work of finetuning embeddings based on synthetic data and human feedback is +a multi-step process. We'll go through the following steps: + +- [generating synthetic data with `distilabel`](synthetic-data-generation.md) +- [finetuning embeddings with Sentence Transformers](finetuning-embeddings-with-sentence-transformers.md) +- [evaluating finetuned embeddings and using ZenML's model control plane to get a systematic overview](evaluating-finetuned-embeddings.md) + +Besides ZenML, we will do this by using two open source libraries: +[`argilla`](https://github.com/argilla-io/argilla/) and +[`distilabel`](https://github.com/argilla-io/distilabel). Both of these +libraries focus optimizing model outputs through improving data quality, +however, each one of them takes a different approach to tackle the same problem. +`distilabel` provides a scalable and reliable approach to distilling knowledge +from LLMs by generating synthetic data or providing AI feedback with LLMs as +judges. `argilla` enables AI engineers and domain experts to collaborate on data +projects by allowing them to organize and explore data through within an +interactive and engaging UI. Both libraries can be used individually but they +work better together. We'll showcase their use via ZenML pipelines. + +To follow along with the example explained in this guide, please follow the +instructions in [the `llm-complete-guide` repository](https://github.com/zenml-io/zenml-projects/llm-complete-guide/README.md) where the full code is also +available. This specific section on embeddings finetuning can be run locally or +using cloud compute as you prefer. + +
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/integrating-finetuned-embeddings-into-zenml-pipelines.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/integrating-finetuned-embeddings-into-zenml-pipelines.md deleted file mode 100644 index 71cafd0248f..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-embeddings/integrating-finetuned-embeddings-into-zenml-pipelines.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Learn how to integrate finetuned embeddings into ZenML pipelines. ---- - -# Integrating Finetuned Embeddings into ZenML Pipelines - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-embeddings/synthetic-data-generation.md b/docs/book/user-guide/llmops-guide/finetuning-embeddings/synthetic-data-generation.md new file mode 100644 index 00000000000..2627f1d13e0 --- /dev/null +++ b/docs/book/user-guide/llmops-guide/finetuning-embeddings/synthetic-data-generation.md @@ -0,0 +1,225 @@ +--- +description: Generate synthetic data with distilabel to finetune embeddings. +--- + +We already have [a dataset of technical documentation](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0) that was generated +previously while we were working on the RAG pipeline. We'll use this dataset +to generate synthetic data with `distilabel`. You can inspect the data directly +[on the Hugging Face dataset page](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0). + +![](../../../.gitbook/assets/rag-dataset-hf.png) + +As you can see, it is made up of some `page_content` (our chunks) as well as the +source URL from where the chunk was taken from. With embeddings, what we're +going to want to do is pair the `page_content` with a question that we want to +answer. In a pre-LLM world we might have actually created a new column and +worked to manually craft questions for each chunk. However, with LLMs, we can +use the `page_content` to generate questions. + +## Pipeline overview + +Our pipeline to generate synthetic data will look like this: + +![](../../../.gitbook/assets/rag-synthetic-data-pipeline.png) + +We'll load the Hugging Face dataset, then we'll use `distilabel` to generate the +synthetic data. To finish off, we'll push the newly-generated data to a new +Hugging Face dataset and also push the same data to our Argilla instance for +annotation and inspection. + +## Synthetic data generation + +[`distilabel`](https://github.com/argilla-io/distilabel) provides a scalable and +reliable approach to distilling knowledge from LLMs by generating synthetic data +or providing AI feedback with LLMs as judges. We'll be using it a relatively +simple use case to generate some queries appropriate to our documentation +chunks, but it can be used for a variety of other tasks. + +We can set up a `distilabel` pipeline easily in our ZenML step to handle the +dataset creation. We'll be using `gpt-4o` as the LLM to generate the synthetic +data so you can follow along, but `distilabel` supports a variety of other LLM +providers (including Ollama) so you can use whatever you have available. + +```python +import os +from typing import Annotated, Tuple + +import distilabel +from constants import ( + DATASET_NAME_DEFAULT, + OPENAI_MODEL_GEN, + OPENAI_MODEL_GEN_KWARGS_EMBEDDINGS, +) +from datasets import Dataset +from distilabel.llms import OpenAILLM +from distilabel.steps import LoadDataFromHub +from distilabel.steps.tasks import GenerateSentencePair +from zenml import step + +synthetic_generation_context = """ +The text is a chunk from technical documentation of ZenML. +ZenML is an MLOps + LLMOps framework that makes your infrastructure and workflow metadata accessible to data science teams. +Along with prose explanations, the text chunk may include code snippets and logs but these are identifiable from the surrounding backticks. +""" + +@step +def generate_synthetic_queries( + train_dataset: Dataset, test_dataset: Dataset +) -> Tuple[ + Annotated[Dataset, "train_with_queries"], + Annotated[Dataset, "test_with_queries"], +]: + llm = OpenAILLM( + model=OPENAI_MODEL_GEN, api_key=os.getenv("OPENAI_API_KEY") + ) + + with distilabel.pipeline.Pipeline( + name="generate_embedding_queries" + ) as pipeline: + load_dataset = LoadDataFromHub( + output_mappings={"page_content": "anchor"}, + ) + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + llm=llm, + input_batch_size=10, + context=synthetic_generation_context, + ) + + load_dataset >> generate_sentence_pair + + train_distiset = pipeline.run( + parameters={ + load_dataset.name: { + "repo_id": DATASET_NAME_DEFAULT, + "split": "train", + }, + generate_sentence_pair.name: { + "llm": { + "generation_kwargs": OPENAI_MODEL_GEN_KWARGS_EMBEDDINGS + } + }, + }, + ) + + test_distiset = pipeline.run( + parameters={ + load_dataset.name: { + "repo_id": DATASET_NAME_DEFAULT, + "split": "test", + }, + generate_sentence_pair.name: { + "llm": { + "generation_kwargs": OPENAI_MODEL_GEN_KWARGS_EMBEDDINGS + } + }, + }, + ) + + train_dataset = train_distiset["default"]["train"] + test_dataset = test_distiset["default"]["train"] + + return train_dataset, test_dataset +``` + +As you can see, we set up the LLM, create a `distilabel` pipeline, load the +dataset, mapping the `page_content` column so that it becomes `anchor`. (This +column renaming will make things easier a bit later when we come to finetuning +the embeddings.) Then we generate the synthetic data by using the `GenerateSentencePair` +step. This will create queries for each of the chunks in the dataset, so if the +chunk was about registering a ZenML stack, the query might be "How do I register +a ZenML stack?". It will also create negative queries, which are queries that +would be inappropriate for the chunk. We do this so that the embeddings model +can learn to distinguish between appropriate and inappropriate queries. + +We add some context to the generation process to help the LLM +understand the task and the data we're working with. In particular, we explain +that some parts of the text are code snippets and logs. We found performance to +be better when we added this context. + +When this step runs within ZenML it will handle spinning up the necessary +processes to make batched LLM calls to the OpenAI API. This is really useful +when working with large datasets. `distilabel` has also implemented a caching +mechanism to avoid recomputing results for the same inputs. So in this case you +have two layers of caching: one in the `distilabel` pipeline and one in the +ZenML orchestrator. This helps [speed up the pace of iteration](https://www.zenml.io/blog/iterate-fast) and saves you money. + +## Data annotation with Argilla + +Once we've let the LLM generate the synthetic data, we'll want to inspect it +and make sure it looks good. We'll do this by pushing the data to an Argilla +instance. We add a few extra pieces of metadata to the data to make it easier to +navigate and inspect within our data annotation tool. These include: + +- `parent_section`: This will be the section of the documentation that the chunk + is from. +- `token_count`: This will be the number of tokens in the chunk. +- `similarity-positive-negative`: This will be the cosine similarity between the + positive and negative queries. +- `similarity-anchor-positive`: This will be the cosine similarity between the + anchor and positive queries. +- `similarity-anchor-negative`: This will be the cosine similarity between the + anchor and negative queries. + +We'll also add the embeddings for the anchor column so that we can use these +for retrieval. We'll use the base model (in our case, +`Snowflake/snowflake-arctic-embed-large`) to generate the embeddings. We use +this function to map the dataset and process all the metadata: + +```python +def format_data(batch): + model = SentenceTransformer( + EMBEDDINGS_MODEL_ID_BASELINE, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + def get_embeddings(batch_column): + vectors = model.encode(batch_column) + return [vector.tolist() for vector in vectors] + + batch["anchor-vector"] = get_embeddings(batch["anchor"]) + batch["question-vector"] = get_embeddings(batch["anchor"]) + batch["positive-vector"] = get_embeddings(batch["positive"]) + batch["negative-vector"] = get_embeddings(batch["negative"]) + + def get_similarities(a, b): + similarities = [] + + for pos_vec, neg_vec in zip(a, b): + similarity = cosine_similarity([pos_vec], [neg_vec])[0][0] + similarities.append(similarity) + return similarities + + batch["similarity-positive-negative"] = get_similarities( + batch["positive-vector"], batch["negative-vector"] + ) + batch["similarity-anchor-positive"] = get_similarities( + batch["anchor-vector"], batch["positive-vector"] + ) + batch["similarity-anchor-negative"] = get_similarities( + batch["anchor-vector"], batch["negative-vector"] + ) + return batch +``` + +The [rest of the `push_to_argilla` step](https://github.com/zenml-io/zenml-projects/blob/main/llm-complete-guide/steps/push_to_argilla.py) is just setting up the Argilla +dataset and pushing the data to it. + +At this point you'd move to Argilla to view the data, see which examples seem to +make sense and which don't. You can update the questions (positive and negative) +which were generated by the LLM. If you want, you can do some data cleaning and +exploration to improve the data quality, perhaps using the similarity metrics +that we calculated earlier. + +![Argilla interface for data annotation](../../../.gitbook/assets/argilla-interface-embeddings-finetuning.png) + +We'll next move to actually finetuning the embeddings, assuming you've done some +data exploration and annotation. The code will work even without the annotation, +however, since we'll just use the full generated dataset and assume that the +quality is good enough. + + +
ZenML Scarf
+ + diff --git a/docs/book/user-guide/llmops-guide/finetuning-llms.md b/docs/book/user-guide/llmops-guide/finetuning-llms.md deleted file mode 100644 index c37be565b91..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-llms.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Finetune LLMs for specific tasks or to improve performance and cost. ---- - -🚧 This guide is a work in progress. Please check back soon for updates. - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-llms/deploying-finetuned-models.md b/docs/book/user-guide/llmops-guide/finetuning-llms/deploying-finetuned-models.md deleted file mode 100644 index f497204deff..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-llms/deploying-finetuned-models.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Guide on deploying finetuned LLMs. ---- - -# Deploying Finetuned Models with ZenML - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-llms/finetuning-llms-for-specific-tasks.md b/docs/book/user-guide/llmops-guide/finetuning-llms/finetuning-llms-for-specific-tasks.md deleted file mode 100644 index 5ef4c070568..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-llms/finetuning-llms-for-specific-tasks.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Learn how to fine-tune LLMs for specific tasks using ZenML. ---- - -# Finetuning LLMs for Specific Tasks - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-llms/synthetic-data-generation.md b/docs/book/user-guide/llmops-guide/finetuning-llms/synthetic-data-generation.md deleted file mode 100644 index 0712debc309..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-llms/synthetic-data-generation.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -description: Learn how to generate synthetic data for finetuning LLMs. ---- - -# Synthetic Data Generation - - - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/finetuning-llms/tracking-finetuned-llm-models.md b/docs/book/user-guide/llmops-guide/finetuning-llms/tracking-finetuned-llm-models.md deleted file mode 100644 index 63baafd1448..00000000000 --- a/docs/book/user-guide/llmops-guide/finetuning-llms/tracking-finetuned-llm-models.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -description: Learn how to track fine-tuned LLM models using ZenML. ---- - -# Tracking Finetuned LLM Models with ZenML - - -
ZenML Scarf
diff --git a/docs/book/user-guide/llmops-guide/llmops-guide.md b/docs/book/user-guide/llmops-guide/llmops-guide.md deleted file mode 100644 index 70fbe0f2a68..00000000000 --- a/docs/book/user-guide/llmops-guide/llmops-guide.md +++ /dev/null @@ -1,38 +0,0 @@ ---- -description: Leverage the power of LLMs in your MLOps workflows with ZenML. ---- - -# 🦜 LLMOps guide - -Welcome to the ZenML LLMOps Guide, where we dive into the exciting world of Large Language Models (LLMs) and how to integrate them seamlessly into your MLOps pipelines using ZenML. This guide is designed for ML practitioners and MLOps engineers looking to harness the potential of LLMs while maintaining the robustness and scalability of their workflows. - -

ZenML simplifies the development and deployment of LLM-powered MLOps pipelines.

- -In this guide, we'll explore various aspects of working with LLMs in ZenML, including: - -* [RAG with ZenML](rag/rag-with-zenml.md) - * [RAG in 85 lines of code](rag/rag-85-loc.md) - * [Understanding Retrieval-Augmented Generation (RAG)](rag/understanding-rag.md) - * [Data ingestion and preprocessing](rag/data-ingestion.md) - * [Embeddings generation](rag/embeddings-generation.md) - * [Storing embeddings in a vector database](rag/storing-embeddings-in-a-vector-database.md) - * [Basic RAG inference pipeline](rag/basic-rag-inference-pipeline.md) -* [Evaluation and metrics](evaluation/evaluation.md) - * [Evaluation in 65 lines of code](evaluation/evaluation-in-65-loc.md) - * [Retrieval evaluation](evaluation/retrieval.md) - * [Generation evaluation](evaluation/generation.md) - * [Evaluation in practice](evaluation/evaluation-in-practice.md) -* [Reranking for better retrieval](reranking/reranking.md) - * [Understanding reranking](reranking/understanding-reranking.md) - * [Implementing reranking in ZenML](reranking/implementing-reranking.md) - * [Evaluating reranking performance](reranking/evaluating-reranking-performance.md) -* [Improve retrieval by finetuning embeddings](finetuning-embeddings/finetuning-embeddings.md) -* [Finetuning LLMs with ZenML](finetuning-llms/finetuning-llms.md) - -To follow along with the examples and tutorials in this guide, ensure you have a Python environment set up with ZenML installed. Familiarity with the concepts covered in the [Starter Guide](../starter-guide/README.md) and [Production Guide](../production-guide/README.md) is recommended. - -We'll showcase a specific application over the course of this LLM guide, showing how you can work from a simple RAG pipeline to a more complex setup that involves finetuning embeddings, reranking retrieved documents, and even finetuning the LLM itself. We'll do this all for a use case relevant to ZenML: a question answering system that can provide answers to common questions about ZenML. This will help you understand how to apply the concepts covered in this guide to your own projects. - -By the end of this guide, you'll have a solid understanding of how to leverage LLMs in your MLOps workflows using ZenML, enabling you to build powerful, scalable, and maintainable LLM-powered applications. First up, let's take a look at a super simple implementation of the RAG paradigm to get started. - -
ZenML Scarf
diff --git a/src/zenml/cli/web_login.py b/src/zenml/cli/web_login.py index 47f302d00e5..370b33999f6 100644 --- a/src/zenml/cli/web_login.py +++ b/src/zenml/cli/web_login.py @@ -28,6 +28,7 @@ DEVICE_AUTHORIZATION, LOGIN, VERSION_1, + ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT, ) from zenml.exceptions import AuthorizationException, OAuthError from zenml.logger import get_logger @@ -93,6 +94,11 @@ def web_login(url: str, verify_ssl: Union[str, bool]) -> str: # Get rid of any trailing slashes to prevent issues when having double # slashes in the URL url = url.rstrip("/") + zenml_pro_extra = "" + if ".zenml.io" in url: + zenml_pro_extra = ( + ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT + ) try: auth_url = url + API + VERSION_1 + DEVICE_AUTHORIZATION response = requests.post( @@ -111,6 +117,7 @@ def web_login(url: str, verify_ssl: Union[str, bool]) -> str: logger.info(f"Error: {response.status_code} {response.text}") raise AuthorizationException( "Could not connect to API server. Please check the URL." + + zenml_pro_extra ) except (requests.exceptions.JSONDecodeError, ValueError, TypeError): logger.exception("Bad response received from API server.") @@ -121,6 +128,7 @@ def web_login(url: str, verify_ssl: Union[str, bool]) -> str: logger.exception("Could not connect to API server.") raise AuthorizationException( "Could not connect to API server. Please check the URL." + + zenml_pro_extra ) # Open the verification URL in the user's browser diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 05b5152aae1..cf73ce691f7 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -313,6 +313,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ) DEFAULT_ZENML_SERVER_SECURE_HEADERS_REPORT_TO = "default" DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD = False +DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30 # Configurations to decide which resources report their usage and check for # entitlement in the case of a cloud deployment. Expected Format is this: @@ -492,3 +493,11 @@ def handle_int_env_var(var: str, default: int = 0) -> int: STACK_DEPLOYMENT_API_TOKEN_EXPIRATION = 60 * 6 # 6 hours + +# ZenML Pro +ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT = ( + "\nHINT: Since you are trying to communicate with the ZenML Pro Tenant, " + "please make sure that your tenant is in RUNNING state on your " + "Organization page. If the tenant is PAUSED you can `Resume` it via UI " + "and try again." +) diff --git a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py index 447aa4ade4f..c5c152f5740 100644 --- a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +++ b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py @@ -574,6 +574,9 @@ def _upload_and_run_pipeline( run_name: Orchestrator run name. settings: Pipeline level settings for this orchestrator. schedule: The schedule the pipeline will run on. + + Raises: + RuntimeError: If the Vertex Orchestrator fails to provision or any other Runtime errors """ # We have to replace the hyphens in the run name with underscores # and lower case the string, because the Vertex AI Pipelines service @@ -656,13 +659,15 @@ def _upload_and_run_pipeline( run.wait() except google_exceptions.ClientError as e: - logger.warning( - "Failed to create the Vertex AI Pipelines job: %s", e + logger.error("Failed to create the Vertex AI Pipelines job: %s", e) + raise RuntimeError( + f"Failed to create the Vertex AI Pipelines job: {e}" ) except RuntimeError as e: logger.error( "The Vertex AI Pipelines job execution has failed: %s", e ) + raise def get_orchestrator_run_id(self) -> str: """Returns the active orchestrator run id. diff --git a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py index 6d214092f59..53e432fe3ee 100644 --- a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py +++ b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py @@ -496,6 +496,108 @@ class GCPAuthenticationMethods(StrEnum): IMPERSONATION = "impersonation" +try: + from google.auth.aws import _DefaultAwsSecurityCredentialsSupplier + + class ZenMLAwsSecurityCredentialsSupplier( + _DefaultAwsSecurityCredentialsSupplier # type: ignore[misc] + ): + """An improved version of the GCP external account credential supplier for AWS. + + The original GCP external account credential supplier only provides + rudimentary support for extracting AWS credentials from environment + variables or the AWS metadata service. This version improves on that by + using the boto3 library itself (if available), which uses the entire range + of implicit authentication features packed into it. + + Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is + not supported for EKS pods and the EC2 attached role credentials are + used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a). + """ + + def get_aws_security_credentials( + self, context: Any, request: Any + ) -> gcp_aws.AwsSecurityCredentials: + """Get the security credentials from the local environment. + + This method is a copy of the original method from the + `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has + been modified to use the boto3 library to extract the AWS credentials + from the local environment. + + Args: + context: The context to use to get the security credentials. + request: The request to use to get the security credentials. + + Returns: + The AWS temporary security credentials. + """ + try: + import boto3 + + session = boto3.Session() + credentials = session.get_credentials() + if credentials is not None: + creds = credentials.get_frozen_credentials() + return gcp_aws.AwsSecurityCredentials( + creds.access_key, + creds.secret_key, + creds.token, + ) + except ImportError: + pass + + logger.debug( + "Failed to extract AWS credentials from the local environment " + "using the boto3 library. Falling back to the original " + "implementation." + ) + + return super().get_aws_security_credentials(context, request) + + def get_aws_region(self, context: Any, request: Any) -> str: + """Get the AWS region from the local environment. + + This method is a copy of the original method from the + `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has + been modified to use the boto3 library to extract the AWS + region from the local environment. + + Args: + context: The context to use to get the security credentials. + request: The request to use to get the security credentials. + + Returns: + The AWS region. + """ + try: + import boto3 + + session = boto3.Session() + if session.region_name: + return session.region_name # type: ignore[no-any-return] + except ImportError: + pass + + logger.debug( + "Failed to extract AWS region from the local environment " + "using the boto3 library. Falling back to the original " + "implementation." + ) + + return super().get_aws_region( # type: ignore[no-any-return] + context, request + ) + +except ImportError: + # The `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` + # class has been introduced in the `google-auth` library version 2.29.0. + # Before that, the AWS logic was part of the `google.auth.awsCredentials` + # class itself. + ZenMLAwsSecurityCredentialsSupplier = None # type: ignore[assignment,misc] + pass + + class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignore[misc] """An improved version of the GCP external account credential for AWS. @@ -508,6 +610,13 @@ class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignor Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is not supported for EKS pods and the EC2 attached role credentials are used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a). + + IMPORTANT: subclassing this class only works with the `google-auth` library + version lower than 2.29.0. Starting from version 2.29.0, the AWS logic + has been moved to a separate `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` + class that can be subclassed instead and supplied as the + `aws_security_credentials_supplier` parameter to the + `google.auth.aws.Credentials` class. """ def _get_security_credentials( @@ -539,12 +648,14 @@ def _get_security_credentials( "secret_access_key": creds.secret_key, "security_token": creds.token, } - except Exception: - logger.debug( - "Failed to extract AWS credentials from the local environment " - "using the boto3 library. Falling back to the original " - "implementation." - ) + except ImportError: + pass + + logger.debug( + "Failed to extract AWS credentials from the local environment " + "using the boto3 library. Falling back to the original " + "implementation." + ) return super()._get_security_credentials( # type: ignore[no-any-return] request, imdsv2_session_token @@ -1126,6 +1237,12 @@ def _authenticate( account_info.get("subject_token_type") == _AWS_SUBJECT_TOKEN_TYPE ): + if ZenMLAwsSecurityCredentialsSupplier is not None: + account_info["aws_security_credentials_supplier"] = ( + ZenMLAwsSecurityCredentialsSupplier( + account_info.pop("credential_source"), + ) + ) credentials = ( ZenMLGCPAWSExternalAccountCredentials.from_info( account_info, diff --git a/src/zenml/models/v2/core/server_settings.py b/src/zenml/models/v2/core/server_settings.py index 16d5c0890b8..6c85c32667c 100644 --- a/src/zenml/models/v2/core/server_settings.py +++ b/src/zenml/models/v2/core/server_settings.py @@ -82,6 +82,9 @@ class ServerSettingsResponseBody(BaseResponseBody): display_updates: Optional[bool] = Field( title="Whether to display notifications about ZenML updates in the dashboard.", ) + last_user_activity: datetime = Field( + title="The timestamp when the last user activity was detected.", + ) updated: datetime = Field( title="The timestamp when this resource was last updated." ) @@ -179,6 +182,15 @@ def active(self) -> bool: """ return self.get_body().active + @property + def last_user_activity(self) -> datetime: + """The `last_user_activity` property. + + Returns: + the value of the property. + """ + return self.get_body().last_user_activity + @property def updated(self) -> datetime: """The `updated` property. diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index b54bfecb905..662adc3506f 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Model definitions for ZenML servers.""" -from typing import Dict +from datetime import datetime +from typing import Dict, Optional from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -103,6 +104,11 @@ class ServerModel(BaseModel): title="Flag to indicate whether the server is using the legacy dashboard.", ) + last_user_activity: Optional[datetime] = Field( + None, + title="Timestamp of latest user activity traced on the server.", + ) + def is_local(self) -> bool: """Return whether the server is running locally. diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 84b8b881697..021f4d8f98b 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -17,8 +17,10 @@ import os from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, + List, Optional, Tuple, Type, @@ -33,7 +35,10 @@ from zenml.config.global_config import GlobalConfiguration from zenml.config.server_config import ServerConfiguration from zenml.constants import ( + API, ENV_ZENML_SERVER, + INFO, + VERSION_1, ) from zenml.enums import ServerProviderType from zenml.exceptions import IllegalOperationError, OAuthError @@ -53,6 +58,9 @@ ) from zenml.zen_stores.sql_zen_store import SqlZenStore +if TYPE_CHECKING: + from fastapi import Request + logger = get_logger(__name__) _zen_store: Optional["SqlZenStore"] = None @@ -570,3 +578,70 @@ def verify_admin_status_if_no_rbac( "without RBAC enabled.", ) return + + +def is_user_request(request: "Request") -> bool: + """Determine if the incoming request is a user request. + + This function checks various aspects of the request to determine + if it's a user-initiated request or a system request. + + Args: + request: The incoming FastAPI request object. + + Returns: + True if it's a user request, False otherwise. + """ + # Define system paths that should be excluded + system_paths: List[str] = [ + "/health", + "/metrics", + "/system", + "/docs", + "/redoc", + "/openapi.json", + ] + + user_prefix = f"{API}{VERSION_1}" + excluded_user_apis = [INFO] + # Check if this is not an excluded endpoint + if request.url.path in [ + user_prefix + suffix for suffix in excluded_user_apis + ]: + return False + + # Check if this is other user request + if request.url.path.startswith(user_prefix): + return True + + # Exclude system paths + if any(request.url.path.startswith(path) for path in system_paths): + return False + + # Exclude requests with specific headers + if request.headers.get("X-System-Request") == "true": + return False + + # Exclude requests from certain user agents (e.g., monitoring tools) + user_agent = request.headers.get("User-Agent", "").lower() + system_agents = ["prometheus", "datadog", "newrelic", "pingdom"] + if any(agent in user_agent for agent in system_agents): + return False + + # Check for internal IP addresses + client_host = request.client.host if request.client else None + if client_host and ( + client_host.startswith("10.") or client_host.startswith("192.168.") + ): + return False + + # Exclude OPTIONS requests (often used for CORS preflight) + if request.method == "OPTIONS": + return False + + # Exclude specific query parameters that might indicate system requests + if request.query_params.get("system_check"): + return False + + # If none of the above conditions are met, consider it a user request + return True diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 5acae894390..cc7f8170b64 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -22,6 +22,7 @@ import os from asyncio.log import logger +from datetime import datetime, timedelta, timezone from genericpath import isfile from typing import Any, List @@ -36,7 +37,11 @@ import zenml from zenml.analytics import source_context -from zenml.constants import API, HEALTH +from zenml.constants import ( + API, + DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS, + HEALTH, +) from zenml.enums import AuthScheme, SourceContextTypes from zenml.zen_server.exceptions import error_detail from zenml.zen_server.routers import ( @@ -80,8 +85,10 @@ initialize_secure_headers, initialize_workload_manager, initialize_zen_store, + is_user_request, secure_headers, server_config, + zen_store, ) if server_config().use_legacy_dashboard: @@ -109,6 +116,12 @@ def relative_path(rel: str) -> str: default_response_class=ORJSONResponse, ) +# Initialize last_user_activity +last_user_activity: datetime = datetime.now(timezone.utc) +last_user_activity_reported: datetime = datetime.now(timezone.utc) + timedelta( + seconds=-DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS +) + # Customize the default request validation handler that comes with FastAPI # to return a JSON response that matches the ZenML API spec. @@ -159,6 +172,44 @@ async def set_secure_headers(request: Request, call_next: Any) -> Any: return response +@app.middleware("http") +async def track_last_user_activity(request: Request, call_next: Any) -> Any: + """A middleware to track last user activity. + + This middleware checks if the incoming request is a user request and + updates the last activity timestamp if it is. + + Args: + request: The incoming request object. + call_next: A function that will receive the request as a parameter and + pass it to the corresponding path operation. + + Returns: + The response to the request. + """ + global last_user_activity + global last_user_activity_reported + + try: + if is_user_request(request): + last_user_activity = datetime.now(timezone.utc) + except Exception as e: + logger.debug( + f"An unexpected error occurred while checking user activity: {e}" + ) + if ( + ( + datetime.now(timezone.utc) - last_user_activity_reported + ).total_seconds() + > DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS + ): + last_user_activity_reported = datetime.now(timezone.utc) + zen_store()._update_last_user_activity_timestamp( + last_user_activity=last_user_activity + ) + return await call_next(request) + + @app.middleware("http") async def infer_source_context(request: Request, call_next: Any) -> Any: """A middleware to track the source of an event. diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 390d31033a8..ec520c54169 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -39,6 +39,7 @@ DEFAULT_WORKSPACE_NAME, ENV_ZENML_DEFAULT_WORKSPACE_NAME, IS_DEBUG_ENV, + ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT, ) from zenml.enums import ( SecretsStoreType, @@ -171,9 +172,14 @@ def __init__( ) except Exception as e: + zenml_pro_extra = "" + if ".zenml.io" in self.url: + zenml_pro_extra = ( + ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT + ) raise RuntimeError( f"Error initializing {self.type.value} store with URL " - f"'{self.url}': {str(e)}" + f"'{self.url}': {str(e)}" + zenml_pro_extra ) from e if not skip_default_registrations: diff --git a/src/zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py b/src/zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py new file mode 100644 index 00000000000..b7079c1f2e3 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py @@ -0,0 +1,51 @@ +"""add last_user_activity [3dcc5d20e82f]. + +Revision ID: 3dcc5d20e82f +Revises: 909550c7c4da +Create Date: 2024-08-07 14:49:07.623500 + +""" + +from datetime import datetime, timezone + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3dcc5d20e82f" +down_revision = "026d4577b6a0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + bind = op.get_bind() + session = sqlmodel.Session(bind=bind) + + with op.batch_alter_table("server_settings", schema=None) as batch_op: + batch_op.add_column( + sa.Column("last_user_activity", sa.DateTime(), nullable=True) + ) + + session.execute( + sa.text( + """ + UPDATE server_settings + SET last_user_activity = :last_user_activity + """ + ), + params=(dict(last_user_activity=datetime.now(timezone.utc))), + ) + + with op.batch_alter_table("server_settings", schema=None) as batch_op: + batch_op.alter_column( + "last_user_activity", existing_type=sa.DateTime(), nullable=False + ) + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + with op.batch_alter_table("server_settings", schema=None) as batch_op: + batch_op.drop_column("last_user_activity") diff --git a/src/zenml/zen_stores/schemas/server_settings_schemas.py b/src/zenml/zen_stores/schemas/server_settings_schemas.py index ef0dc8988e9..1ed5b03157f 100644 --- a/src/zenml/zen_stores/schemas/server_settings_schemas.py +++ b/src/zenml/zen_stores/schemas/server_settings_schemas.py @@ -42,6 +42,7 @@ class ServerSettingsSchema(SQLModel, table=True): display_announcements: Optional[bool] = Field(nullable=True) display_updates: Optional[bool] = Field(nullable=True) onboarding_state: Optional[str] = Field(nullable=True) + last_user_activity: datetime = Field(default_factory=datetime.utcnow) updated: datetime = Field(default_factory=datetime.utcnow) def update( @@ -111,6 +112,7 @@ def to_model( display_updates=self.display_updates, active=self.active, updated=self.updated, + last_user_activity=self.last_user_activity, ) metadata = None diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f47c0481334..cd50ba5ac27 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -20,7 +20,7 @@ import os import re import sys -from datetime import datetime +from datetime import datetime, timezone from functools import lru_cache from pathlib import Path from typing import ( @@ -1587,6 +1587,7 @@ def get_store_info(self) -> ServerModel: # the one fetched from the global configuration model.id = settings.server_id model.active = settings.active + model.last_user_activity = settings.last_user_activity if not handle_bool_env_var(ENV_ZENML_LOCAL_SERVER): model.analytics_enabled = settings.enable_analytics return model @@ -1689,6 +1690,29 @@ def update_server_settings( return settings.to_model(include_metadata=True) + def _update_last_user_activity_timestamp( + self, last_user_activity: datetime + ) -> None: + """Update the last user activity timestamp. + + Args: + last_user_activity: The timestamp of latest user activity + traced by server instance. + """ + with Session(self.engine) as session: + settings = self._get_server_settings(session=session) + + if last_user_activity < settings.last_user_activity.replace( + tzinfo=timezone.utc + ): + return + + settings.last_user_activity = last_user_activity + # `updated` kept intentionally unchanged here + session.add(settings) + session.commit() + session.refresh(settings) + def get_onboarding_state(self) -> List[str]: """Get the server onboarding state.