Skip to content

Commit

Permalink
feat(aiplatform): add sample for Gen AI code model tuning (#3410)
Browse files Browse the repository at this point in the history
* feat(aiplatform): add sample for Gen AI code model tuning

* udpate comments to make it clear

* fix: skip this test temporary until the service account permission issue is resolved

* fix: skip another pipeline test temporary until the service account permission issue is resolved

* fix(aiplatform): try to get the test to run after IAM role changed

* fix(aiplatform): remove the skip for another pipeline test

* fix: address review comments

---------

Co-authored-by: Patti Shin <pattishin@users.noreply.github.com>
  • Loading branch information
2 people authored and Paulina Nguyen committed Sep 11, 2023
1 parent dba9baa commit 4c57c3b
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 0 deletions.
98 changes: 98 additions & 0 deletions ai-platform/snippets/code-model-tuning.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(
project,
pipelineJobId,
modelDisplayName,
gcsOutputDirectory,
location = 'europe-west4',
datasetUri = 'gs://cloud-samples-data/ai-platform/generative_ai/sql_create_context.jsonl',
trainSteps = 300
) {
// [START aiplatform_genai_code_model_tuning]
/**
* TODO(developer): Uncomment these variables before running the sample.\
* (Not necessary if passing values as arguments)
*/
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {PipelineServiceClient} = aiplatform.v1;

// Import the helper module for converting arbitrary protobuf.Value objects.
const {helpers} = aiplatform;

// Specifies the location of the api endpoint
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};
const model = 'code-bison@001';

const pipelineClient = new PipelineServiceClient(clientOptions);

async function tuneLLM() {
// Configure the parent resource
const parent = `projects/${project}/locations/${location}`;

const parameters = {
train_steps: helpers.toValue(trainSteps),
project: helpers.toValue(project),
location: helpers.toValue('us-central1'),
dataset_uri: helpers.toValue(datasetUri),
large_model_reference: helpers.toValue(model),
model_display_name: helpers.toValue(modelDisplayName),
};

const runtimeConfig = {
gcsOutputDirectory,
parameterValues: parameters,
};

const pipelineJob = {
templateUri:
'https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0',
displayName: 'my-tuning-job',
runtimeConfig,
};

const createPipelineRequest = {
parent,
pipelineJob,
pipelineJobId,
};

const [response] = await pipelineClient.createPipelineJob(
createPipelineRequest
);

console.log('Tuning pipeline job:');
console.log(`\tName: ${response.name}`);
console.log(
`\tCreate time: ${new Date(1970, 0, 1)
.setSeconds(response.createTime.seconds)
.toLocaleString()}`
);
console.log(`\tStatus: ${response.status}`);
}

await tuneLLM();
// [END aiplatform_genai_code_model_tuning]
}

exports.tuneModel = main;
86 changes: 86 additions & 0 deletions ai-platform/snippets/test/code-model-tuning.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/* eslint-disable */

'use strict';

const {assert} = require('chai');
const {describe, it} = require('mocha');
const uuid = require('uuid');
const sinon = require('sinon');

const projectId = process.env.CAIP_PROJECT_ID;
const location = 'europe-west4';

const aiplatform = require('@google-cloud/aiplatform');
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};
const pipelineClient = new aiplatform.v1.PipelineServiceClient(clientOptions);

const {tuneModel} = require('../code-model-tuning');

const timestampId = `${new Date()
.toISOString()
.replace(/(:|\.)/g, '-')
.toLowerCase()}`;
const pipelineJobName = `my-tuning-pipeline-${timestampId}`;
const modelDisplayName = `my-tuned-model-${timestampId}`;
const bucketName = 'ucaip-samples-europe-west4/training_pipeline_output';
const bucketUri = `gs://${bucketName}/tune-model-nodejs`;

describe('Tune a code model', () => {
const stubConsole = function () {
sinon.stub(console, 'error');
sinon.stub(console, 'log');
};

const restoreConsole = function () {
console.log.restore();
console.error.restore();
};

beforeEach(stubConsole);
afterEach(restoreConsole);

it('should prompt-tune an existing code model', async () => {
// Act
await tuneModel(projectId, pipelineJobName, modelDisplayName, bucketUri);

// Assert
assert.include(console.log.firstCall.args, 'Tuning pipeline job:');
});

after(async () => {
// Cancel and delete the pipeline job
const name = pipelineClient.pipelineJobPath(
projectId,
location,
pipelineJobName
);

const cancelRequest = {
name,
};

pipelineClient.cancelPipelineJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return pipelineClient.deletePipeline(deleteRequest);
});
});
});

0 comments on commit 4c57c3b

Please sign in to comment.