diff --git a/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_chat_completion.ipynb b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_chat_completion.ipynb index b7faa771b5..066d752fd9 100644 --- a/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_chat_completion.ipynb +++ b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_chat_completion.ipynb @@ -486,6 +486,7 @@ " learning_rate=0.00002,\n", " per_device_train_batch_size=1,\n", " num_train_epochs=3,\n", + " data_generation_task_type=\"NLI\",\n", " )\n", "\n", " return {\"output_model\": oss_distillation.outputs.output_model}" diff --git a/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_nlu_qa_task.ipynb b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_nlu_qa_task.ipynb new file mode 100644 index 0000000000..f2258e1736 --- /dev/null +++ b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/distillation_nlu_qa_task.ipynb @@ -0,0 +1,610 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distillation Q&A with Large Language Models\n", + " \n", + "### Notebook details\n", + " \n", + "This sample demonstrates how to train the selected student model using the teacher model, resulting in the creation of the distilled model.\n", + " \n", + "We will use the Meta Llama 3.1 405B Instruct as the teacher model and the Meta Llama 3.1 8B Instruct as the student model.\n", + " \n", + "**Note :**\n", + " \n", + "- Distillation offering is only available in **West US 3** regions.\n", + "- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.\n", + "- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.\n", + "- Distllation is supported for Question and Answering (QnA) task, which is a standard task in benchmarking for Natural Language Understanding.\n", + "\n", + "**Prerequisites :**\n", + "- Subscribe to the Meta Llama 3.1 405B Instruct and Meta Llama 3.1 8B Instruct, see [how to subscribe your project to the model offering in MS Learn](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio#subscribe-your-project-to-the-model-offering)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install the SDK v2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install azure-ai-ml\n", + "%pip install azure-identity\n", + "\n", + "%pip install mlflow\n", + "%pip install azureml-mlflow\n", + "%pip install -U datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import the required libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import required libraries\n", + "\n", + "import base64\n", + "import json\n", + "import os\n", + "from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n", + "from azure.ai.ml import MLClient, Input\n", + "from azure.ai.ml.constants import AssetTypes\n", + "from azure.ai.ml.dsl import pipeline\n", + "from azure.ai.ml.entities import Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "An AI Studio project in **West US 3** is required. Please follow [this](https://learn.microsoft.com/azure/ai-studio/how-to/fine-tune-model-llama?tabs=llama-two%2Cchatcompletion#prerequisites) document to setup your AI Studio project" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## AI Studio project settings\n", + "\n", + "Update following cell with the information of the AI Studio project just created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SUBSCRIPTION_ID = \"\"\n", + "RESOURCE_GROUP = \"\"\n", + "AI_PROJECT_NAME = \"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure credential\n", + "\n", + "We are using `DefaultAzureCredential` to get access to workspace. \n", + "`DefaultAzureCredential` should be capable of handling most Azure SDK authentication scenarios. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " credential = DefaultAzureCredential()\n", + " # Check if given credential can get token successfully.\n", + " credential.get_token(\"https://management.azure.com/.default\")\n", + "except Exception as ex:\n", + " # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work\n", + " credential = InteractiveBrowserCredential()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get handle to AI Studio project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ml_client = MLClient(credential, SUBSCRIPTION_ID, RESOURCE_GROUP, AI_PROJECT_NAME)\n", + "ai_project = ml_client._workspaces.get(ml_client.workspace_name)\n", + "ai_project._workspace_id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pick a teacher model\n", + "\n", + "We support **Meta-Llama-3.1-405B-Instruct** as the teacher model. \n", + "### First deploy the teacher model in Azure AI Studio\n", + "* Deploy Meta Llama 3.1 405B Instruct as a serverless API - [link](https://aka.ms/meta-llama-3.1-405B-instruct-azure-ai-studio-docs#create-a-new-deployment)\n", + "* Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.\n", + "\n", + "Update the following cell with the information of the deployment you just created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Llama-3-405B Teacher model endpoint name\n", + "# The serverless model name is the name found in ML Studio > Endpoints > Serverless endpoints > Model column\n", + "TEACHER_MODEL_NAME = \"Meta-Llama-3.1-405B-Instruct\"\n", + "\n", + "# The serverless model endpoint name is the name found in ML Studio > Endpoints > Serverless endpoints > Name column\n", + "# The endpoint URL will be resolved from this name by the MLFlow component\n", + "TEACHER_MODEL_ENDPOINT_NAME = \"Meta-Llama-3-1-405B-Instruct-vum\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pick a student model\n", + "\n", + "We will use **Meta-Llama-3.1-8B-Instruct** as student model. We only support chat completion models that are available for PayGo finetuning in Azure AI Studio." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "STUDENT_MODEL_NAME = \"Meta-Llama-3.1-8B-Instruct\"\n", + "STUDENT_MODEL_VERSION = 1\n", + "\n", + "# retrieve student model from model registry\n", + "mlclient_azureml_meta = MLClient(credential, registry_name=\"azureml-meta\")\n", + "student_model = mlclient_azureml_meta.models.get(\n", + " STUDENT_MODEL_NAME, version=STUDENT_MODEL_VERSION\n", + ")\n", + "\n", + "print(\n", + " \"\\n\\nUsing model name: {0}, version: {1}, id: {2} for fine tuning\".format(\n", + " student_model.name, student_model.version, student_model.id\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download the dataset from HuggingFace repo\n", + "\n", + "For our example, we download and use the Common Sense Q&A dataset (https://huggingface.co/datasets/tau/commonsense_qa/) from HuggingFace.\n", + "\n", + "The next few cells show basic data preparation for fine tuning:\n", + "* Visualize some data rows\n", + "* Preprocess the data and format it in required format. This is an important step for performing QandA as we add the required sequences/separators in the data. \n", + "* While fintuning, text column is concatenated with ground_truth column to produce finetuning input. Hence, the data should be prepared such that `text + ground_truth` is your actual finetuning data.\n", + "* bos and eos tokens are added to the data by finetuning pipeline, you do not need to add it explicitly \n", + "\n", + "##### Here is an example of how the data should look like\n", + "\n", + "text generation requires the training data to include at least 2 fields – one for ‘text’ and ‘ground_truth’ like in this example. The below examples are from Common Sense QandA dataset. \n", + "\n", + "Original dataset:\n", + "\n", + "{'id': '1afa02df02c908a558b4036e80242fac', 'question': 'A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?', 'question_concept': 'revolving door', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['bank', 'library', 'department store', 'mall', 'new york']}, 'answerKey': 'A'}\n", + "\n", + "{'id': 'a7ab086045575bb497933726e4e6ad28', 'question': 'What do people aim to do at work?', 'question_concept': 'people', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['complete job', 'learn from each other', 'kill animals', 'wear hats', 'talk to each other']}, 'answerKey': 'A'}\n", + "\n", + "Formatted dataset the user might pass:\n", + "\n", + "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'.\"}, {\"role\": \"user\", \"content\": \"Answer the following multiple-choice question by selecting the correct option.\\n\\nQuestion: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?\\nAnswer Choices:\\n(A) bank\\n(B) library\\n(C) department store\\n(D) mall\\n(E) new york\"}]}\n", + "\n", + "\n", + "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'.\"}, {\"role\": \"user\", \"content\": \"Answer the following multiple-choice question by selecting the correct option.\\n\\nQuestion: What do people aim to do at work?\\nAnswer Choices:\\n(A) complete job\\n(B) learn from each other\\n(C) kill animals\\n(D) wear hats\\n(E) talk to each other\"}]}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can define train and test sample sizes here. Validation size is kept same as test sample size\n", + "import download_dataset\n", + "\n", + "train_sample_size = 100\n", + "val_sample_size = 100\n", + "\n", + "# Sample notebook using the dataset: https://huggingface.co/datasets/tau/commonsense_qa\n", + "dataset_name = \"tau/commonsense_qa\"\n", + "input_dataset = download_dataset.CQnAHuggingFaceInputDataset()\n", + "\n", + "# Note: train_split_name and test_split_name can vary by dataset. They are passed as arguments in load_hf_dataset.\n", + "# If validation_split_name is None, the below function will split the train set to create the specified sized validation set.\n", + "train, val, _ = input_dataset.load_hf_dataset(\n", + " dataset_name=dataset_name,\n", + " train_sample_size=train_sample_size,\n", + " val_sample_size=val_sample_size,\n", + " train_split_name=\"train\",\n", + " val_split_name=\"validation\",\n", + ")\n", + "\n", + "print(\"Len of train data sample is \" + str(len(train)))\n", + "print(\"Len of validation data sample is \" + str(len(val)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " path = \"data\"\n", + " os.mkdir(path, 0o755)\n", + "except FileExistsError as err:\n", + " print(f\"Error: {err}\")\n", + "except OSError as err:\n", + " print(f\"Error: {err}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_data_path = \"data/train_cqna_512.jsonl\"\n", + "valid_data_path = \"data/valid_cqna_256.jsonl\"\n", + "\n", + "for row in train:\n", + " data = {\"messages\": []}\n", + " data[\"messages\"].append(\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'.\",\n", + " }\n", + " )\n", + " question, choices, ans_key = row[\"question\"], row[\"choices\"], row[\"answerKey\"]\n", + " labels, choice_list = choices[\"label\"], choices[\"text\"]\n", + " answer_choices = [\n", + " \"({}) {}\".format(labels[i], choice_list[i]) for i in range(len(labels))\n", + " ]\n", + " answer_choices = \"\\n\".join(answer_choices)\n", + " data[\"messages\"].append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Answer the following multiple-choice question by selecting the correct option.\\n\\nQuestion: \"\n", + " + row[\"question\"]\n", + " + \"\\nAnswer Choices:\\n\"\n", + " + answer_choices,\n", + " }\n", + " )\n", + " with open(train_data_path, \"a\") as f:\n", + " f.write(json.dumps(data) + \"\\n\")\n", + "\n", + "for row in val:\n", + " data = {\"messages\": []}\n", + " data[\"messages\"].append(\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'.\",\n", + " }\n", + " )\n", + " question, choices, ans_key = row[\"question\"], row[\"choices\"], row[\"answerKey\"]\n", + " labels, choice_list = choices[\"label\"], choices[\"text\"]\n", + " answer_choices = [\n", + " \"({}) {}\".format(labels[i], choice_list[i]) for i in range(len(labels))\n", + " ]\n", + " answer_choices = \"\\n\".join(answer_choices)\n", + " data[\"messages\"].append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Answer the following multiple-choice question by selecting the correct option.\\n\\nQuestion: \"\n", + " + row[\"question\"]\n", + " + \"\\nAnswer Choices:\\n\"\n", + " + answer_choices,\n", + " }\n", + " )\n", + " with open(valid_data_path, \"a\") as f:\n", + " f.write(json.dumps(data) + \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "val_df = pd.read_json(\"./data/valid_cqna_256.jsonl\", lines=True)\n", + "val_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare data inputs\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_data = None\n", + "train_data_name = \"cqna_train_70\"\n", + "\n", + "train_data = ml_client.data.create_or_update(\n", + " Data(\n", + " path=train_data_path,\n", + " type=AssetTypes.URI_FILE,\n", + " description=\"Training dataset\",\n", + " name=train_data_name,\n", + " )\n", + ")\n", + "\n", + "train_data_asset_id = f\"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{train_data.name}/versions/{train_data.version}\"\n", + "train_data_asset_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "valid_data = None\n", + "valid_data_name = \"cqna_valid_70\"\n", + "\n", + "valid_data = ml_client.data.create_or_update(\n", + " Data(\n", + " path=valid_data_path,\n", + " type=AssetTypes.URI_FILE,\n", + " description=\"validation dataset\",\n", + " name=valid_data_name,\n", + " )\n", + ")\n", + "\n", + "valid_data_asset_id = f\"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{valid_data.name}/versions/{valid_data.version}\"\n", + "valid_data_asset_id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Distillation strategy settings\n", + "\n", + "We provide the option to leverage Chain of Thought (CoT) reasoning for distillation. CoT leverages step by step reasoning ability of the teacher model to generate more accurate labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ENABLE_CHAIN_OF_THOUGHT = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure distillation\n", + "Define finetune parameters\n", + "\n", + "\n", + "Training parameters define the training aspects such as - \n", + "1. the learning rate\n", + "2. the number of epochs to finetune\n", + "3. the data_generation_task_type , in this case its 'NLU_QA'\n", + "and so on" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mlclient_azureml = MLClient(credential, registry_name=\"azureml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "distillation_pipeline_name = \"oss_distillation_pipeline\"\n", + "distillation_pipeline_component = mlclient_azureml.components.get(\n", + " name=distillation_pipeline_name, label=\"latest\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@pipeline\n", + "def distillation_pipeline(\n", + " teacher_model_endpoint_name: str,\n", + " enable_chain_of_thought: str,\n", + " system_properties: str,\n", + " input_finetune_model: Input,\n", + " train_file_path: Input,\n", + " validation_file_path: Input = None,\n", + "):\n", + " oss_distillation = distillation_pipeline_component(\n", + " teacher_model_endpoint_name=teacher_model_endpoint_name,\n", + " enable_chain_of_thought=enable_chain_of_thought,\n", + " train_file_path=train_file_path,\n", + " validation_file_path=validation_file_path,\n", + " # Finetune\n", + " mlflow_model_path=input_finetune_model,\n", + " model_asset_id=student_model.id,\n", + " system_properties=system_properties,\n", + " ## hyperparams\n", + " learning_rate=0.00002,\n", + " per_device_train_batch_size=1,\n", + " num_train_epochs=3,\n", + " data_generation_task_type=\"NLU_QA\",\n", + " )\n", + "\n", + " return {\"output_model\": oss_distillation.outputs.output_model}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_properties = {\n", + " \"finetune_oss\": \"True\",\n", + " \"model_asset_id\": student_model.id,\n", + " \"PipelineType\": \"Finetune\",\n", + " \"azureml.PipelineType\": \"Finetune\",\n", + " \"azureml.ModelName\": student_model.name,\n", + " \"azureml.original_model_id\": student_model.id,\n", + " \"azureml.trainingData.assetId\": train_data_asset_id,\n", + "}\n", + "\n", + "json_str = json.dumps(system_properties).replace(\" \", \"\")\n", + "\n", + "system_properties_b64_encoded = base64.b64encode(json_str.encode(\"utf-8\")).decode(\n", + " \"utf-8\"\n", + ")\n", + "print(f\"System properties => {system_properties_b64_encoded}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_file_path_input = Input(type=\"uri_file\", path=train_data.path)\n", + "validation_file_path_input = Input(type=\"uri_file\", path=valid_data.path)\n", + "input_finetune_model = Input(type=\"mlflow_model\", path=student_model.id)\n", + "experiment_name = f\"distillation-{TEACHER_MODEL_NAME}\".replace(\".\", \"-\")\n", + "\n", + "finetuning_job = distillation_pipeline(\n", + " teacher_model_endpoint_name=TEACHER_MODEL_ENDPOINT_NAME,\n", + " enable_chain_of_thought=ENABLE_CHAIN_OF_THOUGHT,\n", + " system_properties=system_properties_b64_encoded,\n", + " input_finetune_model=input_finetune_model,\n", + " train_file_path=train_file_path_input,\n", + " validation_file_path=validation_file_path_input,\n", + ")\n", + "\n", + "finetuning_job.properties.update(system_properties)\n", + "print(f\"job property: {finetuning_job.properties}\")\n", + "\n", + "# pipeline_job.identity = UserIdentityConfiguration()\n", + "finetuning_job.display_name = f\"finetune-{student_model.name}\"\n", + "finetuning_job.experiment_name = experiment_name\n", + "finetuning_job.settings.default_compute_type = \"serverless\"\n", + "finetuning_job.continue_on_step_failure = False\n", + "# pipeline_job.settings.force_rerun = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Submit pipeline job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Submit pipeline job to workspace\n", + "ft_job = ml_client.jobs.create_or_update(finetuning_job)\n", + "print(f\"Submitted job, progress available at {ft_job.studio_url}\")\n", + "\n", + "\n", + "# wait for the pipeline job to complete\n", + "ml_client.jobs.stream(ft_job.name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Consuming the distilled model\n", + "\n", + "Once the above job completes, you should be able to deploy the model and use it for inferencing. To deploy this model, do the following:\n", + "\n", + "* Go to AI Studio\n", + "* Navigate to the Fine-tuning tab on the left menu\n", + "* In the list of models you see, click on the model which got created from the distillation\n", + "* This should take you to the details page where you can see the model attributes and other details\n", + "* Click on the Deploy button on top of the page\n", + "* Follow the steps to deploy the model" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/download_dataset.py b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/download_dataset.py new file mode 100644 index 0000000000..95f1d48a98 --- /dev/null +++ b/sdk/python/foundation-models/system/finetune/Llama-notebooks/distillation/download_dataset.py @@ -0,0 +1,49 @@ +from datasets import load_dataset +from abc import ABC + + +class InputDataset(ABC): + def __init__(self): + super().__init__() + ( + self.train_data_file_name, + self.test_data_file_name, + self.eval_data_file_name, + ) = (None, None, None) + + +class CQnAHuggingFaceInputDataset(InputDataset): + """ + Loads the HuggingFace dataset + """ + + def __init__(self): + super().__init__() + + def load_hf_dataset( + self, + dataset_name, + train_sample_size=10, + val_sample_size=10, + test_sample_size=10, + train_split_name="train", + val_split_name="validation", + test_split_name="test", + ): + full_dataset = load_dataset(dataset_name) + + if val_split_name is not None: + train_data = full_dataset[train_split_name].select(range(train_sample_size)) + val_data = full_dataset[val_split_name].select(range(val_sample_size)) + test_data = full_dataset[test_split_name].select(range(test_sample_size)) + else: + train_val_data = full_dataset[train_split_name].select( + range(train_sample_size + val_sample_size) + ) + train_data = train_val_data.select(range(train_sample_size)) + val_data = train_val_data.select( + range(train_sample_size, train_sample_size + val_sample_size) + ) + test_data = full_dataset[test_split_name].select(range(test_sample_size)) + + return train_data, val_data, test_data