Skip to content

Commit

Permalink
feat: Add Contract Tests for new Gen AI attributes for foundational m…
Browse files Browse the repository at this point in the history
…odels (#119)

*Description of changes:*

 contract tests for new gen_ai inference parameter added in 


e8c96ae#diff-20c2ca1cb28cda6e03ec0cb986933b2abd103bee39995ad232cc2e8c2d23e4aaR368

<img width="1344" alt="image"
src="https://github.com/user-attachments/assets/1d63b019-fe49-4222-9663-34e4f10d3d5b">

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Michael He <53622546+yiyuan-he@users.noreply.github.com>
  • Loading branch information
liustve and yiyuan-he authored Nov 22, 2024
1 parent 25fa5e9 commit 27c0f80
Show file tree
Hide file tree
Showing 3 changed files with 377 additions and 24 deletions.
192 changes: 176 additions & 16 deletions contract-tests/images/applications/aws-sdk/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ const { S3Client, CreateBucketCommand, PutObjectCommand, GetObjectCommand } = re
const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk/client-dynamodb');
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
const fetch = require('node-fetch');
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
Expand Down Expand Up @@ -553,30 +552,190 @@ async function handleBedrockRequest(req, res, path) {
});
res.statusCode = 200;
} else if (path.includes('invokemodel/invoke-model')) {
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
const modelId = 'amazon.titan-text-premier-v1:0';
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
const prompt = `<s>[INST] ${userMessage} [/INST]`;

const body = JSON.stringify({
inputText: prompt,
textGenerationConfig: {
maxTokenCount: 3072,
stopSequences: [],
temperature: 0.7,
topP: 0.9,
},
});
const get_model_request_response = function () {
const prompt = "Describe the purpose of a 'hello world' program in one line.";
let modelId = ''
let request_body = {}
let response_body = {}

if (path.includes('amazon.titan')) {

modelId = 'amazon.titan-text-premier-v1:0';

request_body = {
inputText: prompt,
textGenerationConfig: {
maxTokenCount: 3072,
stopSequences: [],
temperature: 0.7,
topP: 0.9,
},
};

response_body = {
inputTextTokenCount: 15,
results: [
{
tokenCount: 13,
outputText: 'text-test-response',
completionReason: 'CONTENT_FILTERED',
},
],
}

}

if (path.includes('anthropic.claude')) {

modelId = 'anthropic.claude-v2:1';

request_body = {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 1000,
temperature: 0.99,
top_p: 1,
messages: [
{
role: 'user',
content: [{ type: 'text', text: prompt }],
},
],
};

response_body = {
stop_reason: 'end_turn',
usage: {
input_tokens: 15,
output_tokens: 13,
},
}
}

if (path.includes('meta.llama')) {
modelId = 'meta.llama2-13b-chat-v1';

request_body = {
prompt,
max_gen_len: 512,
temperature: 0.5,
top_p: 0.9
};

response_body = {
prompt_token_count: 31,
generation_token_count: 49,
stop_reason: 'stop'
}
}

if (path.includes('cohere.command')) {
modelId = 'cohere.command-light-text-v14';

request_body = {
prompt,
max_tokens: 512,
temperature: 0.5,
p: 0.65,
};

response_body = {
generations: [
{
finish_reason: 'COMPLETE',
text: 'test-generation-text',
},
],
prompt: prompt,
};
}

if (path.includes('cohere.command-r')) {
modelId = 'cohere.command-r-v1:0';

request_body = {
message: prompt,
max_tokens: 512,
temperature: 0.5,
p: 0.65,
};

response_body = {
finish_reason: 'COMPLETE',
text: 'test-generation-text',
prompt: prompt,
request: {
commandInput: {
modelId: modelId,
},
},
}
}

if (path.includes('ai21.jamba')) {
modelId = 'ai21.jamba-1-5-large-v1:0';

request_body = {
messages: [
{
role: 'user',
content: prompt,
},
],
top_p: 0.8,
temperature: 0.6,
max_tokens: 512,
};

response_body = {
stop_reason: 'end_turn',
usage: {
prompt_tokens: 21,
completion_tokens: 24,
},
choices: [
{
finish_reason: 'stop',
},
],
}
}

if (path.includes('mistral')) {
modelId = 'mistral.mistral-7b-instruct-v0:2';

request_body = {
prompt,
max_tokens: 4096,
temperature: 0.75,
top_p: 0.99,
};

response_body = {
outputs: [
{
text: 'test-output-text',
stop_reason: 'stop',
},
]
}
}

return [modelId, JSON.stringify(request_body), new TextEncoder().encode(JSON.stringify(response_body))]
}

const [modelId, request_body, response_body] = get_model_request_response();

await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], { body: response_body }, async () => {
await bedrockRuntimeClient.send(
new InvokeModelCommand({
body: body,
body: request_body,
modelId: modelId,
accept: 'application/json',
contentType: 'application/json',
})
);
});

res.statusCode = 200;
} else {
res.statusCode = 404;
Expand Down Expand Up @@ -624,3 +783,4 @@ prepareAwsServer().then(() => {
console.log('Ready');
});
});

Loading

0 comments on commit 27c0f80

Please sign in to comment.