diff --git a/ai-platform/snippets/code-model-tuning.js b/ai-platform/snippets/code-model-tuning.js new file mode 100644 index 0000000000..fee4b19d87 --- /dev/null +++ b/ai-platform/snippets/code-model-tuning.js @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +async function main( + project, + pipelineJobId, + modelDisplayName, + gcsOutputDirectory, + location = 'europe-west4', + datasetUri = 'gs://cloud-samples-data/ai-platform/generative_ai/sql_create_context.jsonl', + trainSteps = 300 +) { + // [START aiplatform_genai_code_model_tuning] + /** + * TODO(developer): Uncomment these variables before running the sample.\ + * (Not necessary if passing values as arguments) + */ + // const project = 'YOUR_PROJECT_ID'; + // const location = 'YOUR_PROJECT_LOCATION'; + const aiplatform = require('@google-cloud/aiplatform'); + const {PipelineServiceClient} = aiplatform.v1; + + // Import the helper module for converting arbitrary protobuf.Value objects. + const {helpers} = aiplatform; + + // Specifies the location of the api endpoint + const clientOptions = { + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }; + const model = 'code-bison@001'; + + const pipelineClient = new PipelineServiceClient(clientOptions); + + async function tuneLLM() { + // Configure the parent resource + const parent = `projects/${project}/locations/${location}`; + + const parameters = { + train_steps: helpers.toValue(trainSteps), + project: helpers.toValue(project), + location: helpers.toValue('us-central1'), + dataset_uri: helpers.toValue(datasetUri), + large_model_reference: helpers.toValue(model), + model_display_name: helpers.toValue(modelDisplayName), + }; + + const runtimeConfig = { + gcsOutputDirectory, + parameterValues: parameters, + }; + + const pipelineJob = { + templateUri: + 'https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0', + displayName: 'my-tuning-job', + runtimeConfig, + }; + + const createPipelineRequest = { + parent, + pipelineJob, + pipelineJobId, + }; + + const [response] = await pipelineClient.createPipelineJob( + createPipelineRequest + ); + + console.log('Tuning pipeline job:'); + console.log(`\tName: ${response.name}`); + console.log( + `\tCreate time: ${new Date(1970, 0, 1) + .setSeconds(response.createTime.seconds) + .toLocaleString()}` + ); + console.log(`\tStatus: ${response.status}`); + } + + await tuneLLM(); + // [END aiplatform_genai_code_model_tuning] +} + +exports.tuneModel = main; diff --git a/ai-platform/snippets/test/code-model-tuning.test.js b/ai-platform/snippets/test/code-model-tuning.test.js new file mode 100644 index 0000000000..bdfee631f1 --- /dev/null +++ b/ai-platform/snippets/test/code-model-tuning.test.js @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* eslint-disable */ + +'use strict'; + +const {assert} = require('chai'); +const {describe, it} = require('mocha'); +const uuid = require('uuid'); +const sinon = require('sinon'); + +const projectId = process.env.CAIP_PROJECT_ID; +const location = 'europe-west4'; + +const aiplatform = require('@google-cloud/aiplatform'); +const clientOptions = { + apiEndpoint: `${location}-aiplatform.googleapis.com`, +}; +const pipelineClient = new aiplatform.v1.PipelineServiceClient(clientOptions); + +const {tuneModel} = require('../code-model-tuning'); + +const timestampId = `${new Date() + .toISOString() + .replace(/(:|\.)/g, '-') + .toLowerCase()}`; +const pipelineJobName = `my-tuning-pipeline-${timestampId}`; +const modelDisplayName = `my-tuned-model-${timestampId}`; +const bucketName = 'ucaip-samples-europe-west4/training_pipeline_output'; +const bucketUri = `gs://${bucketName}/tune-model-nodejs`; + +describe('Tune a code model', () => { + const stubConsole = function () { + sinon.stub(console, 'error'); + sinon.stub(console, 'log'); + }; + + const restoreConsole = function () { + console.log.restore(); + console.error.restore(); + }; + + beforeEach(stubConsole); + afterEach(restoreConsole); + + it('should prompt-tune an existing code model', async () => { + // Act + await tuneModel(projectId, pipelineJobName, modelDisplayName, bucketUri); + + // Assert + assert.include(console.log.firstCall.args, 'Tuning pipeline job:'); + }); + + after(async () => { + // Cancel and delete the pipeline job + const name = pipelineClient.pipelineJobPath( + projectId, + location, + pipelineJobName + ); + + const cancelRequest = { + name, + }; + + pipelineClient.cancelPipelineJob(cancelRequest).then(() => { + const deleteRequest = { + name, + }; + + return pipelineClient.deletePipeline(deleteRequest); + }); + }); +});