From 67fb59d6ad24f83d23a5bd31e6ae8d7e1540e284 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 7 Feb 2024 16:21:08 -0800 Subject: [PATCH 01/68] WIP --- .../aws-bedrock-runtime-2.20/build.gradle | 17 ++ .../newrelic/utils/BedrockRuntimeUtil.java | 237 ++++++++++++++++++ .../utils/InvokeModelRequestWrapper.java | 145 +++++++++++ .../utils/InvokeModelResponseWrapper.java | 184 ++++++++++++++ ...ockRuntimeAsyncClient_Instrumentation.java | 67 +++++ ...tBedrockRuntimeClient_Instrumentation.java | 96 +++++++ settings.gradle | 1 + 7 files changed, 747 insertions(+) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/build.gradle create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/build.gradle b/instrumentation/aws-bedrock-runtime-2.20/build.gradle new file mode 100644 index 0000000000..b42f3d37cb --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/build.gradle @@ -0,0 +1,17 @@ +jar { + manifest { attributes 'Implementation-Title': 'com.newrelic.instrumentation.aws-bedrock-runtime-2.20' } +} + +dependencies { + implementation(project(":agent-bridge")) + implementation 'software.amazon.awssdk:bedrockruntime:2.20.157' +} + +verifyInstrumentation { + passes 'software.amazon.awssdk:bedrockruntime:[2.20.157,)' +} + +site { + title 'AWS Bedrock' + type 'Other' +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java new file mode 100644 index 0000000000..e31a9f4acc --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java @@ -0,0 +1,237 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.utils; + +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Transaction; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; + +public class BedrockRuntimeUtil { + private static final String VENDOR = "bedrock"; + private static final String INGEST_SOURCE = "Java"; + private static final String TRACE_ID = "trace.id"; + private static final String SPAN_ID = "span.id"; + + // LLM event types + private static final String LLM_EMBEDDING = "LlmEmbedding"; + private static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + private static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + /** + * This needs to be incremented for every invocation of the Bedrock SDK. + *

+ * The metric generated triggers the creation of a tag which gates the AI Response UI. The + * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and + * the UI will be hidden. + */ + public static void incrementBedrockInstrumentedMetric() { + NewRelic.incrementCounter("Java/ML/Bedrock/2.20"); + } + + public static void setLlmOperationMetricName(Transaction txn, String operationType) { + txn.getTracedMethod().setMetricName("Llm", operationType, "Bedrock", "invokeModel"); + } + + // TODO add event builders??? Avoid adding null/empty attributes? + + public static void reportLlmEmbeddingEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, + InvokeModelResponseWrapper invokeModelResponseWrapper) { + // TODO available data + String stopSequences = invokeModelRequestWrapper.parseStopSequences(); + String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); + String temperature = invokeModelRequestWrapper.parseTemperature(); + String prompt = invokeModelRequestWrapper.parsePrompt(); + + String completion = invokeModelResponseWrapper.parseCompletion(); + String stop = invokeModelResponseWrapper.parseStop(); + String stopReason = invokeModelResponseWrapper.parseStopReason(); + String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); + String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); + String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); + String operationType = invokeModelResponseWrapper.getOperationType(); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction + + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); + eventAttributes.put("span_id", getSpanId(linkingMetadata)); + eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("trace_id", getTraceId(linkingMetadata)); + eventAttributes.put("input", ""); // TODO figure out how to get this + eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} + eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); + eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) + eventAttributes.put("response.organization", ""); // TODO Organization ID returned in the response or response headers + eventAttributes.put("response.usage.total_tokens", ""); + eventAttributes.put("response.usage.prompt_tokens", ""); + eventAttributes.put("vendor", getVendor()); + eventAttributes.put("ingest_source", getIngestSource()); + eventAttributes.put("duration", ""); + eventAttributes.put("error", ""); +// eventAttributes.put("llm.", ""); + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); + } + + public static void reportLlmChatCompletionSummaryEvent(Transaction txn, Map linkingMetadata, + InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { + // TODO available data + String stopSequences = invokeModelRequestWrapper.parseStopSequences(); + String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); + String temperature = invokeModelRequestWrapper.parseTemperature(); + String prompt = invokeModelRequestWrapper.parsePrompt(); + + String completion = invokeModelResponseWrapper.parseCompletion(); + String stop = invokeModelResponseWrapper.parseStop(); + String stopReason = invokeModelResponseWrapper.parseStopReason(); + String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); + String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); + String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); + String operationType = invokeModelResponseWrapper.getOperationType(); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction + + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); + eventAttributes.put("span_id", getSpanId(linkingMetadata)); + eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("trace_id", getTraceId(linkingMetadata)); + eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} + eventAttributes.put("request.temperature", invokeModelRequestWrapper.parseTemperature()); + eventAttributes.put("request.max_tokens", invokeModelRequestWrapper.parseMaxTokensToSample()); + eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); + eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) + eventAttributes.put("response.organization", ""); // TODO Organization ID returned in the response or response headers + eventAttributes.put("response.number_of_messages", ""); + eventAttributes.put("response.usage.total_tokens", ""); + eventAttributes.put("response.usage.prompt_tokens", ""); + eventAttributes.put("response.usage.completion_tokens", ""); + eventAttributes.put("response.choices.finish_reason", ""); + eventAttributes.put("vendor", getVendor()); + eventAttributes.put("ingest_source", getIngestSource()); + eventAttributes.put("duration", ""); + eventAttributes.put("error", ""); +// eventAttributes.put("llm.", ""); + eventAttributes.put("conversation_id", ""); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); + } + + public static void reportLlmChatCompletionMessageEvent(Transaction txn, Map linkingMetadata, + InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { + // TODO available data + String stopSequences = invokeModelRequestWrapper.parseStopSequences(); + String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); + String temperature = invokeModelRequestWrapper.parseTemperature(); + String prompt = invokeModelRequestWrapper.parsePrompt(); + + String completion = invokeModelResponseWrapper.parseCompletion(); + String stop = invokeModelResponseWrapper.parseStop(); + String stopReason = invokeModelResponseWrapper.parseStopReason(); + String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); + String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); + String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); + String operationType = invokeModelResponseWrapper.getOperationType(); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction + + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); + eventAttributes.put("span_id", getSpanId(linkingMetadata)); + eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("trace_id", getTraceId(linkingMetadata)); + eventAttributes.put("conversation_id", ""); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API + eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} + eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) + eventAttributes.put("vendor", getVendor()); + eventAttributes.put("ingest_source", getIngestSource()); + eventAttributes.put("content", invokeModelRequestWrapper.parsePrompt()); + eventAttributes.put("role", ""); // TODO Role of the message creator (ex: system, assistant, user) + eventAttributes.put("sequence", ""); + eventAttributes.put("completion_id", ""); + eventAttributes.put("is_response", ""); +// eventAttributes.put("llm.", ""); + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); + } + + public static void debugLoggingForDevelopment(Transaction transaction, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { + System.out.println(); + System.out.println("Request: " + invokeModelRequest); + System.out.println("Request Body (UTF8 String): " + invokeModelRequest.body().asUtf8String()); + + System.out.println(); + System.out.println("Response: " + invokeModelResponse); + System.out.println("Response Body (UTF8 String): " + invokeModelResponse.body().asUtf8String()); + System.out.println("Response Metadata: " + invokeModelResponse.responseMetadata()); + System.out.println("Response Metadata Request ID: " + invokeModelResponse.responseMetadata().requestId()); + System.out.println("Response SdkHttpResponse Status Code: " + invokeModelResponse.sdkHttpResponse().statusCode()); + System.out.println("Response SdkHttpResponse Status Text: " + invokeModelResponse.sdkHttpResponse().statusText()); + System.out.println("Response SdkHttpResponse Is Successful: " + invokeModelResponse.sdkHttpResponse().isSuccessful()); + System.out.println(); + } + + // ========================= AGENT DATA ================================ + // Lowercased name of vendor (bedrock or openAI) + public static String getVendor() { + return VENDOR; + } + + // Name of the language agent (ex: Python, Node) + public static String getIngestSource() { + return INGEST_SOURCE; + } + + // GUID associated with the active trace + public static String getSpanId(Map linkingMetadata) { + return linkingMetadata.get(SPAN_ID); + } + + // ID of the current trace + public static String getTraceId(Map linkingMetadata) { + return linkingMetadata.get(TRACE_ID); + } + + // ID of the active transaction + public String getTransactionId() { + // FIXME not sure that this is accessible in an instrumentation module + // might need to add this to events withing the record_event API logic + // Sounds like we need to expose this on the public Transaction API + return ""; + } + + // Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API + public String getLlmUserDefinedMetadata() { + // FIXME hmm where can user attributes actually be accessed from????? AgentBridge??? Tracer??? + // If we create a new AI event type with it's own endpoint this would be inherited from from AnalyticsEvent + return ""; + } + + // Optional attribute that can be added to a transaction by a customer via add_custom_attribute API + public String getLlmConversationId() { + // FIXME hmm where can user attributes actually be accessed from????? AgentBridge??? Tracer??? + // If we create a new AI event type with it's own endpoint this would be inherited from from AnalyticsEvent + return ""; + } + + // Boolean set to True if a message is the result of a chat completion and not an input message + public String getIsResponse(InvokeModelResponse invokeModelResponse) { + // TODO Should this return a boolean or string?? + return ""; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java new file mode 100644 index 0000000000..f53f3dd831 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java @@ -0,0 +1,145 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.utils; + +import com.newrelic.api.agent.NewRelic; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; + +/** + * Stores the required info from the Bedrock InvokeModelRequest + * but doesn't hold a reference to the actual request object. + */ +public class InvokeModelRequestWrapper { + private final String invokeModelRequestBody; + private final String modelId; + private Map requestBodyJsonMap = null; + + // Request body (for Claude, how about other models?) + private static final String STOP_SEQUENCES = "stop_sequences"; + private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + + public InvokeModelRequestWrapper(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + invokeModelRequestBody = ""; + modelId = ""; + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelRequest"); + } + } + + // Lazy init and only parse map once + public Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + // TODO check for other types? Or will it always be Object? + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelRequest body as Map Object"); + } + + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + // TODO do we potentially expect more than one entry in the stop sequence? Or is it sufficient + // to just check if it contains Human? + public String parseStopSequences() { + StringBuilder stopSequences = new StringBuilder(); + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(STOP_SEQUENCES); + if (jsonNode.isArray()) { + List jsonNodeArray = jsonNode.asArray(); + for (JsonNode node : jsonNodeArray) { + if (node.isString()) { + // Don't add comma for first node + if (stopSequences.length() <= 0) { + stopSequences.append(node.asString()); + } else { + stopSequences.append(",").append(node.asString()); + } + } + } + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_SEQUENCES); + } + return stopSequences.toString().replaceAll("[\n:]", ""); + } + + public String parseMaxTokensToSample() { + String maxTokensToSample = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS_TO_SAMPLE); + if (jsonNode.isNumber()) { + maxTokensToSample = jsonNode.asNumber(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + MAX_TOKENS_TO_SAMPLE); + } + return maxTokensToSample; + } + + public String parseTemperature() { + String temperature = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + temperature = jsonNode.asNumber(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + TEMPERATURE); + } + return temperature; + } + + public String parsePrompt() { + String prompt = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(PROMPT); + if (jsonNode.isString()) { + prompt = jsonNode.asString(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); + } + return prompt.replace("Human: ", "").replace("\n\nAssistant:", ""); + } + + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java new file mode 100644 index 0000000000..a6e0dbead4 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java @@ -0,0 +1,184 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.utils; + +import com.newrelic.api.agent.NewRelic; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; + +/** + * Stores the required info from the Bedrock InvokeModelResponse + * but doesn't hold a reference to the actual response object. + */ +public class InvokeModelResponseWrapper { + private final String invokeModelResponseBody; + private Map responseBodyJsonMap = null; + + // Response body (for Claude, how about other models?) + public static final String COMPLETION = "completion"; + public static final String EMBEDDING = "embedding"; + private static final String STOP_REASON = "stop_reason"; + private static final String STOP = "stop"; + + // Response headers + private static final String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; + private static final String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; + private static final String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; + private static final String X_AMZN_BEDROCK_INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; + private String inputTokenCount; + private String outputTokenCount; + private String amznRequestId; + private String invocationLatency; + + // LLM operation type + private String operationType; + + public InvokeModelResponseWrapper(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + extractOperationType(invokeModelResponseBody); + extractHeaders(invokeModelResponse); + } else { + invokeModelResponseBody = ""; + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + private void extractOperationType(String invokeModelResponseBody) { + // FIXME should be starts with instead of contains? + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(COMPLETION)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.contains(EMBEDDING)) { + operationType = EMBEDDING; + } + } + } + + private void extractHeaders(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + inputTokenCount = inputTokenCountHeaders.get(0); + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + outputTokenCount = outputTokenCountHeaders.get(0); + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); // TODO does this differ from invokeModelResponse.responseMetadata().requestId() + } + List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); + if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { + invocationLatency = invocationLatencyHeaders.get(0); + } + } + } + + // Lazy init and only parse map once + public Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + private Map parseInvokeModelResponseBodyMap() { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + Map responseBodyJsonMap = null; + // TODO check for other types? Or will it always be Object? + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse body as Map Object"); + } + + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + public String parseCompletion() { + String completion = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(COMPLETION); + if (jsonNode.isString()) { + completion = jsonNode.asString(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + COMPLETION); + } + return completion; + } + + public String parseStopReason() { + String stopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(STOP_REASON); + if (jsonNode.isString()) { + stopReason = jsonNode.asString(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_REASON); + } + return stopReason; + } + + public String parseStop() { + String stop = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(STOP); + if (jsonNode.isString()) { + stop = jsonNode.asString(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP); + } + return stop.replaceAll("[\n:]", ""); + } + + public String getInputTokenCount() { + return inputTokenCount; + } + + public String getOutputTokenCount() { + return outputTokenCount; + } + + public String getAmznRequestId() { + return amznRequestId; + } + + public String getInvocationLatency() { + return invocationLatency; + } + + public String getOperationType() { + return operationType; + } + + public String getResponseModel() { + // TODO figure out where to get this from + return "TODO"; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java new file mode 100644 index 0000000000..cbded326b7 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -0,0 +1,67 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import software.amazon.awssdk.core.client.config.SdkClientConfiguration; +import software.amazon.awssdk.core.client.handler.AsyncClientHandler; +import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; + +/** + * Service client for accessing Amazon Bedrock Runtime asynchronously. + */ +@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient") +final class DefaultBedrockRuntimeAsyncClient_Instrumentation { +// private static final Logger log = LoggerFactory.getLogger(DefaultBedrockRuntimeAsyncClient.class); +// +// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() +// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + + private final AsyncClientHandler clientHandler; + + private final AwsJsonProtocolFactory protocolFactory; + + private final SdkClientConfiguration clientConfiguration; + + private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; + + private final Executor executor; + + protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, + SdkClientConfiguration clientConfiguration) { + this.clientHandler = Weaver.callOriginal(); + this.clientConfiguration = Weaver.callOriginal(); + this.serviceClientConfiguration = Weaver.callOriginal(); + this.protocolFactory = Weaver.callOriginal(); + this.executor = Weaver.callOriginal(); + } + + @Trace + public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); + + // FIXME needs to be incremented constantly for UI + incrementBedrockInstrumentedMetric(); + + System.out.println("Request: " + invokeModelRequest); + System.out.println("Request Body: " + invokeModelRequest.body()); + + return invokeModelResponseFuture; + } + +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java new file mode 100644 index 0000000000..7656626dad --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -0,0 +1,96 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.Transaction; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import com.newrelic.utils.InvokeModelRequestWrapper; +import com.newrelic.utils.InvokeModelResponseWrapper; +import software.amazon.awssdk.core.client.config.SdkClientConfiguration; +import software.amazon.awssdk.core.client.handler.SyncClientHandler; +import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Map; +import java.util.logging.Level; + +import static com.newrelic.utils.BedrockRuntimeUtil.debugLoggingForDevelopment; +import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; +import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmChatCompletionMessageEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmChatCompletionSummaryEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmEmbeddingEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.setLlmOperationMetricName; +import static com.newrelic.utils.InvokeModelResponseWrapper.COMPLETION; +import static com.newrelic.utils.InvokeModelResponseWrapper.EMBEDDING; + +/** + * Service client for accessing Amazon Bedrock Runtime. + */ +@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient") +final class DefaultBedrockRuntimeClient_Instrumentation { +// private static final Logger log = Logger.loggerFor(DefaultBedrockRuntimeClient.class); +// +// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() +// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + + private final SyncClientHandler clientHandler; + + private final AwsJsonProtocolFactory protocolFactory; + + private final SdkClientConfiguration clientConfiguration; + + private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; + + protected DefaultBedrockRuntimeClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, + SdkClientConfiguration clientConfiguration) { + this.clientHandler = Weaver.callOriginal(); + this.clientConfiguration = Weaver.callOriginal(); + this.serviceClientConfiguration = Weaver.callOriginal(); + this.protocolFactory = Weaver.callOriginal(); + } + + @Trace + public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { + InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); + + incrementBedrockInstrumentedMetric(); + + Transaction txn = NewRelic.getAgent().getTransaction(); + // TODO check AIM config + if (txn != null && !(txn instanceof NoOpTransaction)) { + debugLoggingForDevelopment(txn, invokeModelRequest, invokeModelResponse); // FIXME delete + + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + InvokeModelRequestWrapper requestWrapper = new InvokeModelRequestWrapper(invokeModelRequest); + InvokeModelResponseWrapper responseWrapper = new InvokeModelResponseWrapper(invokeModelResponse); + + String operationType = responseWrapper.getOperationType(); + // Set traced method name based on LLM operation + setLlmOperationMetricName(txn, operationType); + + // Report LLM events + if (operationType.equals(COMPLETION)) { + reportLlmChatCompletionMessageEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + reportLlmChatCompletionSummaryEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + } else if (operationType.equals(EMBEDDING)) { + reportLlmEmbeddingEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type"); + } + } + + return invokeModelResponse; + } + +} diff --git a/settings.gradle b/settings.gradle index 848313f152..4121a3575a 100644 --- a/settings.gradle +++ b/settings.gradle @@ -66,6 +66,7 @@ if (JavaVersion.current().isJava11Compatible()) { // Weaver Instrumentation include 'instrumentation:anorm-2.3' include 'instrumentation:anorm-2.4' +include 'instrumentation:aws-bedrock-runtime-2.20' include 'instrumentation:aws-java-sdk-sqs-1.10.44' include 'instrumentation:aws-java-sdk-s3-1.2.13' include 'instrumentation:aws-java-sdk-s3-2.0' From e2e6381c0067169189c30da910bc1eb3deb4316d Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 14 Feb 2024 16:04:28 -0800 Subject: [PATCH 02/68] WIP --- .../newrelic/utils/BedrockRuntimeUtil.java | 192 ++++++------------ .../utils/InvokeModelRequestWrapper.java | 73 +++++-- .../utils/InvokeModelResponseWrapper.java | 191 +++++++++++------ ...ockRuntimeAsyncClient_Instrumentation.java | 22 +- ...tBedrockRuntimeClient_Instrumentation.java | 14 +- 5 files changed, 278 insertions(+), 214 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java index e31a9f4acc..8310f40b09 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java @@ -9,8 +9,6 @@ import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Transaction; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.HashMap; import java.util.Map; @@ -34,158 +32,122 @@ public class BedrockRuntimeUtil { * the UI will be hidden. */ public static void incrementBedrockInstrumentedMetric() { + // FIXME get library version, not instrumentation version, probably not possible NewRelic.incrementCounter("Java/ML/Bedrock/2.20"); } + /** + * Set name of the span/segment for each LLM embedding and chat completion call + * Llm/{operation_type}/{vendor_name}/{function_name} + * + * @param txn current transaction + * @param operationType operation of type completion or embedding + */ public static void setLlmOperationMetricName(Transaction txn, String operationType) { txn.getTracedMethod().setMetricName("Llm", operationType, "Bedrock", "invokeModel"); } // TODO add event builders??? Avoid adding null/empty attributes? - public static void reportLlmEmbeddingEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, + // TODO create a single recordLlmEvent method that can take a type. Always add attributes common to + // all types and add others based on conditionals + + public static void recordLlmEmbeddingEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO available data - String stopSequences = invokeModelRequestWrapper.parseStopSequences(); - String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); - String temperature = invokeModelRequestWrapper.parseTemperature(); - String prompt = invokeModelRequestWrapper.parsePrompt(); - - String completion = invokeModelResponseWrapper.parseCompletion(); - String stop = invokeModelResponseWrapper.parseStop(); - String stopReason = invokeModelResponseWrapper.parseStopReason(); - String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); - String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); - String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); - String operationType = invokeModelResponseWrapper.getOperationType(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction - Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", invokeModelResponseWrapper.getLlmEmbeddingId()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("input", ""); // TODO figure out how to get this - eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} + eventAttributes.put("input", invokeModelRequestWrapper.getInputText()); eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); - eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) - eventAttributes.put("response.organization", ""); // TODO Organization ID returned in the response or response headers - eventAttributes.put("response.usage.total_tokens", ""); - eventAttributes.put("response.usage.prompt_tokens", ""); + eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. + eventAttributes.put("response.usage.total_tokens", invokeModelResponseWrapper.getTotalTokenCount()); + eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); - eventAttributes.put("duration", ""); - eventAttributes.put("error", ""); -// eventAttributes.put("llm.", ""); +// eventAttributes.put("duration", "NOT POSSIBLE"); // TODO Total time taken for the chat completion or embedding call to complete + if (invokeModelResponseWrapper.isErrorResponse()) { + eventAttributes.put("error", true); // TODO Bool set to True if an error occurred during creation call - omitted if no error occurred +// NewRelic.noticeError(invokeModelResponseWrapper.getStatusText()); + } +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); } - public static void reportLlmChatCompletionSummaryEvent(Transaction txn, Map linkingMetadata, + public static void recordLlmChatCompletionSummaryEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO available data - String stopSequences = invokeModelRequestWrapper.parseStopSequences(); - String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); - String temperature = invokeModelRequestWrapper.parseTemperature(); - String prompt = invokeModelRequestWrapper.parsePrompt(); - - String completion = invokeModelResponseWrapper.parseCompletion(); - String stop = invokeModelResponseWrapper.parseStop(); - String stopReason = invokeModelResponseWrapper.parseStopReason(); - String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); - String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); - String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); - String operationType = invokeModelResponseWrapper.getOperationType(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction - Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} - eventAttributes.put("request.temperature", invokeModelRequestWrapper.parseTemperature()); - eventAttributes.put("request.max_tokens", invokeModelRequestWrapper.parseMaxTokensToSample()); + eventAttributes.put("request.temperature", invokeModelRequestWrapper.getTemperature()); + eventAttributes.put("request.max_tokens", invokeModelRequestWrapper.getMaxTokensToSample()); eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); - eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) - eventAttributes.put("response.organization", ""); // TODO Organization ID returned in the response or response headers - eventAttributes.put("response.number_of_messages", ""); - eventAttributes.put("response.usage.total_tokens", ""); - eventAttributes.put("response.usage.prompt_tokens", ""); - eventAttributes.put("response.usage.completion_tokens", ""); - eventAttributes.put("response.choices.finish_reason", ""); + eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. + eventAttributes.put("response.number_of_messages", + ""); // TODO Number of messages comprising a chat completion including system, user, and assistant messages + eventAttributes.put("response.usage.total_tokens", invokeModelResponseWrapper.getTotalTokenCount()); + eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); + eventAttributes.put("response.usage.completion_tokens", invokeModelResponseWrapper.getOutputTokenCount()); + eventAttributes.put("response.choices.finish_reason", invokeModelResponseWrapper.getStopReason()); eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); - eventAttributes.put("duration", ""); - eventAttributes.put("error", ""); -// eventAttributes.put("llm.", ""); - eventAttributes.put("conversation_id", ""); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API +// eventAttributes.put("duration", "NOT POSSIBLE"); // TODO Total time taken for the chat completion or embedding call to complete + if (invokeModelResponseWrapper.isErrorResponse()) { + eventAttributes.put("error", true); // TODO Bool set to True if an error occurred during creation call - omitted if no error occurred +// NewRelic.noticeError(invokeModelResponseWrapper.getStatusText()); + } +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); } - public static void reportLlmChatCompletionMessageEvent(Transaction txn, Map linkingMetadata, + public static void recordLlmChatCompletionMessageEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO available data - String stopSequences = invokeModelRequestWrapper.parseStopSequences(); - String maxTokensToSample = invokeModelRequestWrapper.parseMaxTokensToSample(); - String temperature = invokeModelRequestWrapper.parseTemperature(); - String prompt = invokeModelRequestWrapper.parsePrompt(); - - String completion = invokeModelResponseWrapper.parseCompletion(); - String stop = invokeModelResponseWrapper.parseStop(); - String stopReason = invokeModelResponseWrapper.parseStopReason(); - String inputTokenCount = invokeModelResponseWrapper.getInputTokenCount(); - String outputTokenCount = invokeModelResponseWrapper.getOutputTokenCount(); - String invocationLatency = invokeModelResponseWrapper.getInvocationLatency(); - String operationType = invokeModelResponseWrapper.getOperationType(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction - Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", ""); // TODO ID in the format response_id-sequence or a UUID generated by the agent if no response ID is returned by the LLM + Map eventAttributes = new HashMap<>(); + eventAttributes.put("id", invokeModelResponseWrapper.getLlmChatCompletionMessageId()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", ""); // TODO figure out how to get this from agent + eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("conversation_id", ""); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API - eventAttributes.put("api_key_last_four_digits", ""); // TODO Final digits of API key formatted as: sk-{last_four_digits_of_api_key} - eventAttributes.put("response.model", ""); // TODO Model name returned in the response (can differ from request.model) +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? + eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); - eventAttributes.put("content", invokeModelRequestWrapper.parsePrompt()); - eventAttributes.put("role", ""); // TODO Role of the message creator (ex: system, assistant, user) - eventAttributes.put("sequence", ""); - eventAttributes.put("completion_id", ""); - eventAttributes.put("is_response", ""); -// eventAttributes.put("llm.", ""); + eventAttributes.put("content", invokeModelRequestWrapper.getPrompt()); + String role = invokeModelRequestWrapper.getRole(); + if (!role.isEmpty()) { + eventAttributes.put("role", role); + if (!role.contains("user")) { + eventAttributes.put("is_response", true); + } + } + eventAttributes.put("sequence", ""); // TODO Index (beginning at 0) associated with each message including the prompt and responses + eventAttributes.put("completion_id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); + +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); } - public static void debugLoggingForDevelopment(Transaction transaction, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { - System.out.println(); - System.out.println("Request: " + invokeModelRequest); - System.out.println("Request Body (UTF8 String): " + invokeModelRequest.body().asUtf8String()); - - System.out.println(); - System.out.println("Response: " + invokeModelResponse); - System.out.println("Response Body (UTF8 String): " + invokeModelResponse.body().asUtf8String()); - System.out.println("Response Metadata: " + invokeModelResponse.responseMetadata()); - System.out.println("Response Metadata Request ID: " + invokeModelResponse.responseMetadata().requestId()); - System.out.println("Response SdkHttpResponse Status Code: " + invokeModelResponse.sdkHttpResponse().statusCode()); - System.out.println("Response SdkHttpResponse Status Text: " + invokeModelResponse.sdkHttpResponse().statusText()); - System.out.println("Response SdkHttpResponse Is Successful: " + invokeModelResponse.sdkHttpResponse().isSuccessful()); - System.out.println(); - } - // ========================= AGENT DATA ================================ // Lowercased name of vendor (bedrock or openAI) public static String getVendor() { @@ -206,32 +168,4 @@ public static String getSpanId(Map linkingMetadata) { public static String getTraceId(Map linkingMetadata) { return linkingMetadata.get(TRACE_ID); } - - // ID of the active transaction - public String getTransactionId() { - // FIXME not sure that this is accessible in an instrumentation module - // might need to add this to events withing the record_event API logic - // Sounds like we need to expose this on the public Transaction API - return ""; - } - - // Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API - public String getLlmUserDefinedMetadata() { - // FIXME hmm where can user attributes actually be accessed from????? AgentBridge??? Tracer??? - // If we create a new AI event type with it's own endpoint this would be inherited from from AnalyticsEvent - return ""; - } - - // Optional attribute that can be added to a transaction by a customer via add_custom_attribute API - public String getLlmConversationId() { - // FIXME hmm where can user attributes actually be accessed from????? AgentBridge??? Tracer??? - // If we create a new AI event type with it's own endpoint this would be inherited from from AnalyticsEvent - return ""; - } - - // Boolean set to True if a message is the result of a chat completion and not an input message - public String getIsResponse(InvokeModelResponse invokeModelResponse) { - // TODO Should this return a boolean or string?? - return ""; - } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java index f53f3dd831..42ce46b020 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java @@ -22,28 +22,38 @@ * but doesn't hold a reference to the actual request object. */ public class InvokeModelRequestWrapper { - private final String invokeModelRequestBody; - private final String modelId; - private Map requestBodyJsonMap = null; - // Request body (for Claude, how about other models?) private static final String STOP_SEQUENCES = "stop_sequences"; private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; + private static final String INPUT_TEXT = "inputText"; + private static final String ESCAPED_NEWLINES = "\\n\\n"; + private static final String SYSTEM = "system"; + private static final String ASSISTANT = "assistant"; + private static final String USER = "user"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + public InvokeModelRequestWrapper(InvokeModelRequest invokeModelRequest) { if (invokeModelRequest != null) { invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); modelId = invokeModelRequest.modelId(); } else { - invokeModelRequestBody = ""; - modelId = ""; NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelRequest"); } } - // Lazy init and only parse map once + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ public Map getRequestBodyJsonMap() { if (requestBodyJsonMap == null) { requestBodyJsonMap = parseInvokeModelRequestBodyMap(); @@ -51,6 +61,11 @@ public Map getRequestBodyJsonMap() { return requestBodyJsonMap; } + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ private Map parseInvokeModelRequestBodyMap() { // Use AWS SDK JSON parsing to parse request body JsonNodeParser jsonNodeParser = JsonNodeParser.create(); @@ -69,7 +84,7 @@ private Map parseInvokeModelRequestBodyMap() { // TODO do we potentially expect more than one entry in the stop sequence? Or is it sufficient // to just check if it contains Human? - public String parseStopSequences() { + public String getStopSequences() { StringBuilder stopSequences = new StringBuilder(); try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -94,7 +109,7 @@ public String parseStopSequences() { return stopSequences.toString().replaceAll("[\n:]", ""); } - public String parseMaxTokensToSample() { + public String getMaxTokensToSample() { String maxTokensToSample = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -109,7 +124,7 @@ public String parseMaxTokensToSample() { return maxTokensToSample; } - public String parseTemperature() { + public String getTemperature() { String temperature = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -124,7 +139,7 @@ public String parseTemperature() { return temperature; } - public String parsePrompt() { + public String getPrompt() { String prompt = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -136,7 +151,41 @@ public String parsePrompt() { } catch (Exception e) { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); } - return prompt.replace("Human: ", "").replace("\n\nAssistant:", ""); + return prompt; +// return prompt.replace("Human: ", "").replace("\n\nAssistant:", ""); + } + + public String getRole() { + try { + if (!invokeModelRequestBody.isEmpty()) { + String invokeModelRequestBodyLowerCase = invokeModelRequestBody.toLowerCase(); + if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + SYSTEM)) { + return SYSTEM; + } else if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + USER)) { + return USER; + } else if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + ASSISTANT)) { + return ASSISTANT; + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse role from InvokeModelRequest"); + } + return ""; + } + + public String getInputText() { + String inputText = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(INPUT_TEXT); + if (jsonNode.isString()) { + inputText = jsonNode.asString(); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + INPUT_TEXT); + } + return inputText; } public String getModelId() { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java index a6e0dbead4..6a82e98111 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java @@ -15,6 +15,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.UUID; import java.util.logging.Level; /** @@ -22,9 +24,6 @@ * but doesn't hold a reference to the actual response object. */ public class InvokeModelResponseWrapper { - private final String invokeModelResponseBody; - private Map responseBodyJsonMap = null; - // Response body (for Claude, how about other models?) public static final String COMPLETION = "completion"; public static final String EMBEDDING = "embedding"; @@ -36,83 +35,127 @@ public class InvokeModelResponseWrapper { private static final String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; private static final String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; private static final String X_AMZN_BEDROCK_INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; - private String inputTokenCount; - private String outputTokenCount; - private String amznRequestId; - private String invocationLatency; + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String amznRequestId = ""; + private String invocationLatency = ""; // LLM operation type - private String operationType; + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + // Random GUID for response + private String llmChatCompletionMessageId = ""; + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + private static final String JSON_START = "{\""; public InvokeModelResponseWrapper(InvokeModelResponse invokeModelResponse) { if (invokeModelResponse != null) { invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); - extractOperationType(invokeModelResponseBody); - extractHeaders(invokeModelResponse); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + setHeaderFields(invokeModelResponse); + llmChatCompletionMessageId = UUID.randomUUID().toString(); + llmChatCompletionSummaryId = UUID.randomUUID().toString(); + llmEmbeddingId = UUID.randomUUID().toString(); } else { - invokeModelResponseBody = ""; NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); } } - private void extractOperationType(String invokeModelResponseBody) { - // FIXME should be starts with instead of contains? - if (!invokeModelResponseBody.isEmpty()) { - if (invokeModelResponseBody.contains(COMPLETION)) { - operationType = COMPLETION; - } else if (invokeModelResponseBody.contains(EMBEDDING)) { - operationType = EMBEDDING; - } - } - } - - private void extractHeaders(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - inputTokenCount = inputTokenCountHeaders.get(0); - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - outputTokenCount = outputTokenCountHeaders.get(0); - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); // TODO does this differ from invokeModelResponse.responseMetadata().requestId() - } - List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); - if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { - invocationLatency = invocationLatencyHeaders.get(0); - } - } - } - - // Lazy init and only parse map once - public Map getResponseBodyJsonMap() { + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { if (responseBodyJsonMap == null) { responseBodyJsonMap = parseInvokeModelResponseBodyMap(); } return responseBodyJsonMap; } + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ private Map parseInvokeModelResponseBodyMap() { - // Use AWS SDK JSON parsing to parse response body - JsonNodeParser jsonNodeParser = JsonNodeParser.create(); - JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); - Map responseBodyJsonMap = null; - // TODO check for other types? Or will it always be Object? - if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { - responseBodyJsonMap = responseBodyJsonNode.asObject(); - } else { + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + // TODO check for other types? Or will it always be Object? + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } +// else { +// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse body as Map Object"); +// } + } catch (Exception e) { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse body as Map Object"); } return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); } - public String parseCompletion() { + private void setOperationType(String invokeModelResponseBody) { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.startsWith(JSON_START + COMPLETION)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { + operationType = EMBEDDING; + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unknown operation type"); + } + } + } + + private void setHeaderFields(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + try { + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + String result = inputTokenCountHeaders.get(0); + inputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + String result = outputTokenCountHeaders.get(0); + outputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); // TODO does this differ from invokeModelResponse.responseMetadata().requestId() + } + List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); + if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { + invocationLatency = invocationLatencyHeaders.get(0); + } + } + } catch (Exception e) { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse headers"); + } + } + + public String getCompletion() { String completion = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -127,7 +170,7 @@ public String parseCompletion() { return completion; } - public String parseStopReason() { + public String getStopReason() { String stopReason = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -142,7 +185,7 @@ public String parseStopReason() { return stopReason; } - public String parseStop() { + public String getStop() { String stop = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -157,14 +200,18 @@ public String parseStop() { return stop.replaceAll("[\n:]", ""); } - public String getInputTokenCount() { + public int getInputTokenCount() { return inputTokenCount; } - public String getOutputTokenCount() { + public int getOutputTokenCount() { return outputTokenCount; } + public int getTotalTokenCount() { + return inputTokenCount + outputTokenCount; + } + public String getAmznRequestId() { return amznRequestId; } @@ -177,8 +224,28 @@ public String getOperationType() { return operationType; } - public String getResponseModel() { - // TODO figure out where to get this from - return "TODO"; + // TODO create errors with below info + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + public int getStatusCode() { + return statusCode; + } + + public String getStatusText() { + return statusText; + } + + public String getLlmChatCompletionMessageId() { + return llmChatCompletionMessageId; + } + + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + public String getLlmEmbeddingId() { + return llmEmbeddingId; } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java index cbded326b7..bf71fd298c 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -7,6 +7,9 @@ package software.amazon.awssdk.services.bedrockruntime; +import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; import com.newrelic.api.agent.Trace; import com.newrelic.api.agent.weaver.MatchType; import com.newrelic.api.agent.weaver.Weave; @@ -53,15 +56,28 @@ protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeService @Trace public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + // TODO name "Llm/" + operationType + "/Bedrock/InvokeModelAsync" ???? + Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); - // FIXME needs to be incremented constantly for UI incrementBedrockInstrumentedMetric(); - System.out.println("Request: " + invokeModelRequest); - System.out.println("Request Body: " + invokeModelRequest.body()); + // this should never happen, but protecting against bad implementations + if (invokeModelResponseFuture == null) { + segment.end(); + } else { + invokeModelResponseFuture.whenComplete((invokeModelResponse, throwable) -> { + try { + // TODO do all the stuff + segment.end(); + } catch (Throwable t) { + AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); + } + }); + } return invokeModelResponseFuture; + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index 7656626dad..57c0e1db1f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -25,11 +25,10 @@ import java.util.Map; import java.util.logging.Level; -import static com.newrelic.utils.BedrockRuntimeUtil.debugLoggingForDevelopment; import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; -import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmChatCompletionMessageEvent; -import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmChatCompletionSummaryEvent; -import static com.newrelic.utils.BedrockRuntimeUtil.reportLlmEmbeddingEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionMessageEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionSummaryEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmEmbeddingEvent; import static com.newrelic.utils.BedrockRuntimeUtil.setLlmOperationMetricName; import static com.newrelic.utils.InvokeModelResponseWrapper.COMPLETION; import static com.newrelic.utils.InvokeModelResponseWrapper.EMBEDDING; @@ -69,7 +68,6 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { Transaction txn = NewRelic.getAgent().getTransaction(); // TODO check AIM config if (txn != null && !(txn instanceof NoOpTransaction)) { - debugLoggingForDevelopment(txn, invokeModelRequest, invokeModelResponse); // FIXME delete Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); InvokeModelRequestWrapper requestWrapper = new InvokeModelRequestWrapper(invokeModelRequest); @@ -81,10 +79,10 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { // Report LLM events if (operationType.equals(COMPLETION)) { - reportLlmChatCompletionMessageEvent(txn, linkingMetadata, requestWrapper, responseWrapper); - reportLlmChatCompletionSummaryEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + recordLlmChatCompletionMessageEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + recordLlmChatCompletionSummaryEvent(txn, linkingMetadata, requestWrapper, responseWrapper); } else if (operationType.equals(EMBEDDING)) { - reportLlmEmbeddingEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + recordLlmEmbeddingEvent(txn, linkingMetadata, requestWrapper, responseWrapper); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type"); } From 3e08036e0894576707066c423102f6e1b2a65d0a Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 14 Feb 2024 16:43:22 -0800 Subject: [PATCH 03/68] Report Llm errors --- .../com/newrelic/utils/BedrockRuntimeUtil.java | 8 ++++---- .../newrelic/utils/InvokeModelResponseWrapper.java | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java index 8310f40b09..918efadf46 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java @@ -73,8 +73,8 @@ public static void recordLlmEmbeddingEvent(Transaction txn, Map eventAttributes.put("ingest_source", getIngestSource()); // eventAttributes.put("duration", "NOT POSSIBLE"); // TODO Total time taken for the chat completion or embedding call to complete if (invokeModelResponseWrapper.isErrorResponse()) { - eventAttributes.put("error", true); // TODO Bool set to True if an error occurred during creation call - omitted if no error occurred -// NewRelic.noticeError(invokeModelResponseWrapper.getStatusText()); + eventAttributes.put("error", true); + invokeModelResponseWrapper.reportLlmError(); } // eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? @@ -107,8 +107,8 @@ public static void recordLlmChatCompletionSummaryEvent(Transaction txn, Map", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? // eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java index 6a82e98111..b2a5bd0a4c 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java @@ -13,6 +13,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -248,4 +249,17 @@ public String getLlmChatCompletionSummaryId() { public String getLlmEmbeddingId() { return llmEmbeddingId; } + + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", statusCode); + errorParams.put("error.code", statusCode); + if (!llmChatCompletionSummaryId.isEmpty()) { + errorParams.put("completion_id", llmChatCompletionSummaryId); + } + if (!llmEmbeddingId.isEmpty()) { + errorParams.put("embedding_id", llmEmbeddingId); + } + NewRelic.noticeError("LlmError: " + statusText, errorParams); + } } From 5c994f39c649606255b30d08480c912d8b22729a Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 20 Feb 2024 14:11:06 -0800 Subject: [PATCH 04/68] Refactoring --- .../newrelic/utils/BedrockRuntimeUtil.java | 99 +++++++++++++------ .../utils/InvokeModelRequestWrapper.java | 32 +++++- .../utils/InvokeModelResponseWrapper.java | 60 ++++++----- ...ockRuntimeAsyncClient_Instrumentation.java | 1 + ...tBedrockRuntimeClient_Instrumentation.java | 11 +-- 5 files changed, 136 insertions(+), 67 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java index 918efadf46..5a1eaedac7 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java @@ -12,6 +12,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.UUID; public class BedrockRuntimeUtil { private static final String VENDOR = "bedrock"; @@ -26,14 +27,15 @@ public class BedrockRuntimeUtil { /** * This needs to be incremented for every invocation of the Bedrock SDK. + * Supportability/{language}/ML/{vendor_name}/{vendor_version} *

* The metric generated triggers the creation of a tag which gates the AI Response UI. The * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and * the UI will be hidden. */ public static void incrementBedrockInstrumentedMetric() { - // FIXME get library version, not instrumentation version, probably not possible - NewRelic.incrementCounter("Java/ML/Bedrock/2.20"); + // Bedrock vendor_version isn't available, so set it to instrumentation version instead + NewRelic.incrementCounter("Supportability/Java/ML/Bedrock/2.20"); } /** @@ -52,17 +54,16 @@ public static void setLlmOperationMetricName(Transaction txn, String operationTy // TODO create a single recordLlmEvent method that can take a type. Always add attributes common to // all types and add others based on conditionals - public static void recordLlmEmbeddingEvent(Transaction txn, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, + public static void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction + // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point + invokeModelRequestWrapper.getModelId(); Map eventAttributes = new HashMap<>(); eventAttributes.put("id", invokeModelResponseWrapper.getLlmEmbeddingId()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); eventAttributes.put("input", invokeModelRequestWrapper.getInputText()); eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); @@ -71,83 +72,116 @@ public static void recordLlmEmbeddingEvent(Transaction txn, Map eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); -// eventAttributes.put("duration", "NOT POSSIBLE"); // TODO Total time taken for the chat completion or embedding call to complete if (invokeModelResponseWrapper.isErrorResponse()) { eventAttributes.put("error", true); invokeModelResponseWrapper.reportLlmError(); } + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction // eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? + long endTime = System.currentTimeMillis(); + eventAttributes.put("duration", (endTime - startTime)); NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); } - public static void recordLlmChatCompletionSummaryEvent(Transaction txn, Map linkingMetadata, + private static void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction + // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point + invokeModelRequestWrapper.getModelId(); Map eventAttributes = new HashMap<>(); eventAttributes.put("id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); eventAttributes.put("request.temperature", invokeModelRequestWrapper.getTemperature()); eventAttributes.put("request.max_tokens", invokeModelRequestWrapper.getMaxTokensToSample()); eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. - eventAttributes.put("response.number_of_messages", - ""); // TODO Number of messages comprising a chat completion including system, user, and assistant messages + eventAttributes.put("response.number_of_messages", numberOfMessages); eventAttributes.put("response.usage.total_tokens", invokeModelResponseWrapper.getTotalTokenCount()); eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); eventAttributes.put("response.usage.completion_tokens", invokeModelResponseWrapper.getOutputTokenCount()); eventAttributes.put("response.choices.finish_reason", invokeModelResponseWrapper.getStopReason()); eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); -// eventAttributes.put("duration", "NOT POSSIBLE"); // TODO Total time taken for the chat completion or embedding call to complete if (invokeModelResponseWrapper.isErrorResponse()) { eventAttributes.put("error", true); invokeModelResponseWrapper.reportLlmError(); } + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction // eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! + long endTime = System.currentTimeMillis(); + eventAttributes.put("duration", (endTime - startTime)); NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); } - public static void recordLlmChatCompletionMessageEvent(Transaction txn, Map linkingMetadata, + private static void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction + // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point + invokeModelRequestWrapper.getModelId(); Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", invokeModelResponseWrapper.getLlmChatCompletionMessageId()); + eventAttributes.put("content", message); + + // FIXME id, content, role, and sequence + // This parsing might only apply to Claude? + if (message.contains("Human:")) { + eventAttributes.put("role", "user"); + eventAttributes.put("is_response", false); + } else { + String role = invokeModelRequestWrapper.getRole(); + if (!role.isEmpty()) { + eventAttributes.put("role", role); + if (!role.contains("user")) { + eventAttributes.put("is_response", true); + } + } + } + + eventAttributes.put("id", getRandomGuid()); eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("transaction_id", getTraceId(linkingMetadata)); // FIXME figure out how to get txn ID from linking metadata eventAttributes.put("trace_id", getTraceId(linkingMetadata)); -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. eventAttributes.put("vendor", getVendor()); eventAttributes.put("ingest_source", getIngestSource()); - eventAttributes.put("content", invokeModelRequestWrapper.getPrompt()); - String role = invokeModelRequestWrapper.getRole(); - if (!role.isEmpty()) { - eventAttributes.put("role", role); - if (!role.contains("user")) { - eventAttributes.put("is_response", true); - } - } - eventAttributes.put("sequence", ""); // TODO Index (beginning at 0) associated with each message including the prompt and responses + eventAttributes.put("sequence", sequence); eventAttributes.put("completion_id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction // eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); } + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event + * + * @param linkingMetadata + * @param requestWrapper + * @param responseWrapper + */ + public static void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata, + InvokeModelRequestWrapper requestWrapper, InvokeModelResponseWrapper responseWrapper) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, requestWrapper.getRequestMessage(), linkingMetadata, requestWrapper, responseWrapper); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, responseWrapper.getResponseMessage(), linkingMetadata, requestWrapper, responseWrapper); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata, requestWrapper, responseWrapper); + } + // ========================= AGENT DATA ================================ // Lowercased name of vendor (bedrock or openAI) public static String getVendor() { @@ -168,4 +202,9 @@ public static String getSpanId(Map linkingMetadata) { public static String getTraceId(Map linkingMetadata) { return linkingMetadata.get(TRACE_ID); } + + // Returns a string representation of a random GUID + public static String getRandomGuid() { + return UUID.randomUUID().toString(); + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java index 42ce46b020..2e625d88e0 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java @@ -37,7 +37,6 @@ public class InvokeModelRequestWrapper { private String modelId = ""; private Map requestBodyJsonMap = null; - public InvokeModelRequestWrapper(InvokeModelRequest invokeModelRequest) { if (invokeModelRequest != null) { invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); @@ -54,7 +53,7 @@ public InvokeModelRequestWrapper(InvokeModelRequest invokeModelRequest) { * * @return map of String to JsonNode */ - public Map getRequestBodyJsonMap() { + private Map getRequestBodyJsonMap() { if (requestBodyJsonMap == null) { requestBodyJsonMap = parseInvokeModelRequestBodyMap(); } @@ -73,6 +72,7 @@ private Map parseInvokeModelRequestBodyMap() { Map requestBodyJsonMap = null; // TODO check for other types? Or will it always be Object? + // add try/catch? if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { requestBodyJsonMap = requestBodyJsonNode.asObject(); } else { @@ -83,7 +83,7 @@ private Map parseInvokeModelRequestBodyMap() { } // TODO do we potentially expect more than one entry in the stop sequence? Or is it sufficient - // to just check if it contains Human? + // to just check if it contains Human? DO we even need this at all? Doesn't look like we do public String getStopSequences() { StringBuilder stopSequences = new StringBuilder(); try { @@ -139,7 +139,7 @@ public String getTemperature() { return temperature; } - public String getPrompt() { + public String getRequestMessage() { String prompt = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -152,9 +152,31 @@ public String getPrompt() { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); } return prompt; -// return prompt.replace("Human: ", "").replace("\n\nAssistant:", ""); } +// /** +// * Represents the user prompt messages +// * +// * @return +// */ +// public List getUserRequestMessages() { +// // FIXME shouldn't parse the request, just send the whole content +// List messageList = null; +// try { +// if (!getRequestBodyJsonMap().isEmpty()) { +// JsonNode jsonNode = getRequestBodyJsonMap().get(PROMPT); +// if (jsonNode.isString()) { +// String[] messages = jsonNode.asString().split(ESCAPED_NEWLINES); +// messageList = new ArrayList<>(Arrays.asList(messages)); +// messageList.remove("Assistant:"); +// } +// } +// } catch (Exception e) { +// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); +// } +// return messageList; +// } + public String getRole() { try { if (!invokeModelRequestBody.isEmpty()) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java index b2a5bd0a4c..a8010a63bd 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java @@ -17,9 +17,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.UUID; import java.util.logging.Level; +import static com.newrelic.utils.BedrockRuntimeUtil.getRandomGuid; + /** * Stores the required info from the Bedrock InvokeModelResponse * but doesn't hold a reference to the actual response object. @@ -50,7 +51,7 @@ public class InvokeModelResponseWrapper { private String statusText = ""; // Random GUID for response - private String llmChatCompletionMessageId = ""; +// private String llmChatCompletionMessageId = ""; private String llmChatCompletionSummaryId = ""; private String llmEmbeddingId = ""; @@ -68,9 +69,9 @@ public InvokeModelResponseWrapper(InvokeModelResponse invokeModelResponse) { statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); setHeaderFields(invokeModelResponse); - llmChatCompletionMessageId = UUID.randomUUID().toString(); - llmChatCompletionSummaryId = UUID.randomUUID().toString(); - llmEmbeddingId = UUID.randomUUID().toString(); +// llmChatCompletionMessageId = BedrockRuntimeUtil.getRandomGuid(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); } @@ -156,7 +157,12 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { } } - public String getCompletion() { + /** + * Represents the response message + * + * @return + */ + public String getResponseMessage() { String completion = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -225,7 +231,21 @@ public String getOperationType() { return operationType; } - // TODO create errors with below info + // TODO stop saving these GUIDS, instead just create a Util method to generate a random one each call?? +// public String getLlmChatCompletionMessageId() { +// return llmChatCompletionMessageId; +// } + + // hmmm this one needs to be stored as it's used for completion_id in the message events + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + // hmmm also needs to be stored to be used for embedding_id in errors + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + public boolean isErrorResponse() { return !isSuccessfulResponse; } @@ -238,28 +258,16 @@ public String getStatusText() { return statusText; } - public String getLlmChatCompletionMessageId() { - return llmChatCompletionMessageId; - } - - public String getLlmChatCompletionSummaryId() { - return llmChatCompletionSummaryId; - } - - public String getLlmEmbeddingId() { - return llmEmbeddingId; - } - public void reportLlmError() { Map errorParams = new HashMap<>(); - errorParams.put("http.statusCode", statusCode); - errorParams.put("error.code", statusCode); - if (!llmChatCompletionSummaryId.isEmpty()) { - errorParams.put("completion_id", llmChatCompletionSummaryId); + errorParams.put("http.statusCode", getStatusCode()); + errorParams.put("error.code", getStatusCode()); + if (!getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", getLlmChatCompletionSummaryId()); } - if (!llmEmbeddingId.isEmpty()) { - errorParams.put("embedding_id", llmEmbeddingId); + if (!getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", getLlmEmbeddingId()); } - NewRelic.noticeError("LlmError: " + statusText, errorParams); + NewRelic.noticeError("LlmError: " + getStatusText(), errorParams); } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java index bf71fd298c..ceb5d627db 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -56,6 +56,7 @@ protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeService @Trace public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + long startTime = System.currentTimeMillis(); // TODO name "Llm/" + operationType + "/Bedrock/InvokeModelAsync" ???? Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index 57c0e1db1f..f3ea752709 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -26,8 +26,7 @@ import java.util.logging.Level; import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; -import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionMessageEvent; -import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionSummaryEvent; +import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionEvents; import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmEmbeddingEvent; import static com.newrelic.utils.BedrockRuntimeUtil.setLlmOperationMetricName; import static com.newrelic.utils.InvokeModelResponseWrapper.COMPLETION; @@ -61,14 +60,15 @@ protected DefaultBedrockRuntimeClient_Instrumentation(BedrockRuntimeServiceClien @Trace public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { + long startTime = System.currentTimeMillis(); InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); incrementBedrockInstrumentedMetric(); Transaction txn = NewRelic.getAgent().getTransaction(); +// Transaction txn = AgentBridge.getAgent().getTransaction(); // TODO check AIM config if (txn != null && !(txn instanceof NoOpTransaction)) { - Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); InvokeModelRequestWrapper requestWrapper = new InvokeModelRequestWrapper(invokeModelRequest); InvokeModelResponseWrapper responseWrapper = new InvokeModelResponseWrapper(invokeModelResponse); @@ -79,10 +79,9 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { // Report LLM events if (operationType.equals(COMPLETION)) { - recordLlmChatCompletionMessageEvent(txn, linkingMetadata, requestWrapper, responseWrapper); - recordLlmChatCompletionSummaryEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + recordLlmChatCompletionEvents(startTime, linkingMetadata, requestWrapper, responseWrapper); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(txn, linkingMetadata, requestWrapper, responseWrapper); + recordLlmEmbeddingEvent(startTime, linkingMetadata, requestWrapper, responseWrapper); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type"); } From 170b360e5ce2351584ad5460568a015b241776d2 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 21 Feb 2024 14:05:58 -0800 Subject: [PATCH 05/68] Refactoring --- .../aws-bedrock-runtime-2.20/README.md | 10 + .../newrelic/utils/BedrockRuntimeUtil.java | 210 ----------- .../main/java/llm/models/ModelInvocation.java | 102 ++++++ .../AnthropicClaudeInvokeModelRequest.java} | 57 +-- .../AnthropicClaudeInvokeModelResponse.java} | 57 +-- .../AnthropicClaudeModelInvocation.java | 331 ++++++++++++++++++ ...ockRuntimeAsyncClient_Instrumentation.java | 5 +- ...tBedrockRuntimeClient_Instrumentation.java | 48 ++- 8 files changed, 526 insertions(+), 294 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/README.md delete mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java rename instrumentation/aws-bedrock-runtime-2.20/src/main/java/{com/newrelic/utils/InvokeModelRequestWrapper.java => llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java} (83%) rename instrumentation/aws-bedrock-runtime-2.20/src/main/java/{com/newrelic/utils/InvokeModelResponseWrapper.java => llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java} (86%) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md new file mode 100644 index 0000000000..d6f27b3820 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -0,0 +1,10 @@ +# AWS Bedrock Runtime Instrumentation + +## About + + +## Pieces + + +## Testing + diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java deleted file mode 100644 index 5a1eaedac7..0000000000 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/BedrockRuntimeUtil.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * - * * Copyright 2024 New Relic Corporation. All rights reserved. - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package com.newrelic.utils; - -import com.newrelic.api.agent.NewRelic; -import com.newrelic.api.agent.Transaction; - -import java.util.HashMap; -import java.util.Map; -import java.util.UUID; - -public class BedrockRuntimeUtil { - private static final String VENDOR = "bedrock"; - private static final String INGEST_SOURCE = "Java"; - private static final String TRACE_ID = "trace.id"; - private static final String SPAN_ID = "span.id"; - - // LLM event types - private static final String LLM_EMBEDDING = "LlmEmbedding"; - private static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; - private static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; - - /** - * This needs to be incremented for every invocation of the Bedrock SDK. - * Supportability/{language}/ML/{vendor_name}/{vendor_version} - *

- * The metric generated triggers the creation of a tag which gates the AI Response UI. The - * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and - * the UI will be hidden. - */ - public static void incrementBedrockInstrumentedMetric() { - // Bedrock vendor_version isn't available, so set it to instrumentation version instead - NewRelic.incrementCounter("Supportability/Java/ML/Bedrock/2.20"); - } - - /** - * Set name of the span/segment for each LLM embedding and chat completion call - * Llm/{operation_type}/{vendor_name}/{function_name} - * - * @param txn current transaction - * @param operationType operation of type completion or embedding - */ - public static void setLlmOperationMetricName(Transaction txn, String operationType) { - txn.getTracedMethod().setMetricName("Llm", operationType, "Bedrock", "invokeModel"); - } - - // TODO add event builders??? Avoid adding null/empty attributes? - - // TODO create a single recordLlmEvent method that can take a type. Always add attributes common to - // all types and add others based on conditionals - - public static void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata, InvokeModelRequestWrapper invokeModelRequestWrapper, - InvokeModelResponseWrapper invokeModelResponseWrapper) { - - // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point - invokeModelRequestWrapper.getModelId(); - - Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", invokeModelResponseWrapper.getLlmEmbeddingId()); - eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); - eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("input", invokeModelRequestWrapper.getInputText()); - eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); - eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. - eventAttributes.put("response.usage.total_tokens", invokeModelResponseWrapper.getTotalTokenCount()); - eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); - eventAttributes.put("vendor", getVendor()); - eventAttributes.put("ingest_source", getIngestSource()); - if (invokeModelResponseWrapper.isErrorResponse()) { - eventAttributes.put("error", true); - invokeModelResponseWrapper.reportLlmError(); - } - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? - - long endTime = System.currentTimeMillis(); - eventAttributes.put("duration", (endTime - startTime)); - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); - } - - private static void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata, - InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - - // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point - invokeModelRequestWrapper.getModelId(); - - Map eventAttributes = new HashMap<>(); - eventAttributes.put("id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); - eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); - eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("request.temperature", invokeModelRequestWrapper.getTemperature()); - eventAttributes.put("request.max_tokens", invokeModelRequestWrapper.getMaxTokensToSample()); - eventAttributes.put("request.model", invokeModelRequestWrapper.getModelId()); - eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. - eventAttributes.put("response.number_of_messages", numberOfMessages); - eventAttributes.put("response.usage.total_tokens", invokeModelResponseWrapper.getTotalTokenCount()); - eventAttributes.put("response.usage.prompt_tokens", invokeModelResponseWrapper.getInputTokenCount()); - eventAttributes.put("response.usage.completion_tokens", invokeModelResponseWrapper.getOutputTokenCount()); - eventAttributes.put("response.choices.finish_reason", invokeModelResponseWrapper.getStopReason()); - eventAttributes.put("vendor", getVendor()); - eventAttributes.put("ingest_source", getIngestSource()); - if (invokeModelResponseWrapper.isErrorResponse()) { - eventAttributes.put("error", true); - invokeModelResponseWrapper.reportLlmError(); - } - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! - - long endTime = System.currentTimeMillis(); - eventAttributes.put("duration", (endTime - startTime)); - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); - } - - private static void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata, - InvokeModelRequestWrapper invokeModelRequestWrapper, InvokeModelResponseWrapper invokeModelResponseWrapper) { - - // TODO filter based on which Bedrock model we're dealing with... this might need to be done before this point - invokeModelRequestWrapper.getModelId(); - - Map eventAttributes = new HashMap<>(); - eventAttributes.put("content", message); - - // FIXME id, content, role, and sequence - // This parsing might only apply to Claude? - if (message.contains("Human:")) { - eventAttributes.put("role", "user"); - eventAttributes.put("is_response", false); - } else { - String role = invokeModelRequestWrapper.getRole(); - if (!role.isEmpty()) { - eventAttributes.put("role", role); - if (!role.contains("user")) { - eventAttributes.put("is_response", true); - } - } - } - - eventAttributes.put("id", getRandomGuid()); - eventAttributes.put("request_id", invokeModelResponseWrapper.getAmznRequestId()); - eventAttributes.put("span_id", getSpanId(linkingMetadata)); - eventAttributes.put("trace_id", getTraceId(linkingMetadata)); - eventAttributes.put("response.model", invokeModelRequestWrapper.getModelId()); // For Bedrock it is the same as the request model. - eventAttributes.put("vendor", getVendor()); - eventAttributes.put("ingest_source", getIngestSource()); - eventAttributes.put("sequence", sequence); - eventAttributes.put("completion_id", invokeModelResponseWrapper.getLlmChatCompletionSummaryId()); - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! - - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); - } - - /** - * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event - * - * @param linkingMetadata - * @param requestWrapper - * @param responseWrapper - */ - public static void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata, - InvokeModelRequestWrapper requestWrapper, InvokeModelResponseWrapper responseWrapper) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, requestWrapper.getRequestMessage(), linkingMetadata, requestWrapper, responseWrapper); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, responseWrapper.getResponseMessage(), linkingMetadata, requestWrapper, responseWrapper); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata, requestWrapper, responseWrapper); - } - - // ========================= AGENT DATA ================================ - // Lowercased name of vendor (bedrock or openAI) - public static String getVendor() { - return VENDOR; - } - - // Name of the language agent (ex: Python, Node) - public static String getIngestSource() { - return INGEST_SOURCE; - } - - // GUID associated with the active trace - public static String getSpanId(Map linkingMetadata) { - return linkingMetadata.get(SPAN_ID); - } - - // ID of the current trace - public static String getTraceId(Map linkingMetadata) { - return linkingMetadata.get(TRACE_ID); - } - - // Returns a string representation of a random GUID - public static String getRandomGuid() { - return UUID.randomUUID().toString(); - } -} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java new file mode 100644 index 0000000000..82b5955d78 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -0,0 +1,102 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; + +import java.util.Map; +import java.util.UUID; + +// TODO make this an interface called LlmModel with some default methods, some methods that need to be implemented, and some useful constants +//public class Model { +public interface ModelInvocation { + String VENDOR = "bedrock"; + String BEDROCK = "Bedrock"; + String INGEST_SOURCE = "Java"; + String TRACE_ID = "trace.id"; + String SPAN_ID = "span.id"; + + // LLM event types + String LLM_EMBEDDING = "LlmEmbedding"; + String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + // Support models + String ANTHROPIC_CLAUDE = "claude"; + String AMAZON_TITAN = "titan"; + String META_LLAMA_2 = "llama"; + String COHERE_COMMAND = "cohere"; + String AI_21_LABS_JURASSIC = "jurassic"; + + /** + * Set name of the span/segment for each LLM embedding and chat completion call + * Llm/{operation_type}/{vendor_name}/{function_name} + * + * @param txn current transaction + */ + void setLlmOperationMetricName(Transaction txn, String functionName); + + void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata); + + void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata); + + void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata); + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata); + + void recordLlmEvents(long startTime, Map linkingMetadata); + + + /** + * This needs to be incremented for every invocation of the SDK. + * Supportability/{language}/ML/{vendor_name}/{vendor_version} + *

+ * The metric generated triggers the creation of a tag which gates the AI Response UI. The + * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and + * the UI will be hidden. + */ + static void incrementInstrumentedSupportabilityMetric() { + // Bedrock vendor_version isn't available, so set it to instrumentation version instead + NewRelic.incrementCounter("Supportability/Java/ML/Bedrock/2.20"); + } + + static void setLlmTrueAgentAttribute(Transaction txn) { + // If in a txn with LLM-related spans + txn.getAgentAttributes().put("llm", true); + } + + // Lowercased name of vendor (bedrock or openAI) + static String getVendor() { + return VENDOR; + } + + // Name of the language agent (ex: Python, Node) + static String getIngestSource() { + return INGEST_SOURCE; + } + + // GUID associated with the active trace + static String getSpanId(Map linkingMetadata) { + return linkingMetadata.get(SPAN_ID); + } + + // ID of the current trace + static String getTraceId(Map linkingMetadata) { + return linkingMetadata.get(TRACE_ID); + } + + // Returns a string representation of a random GUID + static String getRandomGuid() { + return UUID.randomUUID().toString(); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java similarity index 83% rename from instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java rename to instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java index 2e625d88e0..890f49dcf7 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelRequestWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java @@ -5,7 +5,7 @@ * */ -package com.newrelic.utils; +package llm.models.anthropic.claude; import com.newrelic.api.agent.NewRelic; import software.amazon.awssdk.protocols.jsoncore.JsonNode; @@ -21,9 +21,10 @@ * Stores the required info from the Bedrock InvokeModelRequest * but doesn't hold a reference to the actual request object. */ -public class InvokeModelRequestWrapper { +// TODO create an interface +public class AnthropicClaudeInvokeModelRequest { // Request body (for Claude, how about other models?) - private static final String STOP_SEQUENCES = "stop_sequences"; +// private static final String STOP_SEQUENCES = "stop_sequences"; private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; @@ -37,7 +38,7 @@ public class InvokeModelRequestWrapper { private String modelId = ""; private Map requestBodyJsonMap = null; - public InvokeModelRequestWrapper(InvokeModelRequest invokeModelRequest) { + public AnthropicClaudeInvokeModelRequest(InvokeModelRequest invokeModelRequest) { if (invokeModelRequest != null) { invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); modelId = invokeModelRequest.modelId(); @@ -84,30 +85,30 @@ private Map parseInvokeModelRequestBodyMap() { // TODO do we potentially expect more than one entry in the stop sequence? Or is it sufficient // to just check if it contains Human? DO we even need this at all? Doesn't look like we do - public String getStopSequences() { - StringBuilder stopSequences = new StringBuilder(); - try { - if (!getRequestBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getRequestBodyJsonMap().get(STOP_SEQUENCES); - if (jsonNode.isArray()) { - List jsonNodeArray = jsonNode.asArray(); - for (JsonNode node : jsonNodeArray) { - if (node.isString()) { - // Don't add comma for first node - if (stopSequences.length() <= 0) { - stopSequences.append(node.asString()); - } else { - stopSequences.append(",").append(node.asString()); - } - } - } - } - } - } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_SEQUENCES); - } - return stopSequences.toString().replaceAll("[\n:]", ""); - } +// public String getStopSequences() { +// StringBuilder stopSequences = new StringBuilder(); +// try { +// if (!getRequestBodyJsonMap().isEmpty()) { +// JsonNode jsonNode = getRequestBodyJsonMap().get(STOP_SEQUENCES); +// if (jsonNode.isArray()) { +// List jsonNodeArray = jsonNode.asArray(); +// for (JsonNode node : jsonNodeArray) { +// if (node.isString()) { +// // Don't add comma for first node +// if (stopSequences.length() <= 0) { +// stopSequences.append(node.asString()); +// } else { +// stopSequences.append(",").append(node.asString()); +// } +// } +// } +// } +// } +// } catch (Exception e) { +// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_SEQUENCES); +// } +// return stopSequences.toString().replaceAll("[\n:]", ""); +// } public String getMaxTokensToSample() { String maxTokensToSample = ""; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java similarity index 86% rename from instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java rename to instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java index a8010a63bd..927e13b2b2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/com/newrelic/utils/InvokeModelResponseWrapper.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java @@ -5,7 +5,7 @@ * */ -package com.newrelic.utils; +package llm.models.anthropic.claude; import com.newrelic.api.agent.NewRelic; import software.amazon.awssdk.protocols.jsoncore.JsonNode; @@ -19,28 +19,29 @@ import java.util.Optional; import java.util.logging.Level; -import static com.newrelic.utils.BedrockRuntimeUtil.getRandomGuid; +import static llm.models.ModelInvocation.getRandomGuid; /** * Stores the required info from the Bedrock InvokeModelResponse * but doesn't hold a reference to the actual response object. */ -public class InvokeModelResponseWrapper { +// TODO create an interface +public class AnthropicClaudeInvokeModelResponse { // Response body (for Claude, how about other models?) public static final String COMPLETION = "completion"; public static final String EMBEDDING = "embedding"; private static final String STOP_REASON = "stop_reason"; - private static final String STOP = "stop"; +// private static final String STOP = "stop"; // Response headers private static final String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; private static final String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; private static final String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; - private static final String X_AMZN_BEDROCK_INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; +// private static final String X_AMZN_BEDROCK_INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; private int inputTokenCount = 0; private int outputTokenCount = 0; private String amznRequestId = ""; - private String invocationLatency = ""; +// private String invocationLatency = ""; // LLM operation type private String operationType = ""; @@ -60,7 +61,7 @@ public class InvokeModelResponseWrapper { private static final String JSON_START = "{\""; - public InvokeModelResponseWrapper(InvokeModelResponse invokeModelResponse) { + public AnthropicClaudeInvokeModelResponse(InvokeModelResponse invokeModelResponse) { if (invokeModelResponse != null) { invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); @@ -147,10 +148,10 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { amznRequestId = amznRequestIdHeaders.get(0); // TODO does this differ from invokeModelResponse.responseMetadata().requestId() } - List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); - if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { - invocationLatency = invocationLatencyHeaders.get(0); - } +// List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); +// if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { +// invocationLatency = invocationLatencyHeaders.get(0); +// } } } catch (Exception e) { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse headers"); @@ -192,20 +193,20 @@ public String getStopReason() { return stopReason; } - public String getStop() { - String stop = ""; - try { - if (!getResponseBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getResponseBodyJsonMap().get(STOP); - if (jsonNode.isString()) { - stop = jsonNode.asString(); - } - } - } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP); - } - return stop.replaceAll("[\n:]", ""); - } +// public String getStop() { +// String stop = ""; +// try { +// if (!getResponseBodyJsonMap().isEmpty()) { +// JsonNode jsonNode = getResponseBodyJsonMap().get(STOP); +// if (jsonNode.isString()) { +// stop = jsonNode.asString(); +// } +// } +// } catch (Exception e) { +// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP); +// } +// return stop.replaceAll("[\n:]", ""); +// } public int getInputTokenCount() { return inputTokenCount; @@ -223,9 +224,9 @@ public String getAmznRequestId() { return amznRequestId; } - public String getInvocationLatency() { - return invocationLatency; - } +// public String getInvocationLatency() { +// return invocationLatency; +// } public String getOperationType() { return operationType; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java new file mode 100644 index 0000000000..9d28528a68 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -0,0 +1,331 @@ +package llm.models.anthropic.claude; + +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelInvocation; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.COMPLETION; +import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; + +public class AnthropicClaudeModelInvocation implements ModelInvocation { + AnthropicClaudeInvokeModelRequest anthropicClaudeInvokeModelRequest; + AnthropicClaudeInvokeModelResponse anthropicClaudeInvokeModelResponse; + + public AnthropicClaudeModelInvocation(InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { + anthropicClaudeInvokeModelRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); + anthropicClaudeInvokeModelResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); + } + + @Override + public void setLlmOperationMetricName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", anthropicClaudeInvokeModelResponse.getOperationType(), BEDROCK, functionName); + } + + // TODO add event builders??? + @Override + public void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata) { + Map eventAttributes = new HashMap<>(); + // Generic attributes that are constant for all Bedrock Models + addSpanId(eventAttributes, linkingMetadata); + addTraceId(eventAttributes, linkingMetadata); + addVendor(eventAttributes); + addIngestSource(eventAttributes); + + // Attributes dependent on the request/response + addId(eventAttributes, anthropicClaudeInvokeModelResponse.getLlmEmbeddingId()); + addRequestId(eventAttributes); + addInput(eventAttributes); + addRequestModel(eventAttributes); + addResponseModel(eventAttributes); + addResponseUsageTotalTokens(eventAttributes); + addResponseUsagePromptTokens(eventAttributes); + + // Error attributes + if (anthropicClaudeInvokeModelResponse.isErrorResponse()) { + addError(eventAttributes); + anthropicClaudeInvokeModelResponse.reportLlmError(); + } + + // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed + long endTime = System.currentTimeMillis(); + addDuration(eventAttributes, (endTime - startTime)); + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! + + } + + @Override + public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata) { + Map eventAttributes = new HashMap<>(); + // Generic attributes that are constant for all Bedrock Models + addSpanId(eventAttributes, linkingMetadata); + addTraceId(eventAttributes, linkingMetadata); + addVendor(eventAttributes); + addIngestSource(eventAttributes); + + // Attributes dependent on the request/response + addId(eventAttributes, anthropicClaudeInvokeModelResponse.getLlmChatCompletionSummaryId()); + addRequestId(eventAttributes); + addRequestTemperature(eventAttributes); + addRequestMaxTokens(eventAttributes); + addRequestModel(eventAttributes); + addResponseModel(eventAttributes); + addResponseNumberOfMessages(eventAttributes, numberOfMessages); + addResponseUsageTotalTokens(eventAttributes); + addResponseUsagePromptTokens(eventAttributes); + addResponseUsageCompletionTokens(eventAttributes); + addResponseChoicesFinishReason(eventAttributes); + + // Error attributes + if (anthropicClaudeInvokeModelResponse.isErrorResponse()) { + addError(eventAttributes); + anthropicClaudeInvokeModelResponse.reportLlmError(); + } + + // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed + long endTime = System.currentTimeMillis(); + addDuration(eventAttributes, (endTime - startTime)); + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata) { + Map eventAttributes = new HashMap<>(); + // Generic attributes that are constant for all Bedrock Models + addSpanId(eventAttributes, linkingMetadata); + addTraceId(eventAttributes, linkingMetadata); + addVendor(eventAttributes); + addIngestSource(eventAttributes); + + // Multiple completion message events can be created per transaction so generate an id on the fly instead of storing each in the response/request wrapper + addId(eventAttributes, ModelInvocation.getRandomGuid()); + + // Attributes dependent on the request/response + addContent(eventAttributes, message); + if (message.contains("Human:")) { + addRole(eventAttributes, "user"); + addIsResponse(eventAttributes, false); + } else { + String role = anthropicClaudeInvokeModelRequest.getRole(); + if (!role.isEmpty()) { + addRole(eventAttributes, role); + if (!role.contains("user")) { + addIsResponse(eventAttributes, true); + } + } + } + addRequestId(eventAttributes); + addResponseModel(eventAttributes); + addSequence(eventAttributes, sequence); + addCompletionId(eventAttributes); + + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); + + // TODO is it possible to do something like this to call getUserAttributes? + // see com.newrelic.agent.bridge.Transaction +// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? +// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! + } + + @Override + public void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, anthropicClaudeInvokeModelRequest.getRequestMessage(), linkingMetadata); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, anthropicClaudeInvokeModelResponse.getResponseMessage(), linkingMetadata); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata); + } + + @Override + public void recordLlmEvents(long startTime, Map linkingMetadata) { + String operationType = anthropicClaudeInvokeModelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime, linkingMetadata); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvent(startTime, linkingMetadata); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + // TODO can all of these helper methods be moved to the ModelInvocation interface??? + private void addSpanId(Map eventAttributes, Map linkingMetadata) { + String spanId = ModelInvocation.getSpanId(linkingMetadata); + if (spanId != null && !spanId.isEmpty()) { + eventAttributes.put("span_id", spanId); + } + } + + private void addTraceId(Map eventAttributes, Map linkingMetadata) { + String traceId = ModelInvocation.getTraceId(linkingMetadata); + if (traceId != null && !traceId.isEmpty()) { + eventAttributes.put("trace_id", traceId); + } + } + + private void addVendor(Map eventAttributes) { + String vendor = ModelInvocation.getVendor(); + if (vendor != null && !vendor.isEmpty()) { + eventAttributes.put("vendor", vendor); + } + } + + private void addIngestSource(Map eventAttributes) { + String ingestSource = ModelInvocation.getIngestSource(); + if (ingestSource != null && !ingestSource.isEmpty()) { + eventAttributes.put("ingest_source", ingestSource); + } + } + + private void addId(Map eventAttributes, String id) { + if (id != null && !id.isEmpty()) { + eventAttributes.put("id", id); + } + } + + private void addContent(Map eventAttributes, String message) { + if (message != null && !message.isEmpty()) { + eventAttributes.put("content", message); + } + } + + private void addRole(Map eventAttributes, String role) { + if (role != null && !role.isEmpty()) { + eventAttributes.put("role", role); + } + } + + private void addIsResponse(Map eventAttributes, boolean isResponse) { + eventAttributes.put("is_response", isResponse); + } + + private void addSequence(Map eventAttributes, int sequence) { + if (sequence >= 0) { + eventAttributes.put("sequence", sequence); + } + } + + private void addResponseNumberOfMessages(Map eventAttributes, int numberOfMessages) { + if (numberOfMessages >= 0) { + eventAttributes.put("response.number_of_messages", numberOfMessages); + } + } + + private void addDuration(Map eventAttributes, long duration) { + if (duration >= 0) { + eventAttributes.put("duration", duration); + } + } + + private void addError(Map eventAttributes) { + eventAttributes.put("error", true); + } + + private void addInput(Map eventAttributes) { + // TODO modify to pass in Request interface if moving to the ModelInvocation interface + String inputText = anthropicClaudeInvokeModelRequest.getInputText(); + if (inputText != null && !inputText.isEmpty()) { + eventAttributes.put("input", inputText); + } + } + + private void addRequestTemperature(Map eventAttributes) { + // TODO modify to pass in Request interface if moving to the ModelInvocation interface + String temperature = anthropicClaudeInvokeModelRequest.getTemperature(); + if (temperature != null && !temperature.isEmpty()) { + eventAttributes.put("request.temperature", temperature); + } + } + + private void addRequestMaxTokens(Map eventAttributes) { + // TODO modify to pass in Request interface if moving to the ModelInvocation interface + String maxTokensToSample = anthropicClaudeInvokeModelRequest.getMaxTokensToSample(); + if (maxTokensToSample != null && !maxTokensToSample.isEmpty()) { + eventAttributes.put("request.max_tokens", maxTokensToSample); + } + } + + private void addRequestModel(Map eventAttributes) { + // TODO modify to pass in Request interface if moving to the ModelInvocation interface + String modelId = anthropicClaudeInvokeModelRequest.getModelId(); + if (modelId != null && !modelId.isEmpty()) { + eventAttributes.put("request.model", modelId); + } + } + + private void addResponseModel(Map eventAttributes) { + // TODO modify to pass in Request interface if moving to the ModelInvocation interface + // For Bedrock the response model is the same as the request model. + String modelId = anthropicClaudeInvokeModelRequest.getModelId(); + if (modelId != null && !modelId.isEmpty()) { + eventAttributes.put("response.model", modelId); + } + } + + private void addRequestId(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + String requestId = anthropicClaudeInvokeModelResponse.getAmznRequestId(); + if (requestId != null && !requestId.isEmpty()) { + eventAttributes.put("request_id", requestId); + } + } + + private void addCompletionId(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + String llmChatCompletionSummaryId = anthropicClaudeInvokeModelResponse.getLlmChatCompletionSummaryId(); + if (llmChatCompletionSummaryId != null && !llmChatCompletionSummaryId.isEmpty()) { + eventAttributes.put("completion_id", llmChatCompletionSummaryId); + } + } + + private void addResponseUsageTotalTokens(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + int totalTokenCount = anthropicClaudeInvokeModelResponse.getTotalTokenCount(); + if (totalTokenCount >= 0) { + eventAttributes.put("response.usage.total_tokens", totalTokenCount); + } + } + + private void addResponseUsagePromptTokens(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + int inputTokenCount = anthropicClaudeInvokeModelResponse.getInputTokenCount(); + if (inputTokenCount >= 0) { + eventAttributes.put("response.usage.prompt_tokens", inputTokenCount); + } + } + + private void addResponseUsageCompletionTokens(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + int outputTokenCount = anthropicClaudeInvokeModelResponse.getOutputTokenCount(); + if (outputTokenCount >= 0) { + eventAttributes.put("response.usage.completion_tokens", outputTokenCount); + } + } + + private void addResponseChoicesFinishReason(Map eventAttributes) { + // TODO modify to pass in Response interface if moving to the ModelInvocation interface + String stopReason = anthropicClaudeInvokeModelResponse.getStopReason(); + if (stopReason != null && !stopReason.isEmpty()) { + eventAttributes.put("response.choices.finish_reason", stopReason); + } + } + +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java index ceb5d627db..bff0955fb3 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -23,11 +23,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; +import static llm.models.ModelInvocation.incrementInstrumentedSupportabilityMetric; /** * Service client for accessing Amazon Bedrock Runtime asynchronously. */ +// TODO switch back to instrumenting the BedrockRuntimeAsyncClient interface instead of this implementation class @Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient") final class DefaultBedrockRuntimeAsyncClient_Instrumentation { // private static final Logger log = LoggerFactory.getLogger(DefaultBedrockRuntimeAsyncClient.class); @@ -61,7 +62,7 @@ public CompletableFuture invokeModel(InvokeModelRequest inv Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); - incrementBedrockInstrumentedMetric(); + incrementInstrumentedSupportabilityMetric(); // this should never happen, but protecting against bad implementations if (invokeModelResponseFuture == null) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index f3ea752709..2a9c815c23 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -7,15 +7,16 @@ package software.amazon.awssdk.services.bedrockruntime; +import com.newrelic.agent.bridge.AgentBridge; import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Trace; -import com.newrelic.api.agent.Transaction; import com.newrelic.api.agent.weaver.MatchType; import com.newrelic.api.agent.weaver.Weave; import com.newrelic.api.agent.weaver.Weaver; -import com.newrelic.utils.InvokeModelRequestWrapper; -import com.newrelic.utils.InvokeModelResponseWrapper; +import llm.models.ModelInvocation; +import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; @@ -25,16 +26,14 @@ import java.util.Map; import java.util.logging.Level; -import static com.newrelic.utils.BedrockRuntimeUtil.incrementBedrockInstrumentedMetric; -import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmChatCompletionEvents; -import static com.newrelic.utils.BedrockRuntimeUtil.recordLlmEmbeddingEvent; -import static com.newrelic.utils.BedrockRuntimeUtil.setLlmOperationMetricName; -import static com.newrelic.utils.InvokeModelResponseWrapper.COMPLETION; -import static com.newrelic.utils.InvokeModelResponseWrapper.EMBEDDING; +import static llm.models.ModelInvocation.ANTHROPIC_CLAUDE; +import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.COMPLETION; +import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; /** * Service client for accessing Amazon Bedrock Runtime. */ +// TODO switch back to instrumenting the BedrockRuntimeClient interface instead of this implementation class @Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient") final class DefaultBedrockRuntimeClient_Instrumentation { // private static final Logger log = Logger.loggerFor(DefaultBedrockRuntimeClient.class); @@ -63,27 +62,24 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { long startTime = System.currentTimeMillis(); InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); - incrementBedrockInstrumentedMetric(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(); - Transaction txn = NewRelic.getAgent().getTransaction(); -// Transaction txn = AgentBridge.getAgent().getTransaction(); +// Transaction txn = NewRelic.getAgent().getTransaction(); + Transaction txn = AgentBridge.getAgent().getTransaction(); // TODO check AIM config if (txn != null && !(txn instanceof NoOpTransaction)) { Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); - InvokeModelRequestWrapper requestWrapper = new InvokeModelRequestWrapper(invokeModelRequest); - InvokeModelResponseWrapper responseWrapper = new InvokeModelResponseWrapper(invokeModelResponse); - - String operationType = responseWrapper.getOperationType(); - // Set traced method name based on LLM operation - setLlmOperationMetricName(txn, operationType); - - // Report LLM events - if (operationType.equals(COMPLETION)) { - recordLlmChatCompletionEvents(startTime, linkingMetadata, requestWrapper, responseWrapper); - } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime, linkingMetadata, requestWrapper, responseWrapper); - } else { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type"); + Map userAttributes = txn.getUserAttributes(); + + String modelId = invokeModelRequest.modelId(); + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + AnthropicClaudeModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation + anthropicClaudeModelInvocation.setLlmOperationMetricName(txn, "invokeModel"); + // Set llm = true agent attribute + ModelInvocation.setLlmTrueAgentAttribute(txn); + anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); } } From db6b3590cb6cce15c7c42d89487232c91b7b878b Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 21 Feb 2024 15:18:55 -0800 Subject: [PATCH 06/68] Refactoring --- .../main/java/llm/models/ModelInvocation.java | 10 +- .../main/java/llm/models/ModelRequest.java | 15 ++ .../main/java/llm/models/ModelResponse.java | 32 ++++ .../AnthropicClaudeInvokeModelRequest.java | 64 ++------ .../AnthropicClaudeInvokeModelResponse.java | 58 ++----- .../AnthropicClaudeModelInvocation.java | 142 +++++++++--------- ...tBedrockRuntimeClient_Instrumentation.java | 5 +- 7 files changed, 142 insertions(+), 184 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 82b5955d78..bb043d02fa 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -13,8 +13,6 @@ import java.util.Map; import java.util.UUID; -// TODO make this an interface called LlmModel with some default methods, some methods that need to be implemented, and some useful constants -//public class Model { public interface ModelInvocation { String VENDOR = "bedrock"; String BEDROCK = "Bedrock"; @@ -27,7 +25,7 @@ public interface ModelInvocation { String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; - // Support models + // Supported models String ANTHROPIC_CLAUDE = "claude"; String AMAZON_TITAN = "titan"; String META_LLAMA_2 = "llama"; @@ -48,12 +46,6 @@ public interface ModelInvocation { void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata); - /** - * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. - * The number of LlmChatCompletionMessage events produced can differ based on vendor. - */ - void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata); - void recordLlmEvents(long startTime, Map linkingMetadata); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java new file mode 100644 index 0000000000..9575cddf81 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -0,0 +1,15 @@ +package llm.models; + +public interface ModelRequest { + String getMaxTokensToSample(); + + String getTemperature(); + + String getRequestMessage(); + + String getRole(); + + String getInputText(); + + String getModelId(); +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java new file mode 100644 index 0000000000..9f016fb101 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -0,0 +1,32 @@ +package llm.models; + +public interface ModelResponse { + String COMPLETION = "completion"; + String EMBEDDING = "embedding"; + + String getResponseMessage(); + + String getStopReason(); + + int getInputTokenCount(); + + int getOutputTokenCount(); + + int getTotalTokenCount(); + + String getAmznRequestId(); + + String getOperationType(); + + String getLlmChatCompletionSummaryId(); + + String getLlmEmbeddingId(); + + boolean isErrorResponse(); + + int getStatusCode(); + + String getStatusText(); + + void reportLlmError(); +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java index 890f49dcf7..b6b7658dd4 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java @@ -13,7 +13,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.logging.Level; @@ -21,10 +20,9 @@ * Stores the required info from the Bedrock InvokeModelRequest * but doesn't hold a reference to the actual request object. */ -// TODO create an interface -public class AnthropicClaudeInvokeModelRequest { - // Request body (for Claude, how about other models?) -// private static final String STOP_SEQUENCES = "stop_sequences"; +public class AnthropicClaudeInvokeModelRequest implements llm.models.ModelRequest { + // TODO might be able to move some of these constants to the ModelRequest interface + // need to figure out if they are consistent across all models private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; @@ -83,33 +81,7 @@ private Map parseInvokeModelRequestBodyMap() { return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); } - // TODO do we potentially expect more than one entry in the stop sequence? Or is it sufficient - // to just check if it contains Human? DO we even need this at all? Doesn't look like we do -// public String getStopSequences() { -// StringBuilder stopSequences = new StringBuilder(); -// try { -// if (!getRequestBodyJsonMap().isEmpty()) { -// JsonNode jsonNode = getRequestBodyJsonMap().get(STOP_SEQUENCES); -// if (jsonNode.isArray()) { -// List jsonNodeArray = jsonNode.asArray(); -// for (JsonNode node : jsonNodeArray) { -// if (node.isString()) { -// // Don't add comma for first node -// if (stopSequences.length() <= 0) { -// stopSequences.append(node.asString()); -// } else { -// stopSequences.append(",").append(node.asString()); -// } -// } -// } -// } -// } -// } catch (Exception e) { -// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_SEQUENCES); -// } -// return stopSequences.toString().replaceAll("[\n:]", ""); -// } - + @Override public String getMaxTokensToSample() { String maxTokensToSample = ""; try { @@ -125,6 +97,7 @@ public String getMaxTokensToSample() { return maxTokensToSample; } + @Override public String getTemperature() { String temperature = ""; try { @@ -140,6 +113,7 @@ public String getTemperature() { return temperature; } + @Override public String getRequestMessage() { String prompt = ""; try { @@ -155,29 +129,7 @@ public String getRequestMessage() { return prompt; } -// /** -// * Represents the user prompt messages -// * -// * @return -// */ -// public List getUserRequestMessages() { -// // FIXME shouldn't parse the request, just send the whole content -// List messageList = null; -// try { -// if (!getRequestBodyJsonMap().isEmpty()) { -// JsonNode jsonNode = getRequestBodyJsonMap().get(PROMPT); -// if (jsonNode.isString()) { -// String[] messages = jsonNode.asString().split(ESCAPED_NEWLINES); -// messageList = new ArrayList<>(Arrays.asList(messages)); -// messageList.remove("Assistant:"); -// } -// } -// } catch (Exception e) { -// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); -// } -// return messageList; -// } - + @Override public String getRole() { try { if (!invokeModelRequestBody.isEmpty()) { @@ -196,6 +148,7 @@ public String getRole() { return ""; } + @Override public String getInputText() { String inputText = ""; try { @@ -211,6 +164,7 @@ public String getInputText() { return inputText; } + @Override public String getModelId() { return modelId; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java index 927e13b2b2..e275cb62f5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java @@ -25,23 +25,17 @@ * Stores the required info from the Bedrock InvokeModelResponse * but doesn't hold a reference to the actual response object. */ -// TODO create an interface -public class AnthropicClaudeInvokeModelResponse { - // Response body (for Claude, how about other models?) - public static final String COMPLETION = "completion"; - public static final String EMBEDDING = "embedding"; +public class AnthropicClaudeInvokeModelResponse implements llm.models.ModelResponse { private static final String STOP_REASON = "stop_reason"; -// private static final String STOP = "stop"; // Response headers private static final String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; private static final String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; private static final String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; -// private static final String X_AMZN_BEDROCK_INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; + private int inputTokenCount = 0; private int outputTokenCount = 0; private String amznRequestId = ""; -// private String invocationLatency = ""; // LLM operation type private String operationType = ""; @@ -51,8 +45,6 @@ public class AnthropicClaudeInvokeModelResponse { private int statusCode = 0; private String statusText = ""; - // Random GUID for response -// private String llmChatCompletionMessageId = ""; private String llmChatCompletionSummaryId = ""; private String llmEmbeddingId = ""; @@ -70,7 +62,6 @@ public AnthropicClaudeInvokeModelResponse(InvokeModelResponse invokeModelRespons statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); setHeaderFields(invokeModelResponse); -// llmChatCompletionMessageId = BedrockRuntimeUtil.getRandomGuid(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -146,12 +137,8 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { } List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); // TODO does this differ from invokeModelResponse.responseMetadata().requestId() + amznRequestId = amznRequestIdHeaders.get(0); } -// List invocationLatencyHeaders = headers.get(X_AMZN_BEDROCK_INVOCATION_LATENCY); -// if (invocationLatencyHeaders != null && !invocationLatencyHeaders.isEmpty()) { -// invocationLatency = invocationLatencyHeaders.get(0); -// } } } catch (Exception e) { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse headers"); @@ -163,6 +150,7 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { * * @return */ + @Override public String getResponseMessage() { String completion = ""; try { @@ -178,6 +166,7 @@ public String getResponseMessage() { return completion; } + @Override public String getStopReason() { String stopReason = ""; try { @@ -193,72 +182,57 @@ public String getStopReason() { return stopReason; } -// public String getStop() { -// String stop = ""; -// try { -// if (!getResponseBodyJsonMap().isEmpty()) { -// JsonNode jsonNode = getResponseBodyJsonMap().get(STOP); -// if (jsonNode.isString()) { -// stop = jsonNode.asString(); -// } -// } -// } catch (Exception e) { -// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP); -// } -// return stop.replaceAll("[\n:]", ""); -// } - + @Override public int getInputTokenCount() { return inputTokenCount; } + @Override public int getOutputTokenCount() { return outputTokenCount; } + @Override public int getTotalTokenCount() { return inputTokenCount + outputTokenCount; } + @Override public String getAmznRequestId() { return amznRequestId; } -// public String getInvocationLatency() { -// return invocationLatency; -// } - + @Override public String getOperationType() { return operationType; } - // TODO stop saving these GUIDS, instead just create a Util method to generate a random one each call?? -// public String getLlmChatCompletionMessageId() { -// return llmChatCompletionMessageId; -// } - - // hmmm this one needs to be stored as it's used for completion_id in the message events + @Override public String getLlmChatCompletionSummaryId() { return llmChatCompletionSummaryId; } - // hmmm also needs to be stored to be used for embedding_id in errors + @Override public String getLlmEmbeddingId() { return llmEmbeddingId; } + @Override public boolean isErrorResponse() { return !isSuccessfulResponse; } + @Override public int getStatusCode() { return statusCode; } + @Override public String getStatusText() { return statusText; } + @Override public void reportLlmError() { Map errorParams = new HashMap<>(); errorParams.put("http.statusCode", getStatusCode()); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 9d28528a68..1e3c446bf8 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -3,6 +3,8 @@ import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -14,17 +16,17 @@ import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; public class AnthropicClaudeModelInvocation implements ModelInvocation { - AnthropicClaudeInvokeModelRequest anthropicClaudeInvokeModelRequest; - AnthropicClaudeInvokeModelResponse anthropicClaudeInvokeModelResponse; + ModelRequest claudeRequest; + ModelResponse claudeResponse; public AnthropicClaudeModelInvocation(InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { - anthropicClaudeInvokeModelRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); - anthropicClaudeInvokeModelResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); + claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); + claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); } @Override public void setLlmOperationMetricName(Transaction txn, String functionName) { - txn.getTracedMethod().setMetricName("Llm", anthropicClaudeInvokeModelResponse.getOperationType(), BEDROCK, functionName); + txn.getTracedMethod().setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); } // TODO add event builders??? @@ -38,18 +40,18 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM addIngestSource(eventAttributes); // Attributes dependent on the request/response - addId(eventAttributes, anthropicClaudeInvokeModelResponse.getLlmEmbeddingId()); - addRequestId(eventAttributes); - addInput(eventAttributes); - addRequestModel(eventAttributes); - addResponseModel(eventAttributes); - addResponseUsageTotalTokens(eventAttributes); - addResponseUsagePromptTokens(eventAttributes); + addId(eventAttributes, claudeResponse.getLlmEmbeddingId()); + addRequestId(eventAttributes, claudeResponse); + addInput(eventAttributes, claudeRequest); + addRequestModel(eventAttributes, claudeRequest); + addResponseModel(eventAttributes, claudeRequest); + addResponseUsageTotalTokens(eventAttributes, claudeResponse); + addResponseUsagePromptTokens(eventAttributes, claudeResponse); // Error attributes - if (anthropicClaudeInvokeModelResponse.isErrorResponse()) { + if (claudeResponse.isErrorResponse()) { addError(eventAttributes); - anthropicClaudeInvokeModelResponse.reportLlmError(); + claudeResponse.reportLlmError(); } // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed @@ -75,22 +77,22 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start addIngestSource(eventAttributes); // Attributes dependent on the request/response - addId(eventAttributes, anthropicClaudeInvokeModelResponse.getLlmChatCompletionSummaryId()); - addRequestId(eventAttributes); - addRequestTemperature(eventAttributes); - addRequestMaxTokens(eventAttributes); - addRequestModel(eventAttributes); - addResponseModel(eventAttributes); + addId(eventAttributes, claudeResponse.getLlmChatCompletionSummaryId()); + addRequestId(eventAttributes, claudeResponse); + addRequestTemperature(eventAttributes, claudeRequest); + addRequestMaxTokens(eventAttributes, claudeRequest); + addRequestModel(eventAttributes, claudeRequest); + addResponseModel(eventAttributes, claudeRequest); addResponseNumberOfMessages(eventAttributes, numberOfMessages); - addResponseUsageTotalTokens(eventAttributes); - addResponseUsagePromptTokens(eventAttributes); - addResponseUsageCompletionTokens(eventAttributes); - addResponseChoicesFinishReason(eventAttributes); + addResponseUsageTotalTokens(eventAttributes, claudeResponse); + addResponseUsagePromptTokens(eventAttributes, claudeResponse); + addResponseUsageCompletionTokens(eventAttributes, claudeResponse); + addResponseChoicesFinishReason(eventAttributes, claudeResponse); // Error attributes - if (anthropicClaudeInvokeModelResponse.isErrorResponse()) { + if (claudeResponse.isErrorResponse()) { addError(eventAttributes); - anthropicClaudeInvokeModelResponse.reportLlmError(); + claudeResponse.reportLlmError(); } // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed @@ -123,7 +125,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, Ma addRole(eventAttributes, "user"); addIsResponse(eventAttributes, false); } else { - String role = anthropicClaudeInvokeModelRequest.getRole(); + String role = claudeRequest.getRole(); if (!role.isEmpty()) { addRole(eventAttributes, role); if (!role.contains("user")) { @@ -131,10 +133,10 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, Ma } } } - addRequestId(eventAttributes); - addResponseModel(eventAttributes); + addRequestId(eventAttributes, claudeResponse); + addResponseModel(eventAttributes, claudeRequest); addSequence(eventAttributes, sequence); - addCompletionId(eventAttributes); + addCompletionId(eventAttributes, claudeResponse); NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); @@ -144,19 +146,9 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, Ma // eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! } - @Override - public void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, anthropicClaudeInvokeModelRequest.getRequestMessage(), linkingMetadata); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, anthropicClaudeInvokeModelResponse.getResponseMessage(), linkingMetadata); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata); - } - @Override public void recordLlmEvents(long startTime, Map linkingMetadata) { - String operationType = anthropicClaudeInvokeModelResponse.getOperationType(); + String operationType = claudeResponse.getOperationType(); if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime, linkingMetadata); } else if (operationType.equals(EMBEDDING)) { @@ -166,6 +158,19 @@ public void recordLlmEvents(long startTime, Map linkingMetadata) } } + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, claudeRequest.getRequestMessage(), linkingMetadata); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, claudeResponse.getResponseMessage(), linkingMetadata); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata); + } + // TODO can all of these helper methods be moved to the ModelInvocation interface??? private void addSpanId(Map eventAttributes, Map linkingMetadata) { String spanId = ModelInvocation.getSpanId(linkingMetadata); @@ -239,90 +244,79 @@ private void addError(Map eventAttributes) { eventAttributes.put("error", true); } - private void addInput(Map eventAttributes) { - // TODO modify to pass in Request interface if moving to the ModelInvocation interface - String inputText = anthropicClaudeInvokeModelRequest.getInputText(); + private void addInput(Map eventAttributes, ModelRequest modelRequest) { + String inputText = modelRequest.getInputText(); if (inputText != null && !inputText.isEmpty()) { eventAttributes.put("input", inputText); } } - private void addRequestTemperature(Map eventAttributes) { - // TODO modify to pass in Request interface if moving to the ModelInvocation interface - String temperature = anthropicClaudeInvokeModelRequest.getTemperature(); + private void addRequestTemperature(Map eventAttributes, ModelRequest modelRequest) { + String temperature = modelRequest.getTemperature(); if (temperature != null && !temperature.isEmpty()) { eventAttributes.put("request.temperature", temperature); } } - private void addRequestMaxTokens(Map eventAttributes) { - // TODO modify to pass in Request interface if moving to the ModelInvocation interface - String maxTokensToSample = anthropicClaudeInvokeModelRequest.getMaxTokensToSample(); + private void addRequestMaxTokens(Map eventAttributes, ModelRequest modelRequest) { + String maxTokensToSample = modelRequest.getMaxTokensToSample(); if (maxTokensToSample != null && !maxTokensToSample.isEmpty()) { eventAttributes.put("request.max_tokens", maxTokensToSample); } } - private void addRequestModel(Map eventAttributes) { - // TODO modify to pass in Request interface if moving to the ModelInvocation interface - String modelId = anthropicClaudeInvokeModelRequest.getModelId(); + private void addRequestModel(Map eventAttributes, ModelRequest modelRequest) { + String modelId = modelRequest.getModelId(); if (modelId != null && !modelId.isEmpty()) { eventAttributes.put("request.model", modelId); } } - private void addResponseModel(Map eventAttributes) { - // TODO modify to pass in Request interface if moving to the ModelInvocation interface + private void addResponseModel(Map eventAttributes, ModelRequest modelRequest) { // For Bedrock the response model is the same as the request model. - String modelId = anthropicClaudeInvokeModelRequest.getModelId(); + String modelId = modelRequest.getModelId(); if (modelId != null && !modelId.isEmpty()) { eventAttributes.put("response.model", modelId); } } - private void addRequestId(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - String requestId = anthropicClaudeInvokeModelResponse.getAmznRequestId(); + private void addRequestId(Map eventAttributes, ModelResponse modelResponse) { + String requestId = modelResponse.getAmznRequestId(); if (requestId != null && !requestId.isEmpty()) { eventAttributes.put("request_id", requestId); } } - private void addCompletionId(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - String llmChatCompletionSummaryId = anthropicClaudeInvokeModelResponse.getLlmChatCompletionSummaryId(); + private void addCompletionId(Map eventAttributes, ModelResponse modelResponse) { + String llmChatCompletionSummaryId = modelResponse.getLlmChatCompletionSummaryId(); if (llmChatCompletionSummaryId != null && !llmChatCompletionSummaryId.isEmpty()) { eventAttributes.put("completion_id", llmChatCompletionSummaryId); } } - private void addResponseUsageTotalTokens(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - int totalTokenCount = anthropicClaudeInvokeModelResponse.getTotalTokenCount(); + private void addResponseUsageTotalTokens(Map eventAttributes, ModelResponse modelResponse) { + int totalTokenCount = modelResponse.getTotalTokenCount(); if (totalTokenCount >= 0) { eventAttributes.put("response.usage.total_tokens", totalTokenCount); } } - private void addResponseUsagePromptTokens(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - int inputTokenCount = anthropicClaudeInvokeModelResponse.getInputTokenCount(); + private void addResponseUsagePromptTokens(Map eventAttributes, ModelResponse modelResponse) { + int inputTokenCount = modelResponse.getInputTokenCount(); if (inputTokenCount >= 0) { eventAttributes.put("response.usage.prompt_tokens", inputTokenCount); } } - private void addResponseUsageCompletionTokens(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - int outputTokenCount = anthropicClaudeInvokeModelResponse.getOutputTokenCount(); + private void addResponseUsageCompletionTokens(Map eventAttributes, ModelResponse modelResponse) { + int outputTokenCount = modelResponse.getOutputTokenCount(); if (outputTokenCount >= 0) { eventAttributes.put("response.usage.completion_tokens", outputTokenCount); } } - private void addResponseChoicesFinishReason(Map eventAttributes) { - // TODO modify to pass in Response interface if moving to the ModelInvocation interface - String stopReason = anthropicClaudeInvokeModelResponse.getStopReason(); + private void addResponseChoicesFinishReason(Map eventAttributes, ModelResponse modelResponse) { + String stopReason = modelResponse.getStopReason(); if (stopReason != null && !stopReason.isEmpty()) { eventAttributes.put("response.choices.finish_reason", stopReason); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index 2a9c815c23..cdc299c4e4 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -24,11 +24,8 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.Map; -import java.util.logging.Level; import static llm.models.ModelInvocation.ANTHROPIC_CLAUDE; -import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.COMPLETION; -import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; /** * Service client for accessing Amazon Bedrock Runtime. @@ -73,7 +70,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { String modelId = invokeModelRequest.modelId(); if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - AnthropicClaudeModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(invokeModelRequest, + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation anthropicClaudeModelInvocation.setLlmOperationMetricName(txn, "invokeModel"); From 443b600c81eee054b4902c12796b4cea0d94b8ef Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 22 Feb 2024 16:35:21 -0800 Subject: [PATCH 07/68] Refactoring --- .../aws-bedrock-runtime-2.20/README.md | 71 +++- .../src/main/java/llm/events/LlmEvent.java | 361 ++++++++++++++++++ .../main/java/llm/models/ModelInvocation.java | 8 +- .../main/java/llm/models/ModelRequest.java | 4 +- .../main/java/llm/models/ModelResponse.java | 2 - .../AnthropicClaudeInvokeModelRequest.java | 14 +- .../AnthropicClaudeInvokeModelResponse.java | 15 - .../AnthropicClaudeModelInvocation.java | 324 +++++----------- ...tBedrockRuntimeClient_Instrumentation.java | 2 +- 9 files changed, 531 insertions(+), 270 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index d6f27b3820..15f0034725 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -2,8 +2,77 @@ ## About +Instruments invocations of LLMs via AWS Bedrock Runtime. + +## Support + +### Supported Clients/APIs + +The following AWS Bedrock Runtime clients and APIs are supported: + +* `BedrockRuntimeClient` + * `invokeModel` +* `BedrockRuntimeAsyncClient` + * `invokeModel` + +### Supported Models + +Currently, only the following text-based foundation models are supported: + +* Anthropic Claude +* Amazon Titan +* Meta Llama 2 +* Cohere Command +* AI21 Labs Jurassic + +## Involved Pieces + +### LLM Events + +The main goal of this instrumentation is to generate the following LLM events to drive the UI. These events are custom events sent via the public `recordCustomEvent` API. + +* `LlmEmbedding` +* `LlmChatCompletionSummary` +* `LlmChatCompletionMessage` + +Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). Because of this it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. + +```yaml + custom_insights_events: + max_samples_stored: 100000 +``` + +LLM events also have some unique limits for their attributes... + +they are also bucketed into a unique namespace separate from other custom events on the backend... + +``` +Regardless of which implementation(s) are built, there are consistent changes within the agents and the UX to support AI Monitoring. + +Agents should send the entire content; do not truncate it to 256 or 4096 characters + +Agents should move known token counts to the LlmChatCompletionMessage + +Agents should remove token counts from the LlmChatCompletionSummary +``` + + +call out llm. behavior + + + +### Model Invocation/Request/Response + +* `ModelInvocation` +* `ModelRequest` +* `ModelResponse` + +### Metrics + + + +## Config -## Pieces ## Testing diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java new file mode 100644 index 0000000000..59c47dde3a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -0,0 +1,361 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.events; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; + +import java.util.HashMap; +import java.util.Map; + +/** + * Class for building an LlmEvent + */ +public class LlmEvent { + private final Map eventAttributes = new HashMap<>(); + + // LLM event types + public static final String LLM_EMBEDDING = "LlmEmbedding"; + public static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + public static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + // Optional LlmEvent attributes + private final String spanId; + private final String traceId; + private final String vendor; + private final String ingestSource; + private final String id; + private final String content; + private final String role; + private final Boolean isResponse; + private final String requestId; + private final String responseModel; + private final Integer sequence; + private final String completionId; + private final Integer responseNumberOfMessages; + private final Float duration; + private final Boolean error; + private final String input; + private final Float requestTemperature; + private final Integer requestMaxTokens; + private final String requestModel; + private final Integer responseUsageTotalTokens; + private final Integer responseUsagePromptTokens; + private final Integer responseUsageCompletionTokens; + private final String responseChoicesFinishReason; + + public static class Builder { + // Required builder parameters + private final Map linkingMetadata; + private final ModelRequest modelRequest; + private final ModelResponse modelResponse; + + /* + * All optional builder attributes are defaulted to null so that they won't be added + * to the eventAttributes map unless they are explicitly set via the builder + * methods when constructing an LlmEvent. This allows the builder to create + * any type of LlmEvent with any combination of attributes solely determined + * by the builder methods that are called while omitting all unused attributes. + */ + + // Optional builder parameters + private String spanId = null; + private String traceId = null; + private String vendor = null; + private String ingestSource = null; + private String id = null; + private String content = null; + private String role = null; + private Boolean isResponse = null; + private String requestId = null; + private String responseModel = null; + private Integer sequence = null; + private String completionId = null; + private Integer responseNumberOfMessages = null; + private Float duration = null; + private Boolean error = null; + private String input = null; + private Float requestTemperature = null; + private Integer requestMaxTokens = null; + private String requestModel = null; + private Integer responseUsageTotalTokens = null; + private Integer responseUsagePromptTokens = null; + private Integer responseUsageCompletionTokens = null; + private String responseChoicesFinishReason = null; + + public Builder(Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { + this.linkingMetadata = linkingMetadata; + this.modelRequest = modelRequest; + this.modelResponse = modelResponse; + } + + public Builder spanId() { + spanId = ModelInvocation.getSpanId(linkingMetadata); + return this; + } + + public Builder traceId() { + traceId = ModelInvocation.getTraceId(linkingMetadata); + return this; + } + + public Builder vendor() { + vendor = ModelInvocation.getVendor(); + return this; + } + + public Builder ingestSource() { + ingestSource = ModelInvocation.getIngestSource(); + return this; + } + + public Builder id(String modelId) { + id = modelId; + return this; + } + + public Builder content(String message) { + content = message; + return this; + } + + public Builder role(boolean isUser) { + if (isUser) { + role = "user"; + } else { + role = modelRequest.getRole(); + } + return this; + } + + public Builder isResponse(boolean isUser) { + isResponse = !isUser; + return this; + } + + public Builder requestId() { + requestId = modelResponse.getAmznRequestId(); + return this; + } + + public Builder responseModel() { + responseModel = modelRequest.getModelId(); + return this; + } + + public Builder sequence(int eventSequence) { + sequence = eventSequence; + return this; + } + + public Builder completionId() { + completionId = modelResponse.getLlmChatCompletionSummaryId(); + return this; + } + + public Builder responseNumberOfMessages(int numberOfMessages) { + responseNumberOfMessages = numberOfMessages; + return this; + } + + public Builder duration(float callDuration) { + duration = callDuration; + return this; + } + + public Builder error() { + error = modelResponse.isErrorResponse(); + return this; + } + + public Builder input() { + input = modelRequest.getInputText(); + return this; + } + + public Builder requestTemperature() { + requestTemperature = modelRequest.getTemperature(); + return this; + } + + public Builder requestMaxTokens() { + requestMaxTokens = modelRequest.getMaxTokensToSample(); + return this; + } + + public Builder requestModel() { + requestModel = modelRequest.getModelId(); + return this; + } + + public Builder responseUsageTotalTokens() { + responseUsageTotalTokens = modelResponse.getTotalTokenCount(); + return this; + } + + public Builder responseUsagePromptTokens() { + responseUsagePromptTokens = modelResponse.getInputTokenCount(); + return this; + } + + public Builder responseUsageCompletionTokens() { + responseUsageCompletionTokens = modelResponse.getOutputTokenCount(); + return this; + } + + public Builder responseChoicesFinishReason() { + responseChoicesFinishReason = modelResponse.getStopReason(); + return this; + } + + public LlmEvent build() { + return new LlmEvent(this); + } + } + + // This populates the LlmEvent attributes map with only the attributes that were explicitly set on the builder. + private LlmEvent(Builder builder) { + spanId = builder.spanId; + if (spanId != null && !spanId.isEmpty()) { + eventAttributes.put("span_id", spanId); + } + + traceId = builder.traceId; + if (traceId != null && !traceId.isEmpty()) { + eventAttributes.put("trace_id", traceId); + } + + vendor = builder.vendor; + if (vendor != null && !vendor.isEmpty()) { + eventAttributes.put("vendor", vendor); + } + + ingestSource = builder.ingestSource; + if (ingestSource != null && !ingestSource.isEmpty()) { + eventAttributes.put("ingest_source", ingestSource); + } + + id = builder.id; + if (id != null && !id.isEmpty()) { + eventAttributes.put("id", id); + } + + content = builder.content; + if (content != null && !content.isEmpty()) { + eventAttributes.put("content", content); + } + + role = builder.role; + if (role != null && !role.isEmpty()) { + eventAttributes.put("role", role); + } + + isResponse = builder.isResponse; + if (isResponse != null) { + eventAttributes.put("is_response", isResponse); + } + + requestId = builder.requestId; + if (requestId != null && !requestId.isEmpty()) { + eventAttributes.put("request_id", requestId); + } + + responseModel = builder.responseModel; + if (responseModel != null && !responseModel.isEmpty()) { + eventAttributes.put("response.model", responseModel); + } + + sequence = builder.sequence; + if (sequence != null && sequence >= 0) { + eventAttributes.put("sequence", sequence); + } + + completionId = builder.completionId; + if (completionId != null && !completionId.isEmpty()) { + eventAttributes.put("completion_id", completionId); + } + + responseNumberOfMessages = builder.responseNumberOfMessages; + if (responseNumberOfMessages != null && responseNumberOfMessages >= 0) { + eventAttributes.put("response.number_of_messages", responseNumberOfMessages); + } + + duration = builder.duration; + if (duration != null && duration >= 0) { + eventAttributes.put("duration", duration); + } + + error = builder.error; + if (error != null && error) { + eventAttributes.put("error", true); + } + + input = builder.input; + if (input != null && !input.isEmpty()) { + eventAttributes.put("input", input); + } + + requestTemperature = builder.requestTemperature; + if (requestTemperature != null && requestTemperature >= 0) { + eventAttributes.put("request.temperature", requestTemperature); + } + + requestMaxTokens = builder.requestMaxTokens; + if (requestMaxTokens != null && requestMaxTokens >= 0) { + eventAttributes.put("request.max_tokens", requestMaxTokens); + } + + requestModel = builder.requestModel; + if (requestModel != null && !requestModel.isEmpty()) { + eventAttributes.put("request.model", requestModel); + } + + responseUsageTotalTokens = builder.responseUsageTotalTokens; + if (responseUsageTotalTokens != null && responseUsageTotalTokens >= 0) { + eventAttributes.put("response.usage.total_tokens", responseUsageTotalTokens); + } + + responseUsagePromptTokens = builder.responseUsagePromptTokens; + if (responseUsagePromptTokens != null && responseUsagePromptTokens >= 0) { + eventAttributes.put("response.usage.prompt_tokens", responseUsagePromptTokens); + } + + responseUsageCompletionTokens = builder.responseUsageCompletionTokens; + if (responseUsageCompletionTokens != null && responseUsageCompletionTokens >= 0) { + eventAttributes.put("response.usage.completion_tokens", responseUsageCompletionTokens); + } + + responseChoicesFinishReason = builder.responseChoicesFinishReason; + if (responseChoicesFinishReason != null && !responseChoicesFinishReason.isEmpty()) { + eventAttributes.put("response.choices.finish_reason", responseChoicesFinishReason); + } + } + + /** + * Record a LlmChatCompletionMessage custom event + */ + public void recordLlmChatCompletionMessageEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); + } + + /** + * Record a LlmChatCompletionSummary custom event + */ + public void recordLlmChatCompletionSummaryEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); + } + + /** + * Record a LlmEmbedding custom event + */ + public void recordLlmEmbeddingEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index bb043d02fa..60acfaf524 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -20,11 +20,6 @@ public interface ModelInvocation { String TRACE_ID = "trace.id"; String SPAN_ID = "span.id"; - // LLM event types - String LLM_EMBEDDING = "LlmEmbedding"; - String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; - String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; - // Supported models String ANTHROPIC_CLAUDE = "claude"; String AMAZON_TITAN = "titan"; @@ -36,7 +31,7 @@ public interface ModelInvocation { * Set name of the span/segment for each LLM embedding and chat completion call * Llm/{operation_type}/{vendor_name}/{function_name} * - * @param txn current transaction + * @param txn current transaction */ void setLlmOperationMetricName(Transaction txn, String functionName); @@ -48,6 +43,7 @@ public interface ModelInvocation { void recordLlmEvents(long startTime, Map linkingMetadata); + void reportLlmError(); /** * This needs to be incremented for every invocation of the SDK. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index 9575cddf81..b66ca6c548 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -1,9 +1,9 @@ package llm.models; public interface ModelRequest { - String getMaxTokensToSample(); + int getMaxTokensToSample(); - String getTemperature(); + float getTemperature(); String getRequestMessage(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index 9f016fb101..ff94d24e51 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -27,6 +27,4 @@ public interface ModelResponse { int getStatusCode(); String getStatusText(); - - void reportLlmError(); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java index b6b7658dd4..e43b905ec3 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java @@ -82,13 +82,14 @@ private Map parseInvokeModelRequestBodyMap() { } @Override - public String getMaxTokensToSample() { - String maxTokensToSample = ""; + public int getMaxTokensToSample() { + int maxTokensToSample = 0; try { if (!getRequestBodyJsonMap().isEmpty()) { JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS_TO_SAMPLE); if (jsonNode.isNumber()) { - maxTokensToSample = jsonNode.asNumber(); + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); } } } catch (Exception e) { @@ -98,13 +99,14 @@ public String getMaxTokensToSample() { } @Override - public String getTemperature() { - String temperature = ""; + public float getTemperature() { + float temperature = 0f; try { if (!getRequestBodyJsonMap().isEmpty()) { JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); if (jsonNode.isNumber()) { - temperature = jsonNode.asNumber(); + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); } } } catch (Exception e) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java index e275cb62f5..3cf341a586 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java @@ -13,7 +13,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -231,18 +230,4 @@ public int getStatusCode() { public String getStatusText() { return statusText; } - - @Override - public void reportLlmError() { - Map errorParams = new HashMap<>(); - errorParams.put("http.statusCode", getStatusCode()); - errorParams.put("error.code", getStatusCode()); - if (!getLlmChatCompletionSummaryId().isEmpty()) { - errorParams.put("completion_id", getLlmChatCompletionSummaryId()); - } - if (!getLlmEmbeddingId().isEmpty()) { - errorParams.put("embedding_id", getLlmEmbeddingId()); - } - NewRelic.noticeError("LlmError: " + getStatusText(), errorParams); - } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 1e3c446bf8..4277b9ff37 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -1,7 +1,15 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + package llm.models.anthropic.claude; import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; +import llm.events.LlmEvent; import llm.models.ModelInvocation; import llm.models.ModelRequest; import llm.models.ModelResponse; @@ -29,77 +37,67 @@ public void setLlmOperationMetricName(Transaction txn, String functionName) { txn.getTracedMethod().setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); } - // TODO add event builders??? @Override public void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata) { - Map eventAttributes = new HashMap<>(); - // Generic attributes that are constant for all Bedrock Models - addSpanId(eventAttributes, linkingMetadata); - addTraceId(eventAttributes, linkingMetadata); - addVendor(eventAttributes); - addIngestSource(eventAttributes); - - // Attributes dependent on the request/response - addId(eventAttributes, claudeResponse.getLlmEmbeddingId()); - addRequestId(eventAttributes, claudeResponse); - addInput(eventAttributes, claudeRequest); - addRequestModel(eventAttributes, claudeRequest); - addResponseModel(eventAttributes, claudeRequest); - addResponseUsageTotalTokens(eventAttributes, claudeResponse); - addResponseUsagePromptTokens(eventAttributes, claudeResponse); - - // Error attributes if (claudeResponse.isErrorResponse()) { - addError(eventAttributes); - claudeResponse.reportLlmError(); + reportLlmError(); } - // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed - long endTime = System.currentTimeMillis(); - addDuration(eventAttributes, (endTime - startTime)); + LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(claudeResponse.getLlmEmbeddingId()) + .requestId() + .input() + .requestModel() + .responseModel() + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction // eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? // eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! - } @Override public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata) { - Map eventAttributes = new HashMap<>(); - // Generic attributes that are constant for all Bedrock Models - addSpanId(eventAttributes, linkingMetadata); - addTraceId(eventAttributes, linkingMetadata); - addVendor(eventAttributes); - addIngestSource(eventAttributes); - - // Attributes dependent on the request/response - addId(eventAttributes, claudeResponse.getLlmChatCompletionSummaryId()); - addRequestId(eventAttributes, claudeResponse); - addRequestTemperature(eventAttributes, claudeRequest); - addRequestMaxTokens(eventAttributes, claudeRequest); - addRequestModel(eventAttributes, claudeRequest); - addResponseModel(eventAttributes, claudeRequest); - addResponseNumberOfMessages(eventAttributes, numberOfMessages); - addResponseUsageTotalTokens(eventAttributes, claudeResponse); - addResponseUsagePromptTokens(eventAttributes, claudeResponse); - addResponseUsageCompletionTokens(eventAttributes, claudeResponse); - addResponseChoicesFinishReason(eventAttributes, claudeResponse); - - // Error attributes if (claudeResponse.isErrorResponse()) { - addError(eventAttributes); - claudeResponse.reportLlmError(); - } - - // Duration attribute from manual timing as we don't have a way of getting timing from a tracer/segment within a method that is in the process of being timed - long endTime = System.currentTimeMillis(); - addDuration(eventAttributes, (endTime - startTime)); - - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(claudeResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .responseUsageCompletionTokens() + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction @@ -109,36 +107,26 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata) { - Map eventAttributes = new HashMap<>(); - // Generic attributes that are constant for all Bedrock Models - addSpanId(eventAttributes, linkingMetadata); - addTraceId(eventAttributes, linkingMetadata); - addVendor(eventAttributes); - addIngestSource(eventAttributes); - - // Multiple completion message events can be created per transaction so generate an id on the fly instead of storing each in the response/request wrapper - addId(eventAttributes, ModelInvocation.getRandomGuid()); - - // Attributes dependent on the request/response - addContent(eventAttributes, message); - if (message.contains("Human:")) { - addRole(eventAttributes, "user"); - addIsResponse(eventAttributes, false); - } else { - String role = claudeRequest.getRole(); - if (!role.isEmpty()) { - addRole(eventAttributes, role); - if (!role.contains("user")) { - addIsResponse(eventAttributes, true); - } - } - } - addRequestId(eventAttributes, claudeResponse); - addResponseModel(eventAttributes, claudeRequest); - addSequence(eventAttributes, sequence); - addCompletionId(eventAttributes, claudeResponse); - - NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); + boolean isUser = message.contains("Human:"); + + LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); // TODO is it possible to do something like this to call getUserAttributes? // see com.newrelic.agent.bridge.Transaction @@ -158,6 +146,20 @@ public void recordLlmEvents(long startTime, Map linkingMetadata) } } + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", claudeResponse.getStatusCode()); + errorParams.put("error.code", claudeResponse.getStatusCode()); + if (!claudeResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", claudeResponse.getLlmChatCompletionSummaryId()); + } + if (!claudeResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", claudeResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + claudeResponse.getStatusText(), errorParams); + } + /** * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. * The number of LlmChatCompletionMessage events produced can differ based on vendor. @@ -170,156 +172,4 @@ private void recordLlmChatCompletionEvents(long startTime, Map l // A summary of all LlmChatCompletionMessage events recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata); } - - // TODO can all of these helper methods be moved to the ModelInvocation interface??? - private void addSpanId(Map eventAttributes, Map linkingMetadata) { - String spanId = ModelInvocation.getSpanId(linkingMetadata); - if (spanId != null && !spanId.isEmpty()) { - eventAttributes.put("span_id", spanId); - } - } - - private void addTraceId(Map eventAttributes, Map linkingMetadata) { - String traceId = ModelInvocation.getTraceId(linkingMetadata); - if (traceId != null && !traceId.isEmpty()) { - eventAttributes.put("trace_id", traceId); - } - } - - private void addVendor(Map eventAttributes) { - String vendor = ModelInvocation.getVendor(); - if (vendor != null && !vendor.isEmpty()) { - eventAttributes.put("vendor", vendor); - } - } - - private void addIngestSource(Map eventAttributes) { - String ingestSource = ModelInvocation.getIngestSource(); - if (ingestSource != null && !ingestSource.isEmpty()) { - eventAttributes.put("ingest_source", ingestSource); - } - } - - private void addId(Map eventAttributes, String id) { - if (id != null && !id.isEmpty()) { - eventAttributes.put("id", id); - } - } - - private void addContent(Map eventAttributes, String message) { - if (message != null && !message.isEmpty()) { - eventAttributes.put("content", message); - } - } - - private void addRole(Map eventAttributes, String role) { - if (role != null && !role.isEmpty()) { - eventAttributes.put("role", role); - } - } - - private void addIsResponse(Map eventAttributes, boolean isResponse) { - eventAttributes.put("is_response", isResponse); - } - - private void addSequence(Map eventAttributes, int sequence) { - if (sequence >= 0) { - eventAttributes.put("sequence", sequence); - } - } - - private void addResponseNumberOfMessages(Map eventAttributes, int numberOfMessages) { - if (numberOfMessages >= 0) { - eventAttributes.put("response.number_of_messages", numberOfMessages); - } - } - - private void addDuration(Map eventAttributes, long duration) { - if (duration >= 0) { - eventAttributes.put("duration", duration); - } - } - - private void addError(Map eventAttributes) { - eventAttributes.put("error", true); - } - - private void addInput(Map eventAttributes, ModelRequest modelRequest) { - String inputText = modelRequest.getInputText(); - if (inputText != null && !inputText.isEmpty()) { - eventAttributes.put("input", inputText); - } - } - - private void addRequestTemperature(Map eventAttributes, ModelRequest modelRequest) { - String temperature = modelRequest.getTemperature(); - if (temperature != null && !temperature.isEmpty()) { - eventAttributes.put("request.temperature", temperature); - } - } - - private void addRequestMaxTokens(Map eventAttributes, ModelRequest modelRequest) { - String maxTokensToSample = modelRequest.getMaxTokensToSample(); - if (maxTokensToSample != null && !maxTokensToSample.isEmpty()) { - eventAttributes.put("request.max_tokens", maxTokensToSample); - } - } - - private void addRequestModel(Map eventAttributes, ModelRequest modelRequest) { - String modelId = modelRequest.getModelId(); - if (modelId != null && !modelId.isEmpty()) { - eventAttributes.put("request.model", modelId); - } - } - - private void addResponseModel(Map eventAttributes, ModelRequest modelRequest) { - // For Bedrock the response model is the same as the request model. - String modelId = modelRequest.getModelId(); - if (modelId != null && !modelId.isEmpty()) { - eventAttributes.put("response.model", modelId); - } - } - - private void addRequestId(Map eventAttributes, ModelResponse modelResponse) { - String requestId = modelResponse.getAmznRequestId(); - if (requestId != null && !requestId.isEmpty()) { - eventAttributes.put("request_id", requestId); - } - } - - private void addCompletionId(Map eventAttributes, ModelResponse modelResponse) { - String llmChatCompletionSummaryId = modelResponse.getLlmChatCompletionSummaryId(); - if (llmChatCompletionSummaryId != null && !llmChatCompletionSummaryId.isEmpty()) { - eventAttributes.put("completion_id", llmChatCompletionSummaryId); - } - } - - private void addResponseUsageTotalTokens(Map eventAttributes, ModelResponse modelResponse) { - int totalTokenCount = modelResponse.getTotalTokenCount(); - if (totalTokenCount >= 0) { - eventAttributes.put("response.usage.total_tokens", totalTokenCount); - } - } - - private void addResponseUsagePromptTokens(Map eventAttributes, ModelResponse modelResponse) { - int inputTokenCount = modelResponse.getInputTokenCount(); - if (inputTokenCount >= 0) { - eventAttributes.put("response.usage.prompt_tokens", inputTokenCount); - } - } - - private void addResponseUsageCompletionTokens(Map eventAttributes, ModelResponse modelResponse) { - int outputTokenCount = modelResponse.getOutputTokenCount(); - if (outputTokenCount >= 0) { - eventAttributes.put("response.usage.completion_tokens", outputTokenCount); - } - } - - private void addResponseChoicesFinishReason(Map eventAttributes, ModelResponse modelResponse) { - String stopReason = modelResponse.getStopReason(); - if (stopReason != null && !stopReason.isEmpty()) { - eventAttributes.put("response.choices.finish_reason", stopReason); - } - } - } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index cdc299c4e4..834eb1395a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -66,7 +66,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { // TODO check AIM config if (txn != null && !(txn instanceof NoOpTransaction)) { Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); - Map userAttributes = txn.getUserAttributes(); +// Map userAttributes = txn.getUserAttributes(); String modelId = invokeModelRequest.modelId(); if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { From a2ba5e1a94e51cffe4796f9c3de3d572d41102f1 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Fri, 23 Feb 2024 16:57:35 -0800 Subject: [PATCH 08/68] Refactoring. Add llm custom attributes. --- .../agent/bridge/NoOpTransaction.java | 6 +++ .../newrelic/agent/bridge/Transaction.java | 1 + .../agent/model/LlmCustomInsightsEvent.java | 30 ++++++++++++ .../aws-bedrock-runtime-2.20/README.md | 25 +++++++--- .../src/main/java/llm/events/LlmEvent.java | 41 +++++++++++++--- .../main/java/llm/models/ModelInvocation.java | 38 +++----------- .../main/java/llm/models/ModelRequest.java | 7 +++ .../main/java/llm/models/ModelResponse.java | 7 +++ .../main/java/llm/models/SupportedModels.java | 16 ++++++ .../AnthropicClaudeModelInvocation.java | 32 +++++------- .../src/main/java/llm/vendor/Vendor.java | 16 ++++++ ...ockRuntimeAsyncClient_Instrumentation.java | 3 +- ...tBedrockRuntimeClient_Instrumentation.java | 9 ++-- .../newrelic/agent/TransactionApiImpl.java | 6 +++ .../CustomEventAttributeValidator.java | 7 ++- .../LlmEventAttributeValidator.java | 36 ++++++++++++++ .../analytics/InsightsServiceImpl.java | 49 +++++++++++++++++-- 17 files changed, 251 insertions(+), 78 deletions(-) create mode 100644 agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java create mode 100644 newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java index 4449df5d57..9c79f6964b 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java @@ -27,6 +27,7 @@ public class NoOpTransaction implements Transaction { public static final Transaction INSTANCE = new NoOpTransaction(); public static final NoOpMap AGENT_ATTRIBUTES = new NoOpMap<>(); + public static final NoOpMap USER_ATTRIBUTES = new NoOpMap<>(); @Override public void beforeSendResponseHeaders() { @@ -153,6 +154,11 @@ public Map getAgentAttributes() { return AGENT_ATTRIBUTES; } + @Override + public Map getUserAttributes() { + return USER_ATTRIBUTES; + } + @Override public boolean registerAsyncActivity(Object activityContext) { return false; diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java index 9f89f2064d..fc4061973e 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java @@ -23,6 +23,7 @@ public interface Transaction extends com.newrelic.api.agent.Transaction { Map getAgentAttributes(); + Map getUserAttributes(); /** * Sets the current transaction's name. diff --git a/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java b/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java new file mode 100644 index 0000000000..44798d0f76 --- /dev/null +++ b/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java @@ -0,0 +1,30 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.model; + +/** + * Represents an internal subtype of a CustomInsightsEvent that is sent to the + * custom_event_data collector endpoint but potentially subject to different + * validation rules and agent configuration. + */ +public class LlmCustomInsightsEvent { + // LLM event types + private static final String LLM_EMBEDDING = "LlmEmbedding"; + private static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + private static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + /** + * Determines if a CustomInsightsEvent should be treated as a LlmEvent + * + * @param eventType type of the current event + * @return true if eventType is an LlmEvent, else false + */ + public static boolean isLlmEvent(String eventType) { + return eventType.equals(LLM_EMBEDDING) || eventType.equals(LLM_CHAT_COMPLETION_MESSAGE) || eventType.equals(LLM_CHAT_COMPLETION_SUMMARY); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index 15f0034725..fbb9fd453a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -29,13 +29,13 @@ Currently, only the following text-based foundation models are supported: ### LLM Events -The main goal of this instrumentation is to generate the following LLM events to drive the UI. These events are custom events sent via the public `recordCustomEvent` API. +The main goal of this instrumentation is to generate the following LLM events to drive the UI. -* `LlmEmbedding` -* `LlmChatCompletionSummary` -* `LlmChatCompletionMessage` +* `LlmEmbedding`: An event that captures data specific to the creation of an embedding. +* `LlmChatCompletionSummary`: An event that captures high-level data about the creation of a chat completion including request, response, and call information. +* `LlmChatCompletionMessage`: An event that corresponds to each message (sent and received) from a chat completion call including those created by the user, assistant, and the system. -Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). Because of this it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. +These events are custom events sent via the public `recordCustomEvent` API. Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). Because of this, it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. LLM events are sent to the `custom_event_data` collector endpoint but the backend will assign them a unique namespace to distinguish them from other custom events. ```yaml custom_insights_events: @@ -44,8 +44,6 @@ Currently, they contribute towards the following Custom Insights Events limits ( LLM events also have some unique limits for their attributes... -they are also bucketed into a unique namespace separate from other custom events on the backend... - ``` Regardless of which implementation(s) are built, there are consistent changes within the agents and the UX to support AI Monitoring. @@ -59,7 +57,7 @@ Agents should remove token counts from the LlmChatCompletionSummary call out llm. behavior - +Can be built via `LlmEvent` builder ### Model Invocation/Request/Response @@ -77,3 +75,14 @@ call out llm. behavior ## Testing + +## TODO +* Add custom `llm.` attributes +* Wire up async client +* Switch instrumentation back to BedrockRuntimeClient/BedrockRuntimeAsyncClient interfaces +* Clean up request/response parsing logic +* Wire up Config + * Generate `Supportability/{language}/ML/Streaming/Disabled` metric? +* Set up and test new models +* Write instrumentation tests +* Finish readme diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 59c47dde3a..07c11794c4 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -7,10 +7,12 @@ package llm.events; +import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import llm.models.ModelInvocation; import llm.models.ModelRequest; import llm.models.ModelResponse; +import llm.vendor.Vendor; import java.util.HashMap; import java.util.Map; @@ -19,7 +21,7 @@ * Class for building an LlmEvent */ public class LlmEvent { - private final Map eventAttributes = new HashMap<>(); + private final Map eventAttributes; // LLM event types public static final String LLM_EMBEDDING = "LlmEmbedding"; @@ -54,6 +56,7 @@ public class LlmEvent { public static class Builder { // Required builder parameters private final Map linkingMetadata; + private final Transaction txn; private final ModelRequest modelRequest; private final ModelResponse modelResponse; @@ -90,7 +93,8 @@ public static class Builder { private Integer responseUsageCompletionTokens = null; private String responseChoicesFinishReason = null; - public Builder(Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { + public Builder(Transaction txn, Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { + this.txn = txn; this.linkingMetadata = linkingMetadata; this.modelRequest = modelRequest; this.modelResponse = modelResponse; @@ -107,12 +111,12 @@ public Builder traceId() { } public Builder vendor() { - vendor = ModelInvocation.getVendor(); + vendor = Vendor.VENDOR; return this; } public Builder ingestSource() { - ingestSource = ModelInvocation.getIngestSource(); + ingestSource = Vendor.INGEST_SOURCE; return this; } @@ -216,12 +220,15 @@ public Builder responseChoicesFinishReason() { } public LlmEvent build() { - return new LlmEvent(this); + return new LlmEvent(this, txn); } } // This populates the LlmEvent attributes map with only the attributes that were explicitly set on the builder. - private LlmEvent(Builder builder) { + private LlmEvent(Builder builder, Transaction txn) { + // Init map with any user attributes containing the llm. prefix + eventAttributes = new HashMap<>(getUserLlmAttributes(txn)); + spanId = builder.spanId; if (spanId != null && !spanId.isEmpty()) { eventAttributes.put("span_id", spanId); @@ -338,6 +345,28 @@ private LlmEvent(Builder builder) { } } + /** + * Takes a map of all attributes added by the customer via the addCustomParameter API and returns a map + * containing only custom attributes with a llm. prefix to be added to LlmEvents. + * + * @param txn current transaction + * @return Map of user attributes prefixed with llm. + */ + private Map getUserLlmAttributes(Transaction txn) { + Map userAttributes = txn.getUserAttributes(); + Map userLlmAttributes = new HashMap<>(); + + if (userAttributes != null && !userAttributes.isEmpty()) { + for (Map.Entry entry : userAttributes.entrySet()) { + String key = entry.getKey(); + if (key.startsWith("llm.")) { + userLlmAttributes.put(key, entry.getValue()); + } + } + } + return userLlmAttributes; + } + /** * Record a LlmChatCompletionMessage custom event */ diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 60acfaf524..024211ae25 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -13,27 +13,14 @@ import java.util.Map; import java.util.UUID; -public interface ModelInvocation { - String VENDOR = "bedrock"; - String BEDROCK = "Bedrock"; - String INGEST_SOURCE = "Java"; - String TRACE_ID = "trace.id"; - String SPAN_ID = "span.id"; - - // Supported models - String ANTHROPIC_CLAUDE = "claude"; - String AMAZON_TITAN = "titan"; - String META_LLAMA_2 = "llama"; - String COHERE_COMMAND = "cohere"; - String AI_21_LABS_JURASSIC = "jurassic"; +import static llm.vendor.Vendor.VENDOR; +public interface ModelInvocation { /** * Set name of the span/segment for each LLM embedding and chat completion call * Llm/{operation_type}/{vendor_name}/{function_name} - * - * @param txn current transaction */ - void setLlmOperationMetricName(Transaction txn, String functionName); + void setLlmOperationMetricName(String functionName); void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata); @@ -53,9 +40,8 @@ public interface ModelInvocation { * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and * the UI will be hidden. */ - static void incrementInstrumentedSupportabilityMetric() { - // Bedrock vendor_version isn't available, so set it to instrumentation version instead - NewRelic.incrementCounter("Supportability/Java/ML/Bedrock/2.20"); + static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { + NewRelic.incrementCounter("Supportability/Java/ML/" + VENDOR + "/" + vendorVersion); } static void setLlmTrueAgentAttribute(Transaction txn) { @@ -63,24 +49,14 @@ static void setLlmTrueAgentAttribute(Transaction txn) { txn.getAgentAttributes().put("llm", true); } - // Lowercased name of vendor (bedrock or openAI) - static String getVendor() { - return VENDOR; - } - - // Name of the language agent (ex: Python, Node) - static String getIngestSource() { - return INGEST_SOURCE; - } - // GUID associated with the active trace static String getSpanId(Map linkingMetadata) { - return linkingMetadata.get(SPAN_ID); + return linkingMetadata.get("span.id"); } // ID of the current trace static String getTraceId(Map linkingMetadata) { - return linkingMetadata.get(TRACE_ID); + return linkingMetadata.get("trace.id"); } // Returns a string representation of a random GUID diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index b66ca6c548..44dda2d2d5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -1,3 +1,10 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + package llm.models; public interface ModelRequest { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index ff94d24e51..0c95c1942f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -1,3 +1,10 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + package llm.models; public interface ModelResponse { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java new file mode 100644 index 0000000000..8f94975185 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java @@ -0,0 +1,16 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +public class SupportedModels { + public static final String ANTHROPIC_CLAUDE = "claude"; + public static final String AMAZON_TITAN = "titan"; + public static final String META_LLAMA_2 = "llama"; + public static final String COHERE_COMMAND = "cohere"; + public static final String AI_21_LABS_JURASSIC = "jurassic"; +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 4277b9ff37..4d1bebe542 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -22,18 +22,21 @@ import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.COMPLETION; import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; public class AnthropicClaudeModelInvocation implements ModelInvocation { + Transaction txn; ModelRequest claudeRequest; ModelResponse claudeResponse; - public AnthropicClaudeModelInvocation(InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { + public AnthropicClaudeModelInvocation(Transaction currentTransaction, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { + txn = currentTransaction; claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); } @Override - public void setLlmOperationMetricName(Transaction txn, String functionName) { + public void setLlmOperationMetricName(String functionName) { txn.getTracedMethod().setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); } @@ -42,8 +45,10 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM if (claudeResponse.isErrorResponse()) { reportLlmError(); } - - LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); + // TODO should the builder just take a ModelInvocation instance and pull all of this stuff from it? All it would + // require is storing the linking metadata on the ModelInvocation instance and adding getters for the + // txn, linkingMetadata, claudeRequest, claudeResponse. + LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmEmbeddingEvent = builder .spanId() @@ -62,11 +67,6 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM .build(); llmEmbeddingEvent.recordLlmEmbeddingEvent(); - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! } @Override @@ -75,7 +75,7 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start reportLlmError(); } - LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmChatCompletionSummaryEvent = builder .spanId() @@ -98,18 +98,13 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start .build(); llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! } @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata) { boolean isUser = message.contains("Human:"); - LlmEvent.Builder builder = new LlmEvent.Builder(linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmChatCompletionMessageEvent = builder .spanId() @@ -127,11 +122,6 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, Ma .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); - - // TODO is it possible to do something like this to call getUserAttributes? - // see com.newrelic.agent.bridge.Transaction -// eventAttributes.put("llm.", ""); // TODO Optional metadata attributes that can be added to a transaction by a customer via add_custom_attribute API. Done internally when event is created? -// eventAttributes.put("llm.conversation_id", "NEW API"); // TODO Optional attribute that can be added to a transaction by a customer via add_custom_attribute API. Should just be added and prefixed along with the other user attributes? YES! } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java new file mode 100644 index 0000000000..0381ffe818 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java @@ -0,0 +1,16 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.vendor; + +public class Vendor { + public static final String VENDOR = "bedrock"; + // Bedrock vendor_version isn't obtainable, so set it to instrumentation version instead + public static final String VENDOR_VERSION = "2.20"; + public static final String BEDROCK = "Bedrock"; + public static final String INGEST_SOURCE = "Java"; +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java index bff0955fb3..0c5137a80e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -24,6 +24,7 @@ import java.util.concurrent.Executor; import static llm.models.ModelInvocation.incrementInstrumentedSupportabilityMetric; +import static llm.vendor.Vendor.VENDOR_VERSION; /** * Service client for accessing Amazon Bedrock Runtime asynchronously. @@ -62,7 +63,7 @@ public CompletableFuture invokeModel(InvokeModelRequest inv Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); - incrementInstrumentedSupportabilityMetric(); + incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); // this should never happen, but protecting against bad implementations if (invokeModelResponseFuture == null) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index 834eb1395a..8a8dd96389 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -25,7 +25,8 @@ import java.util.Map; -import static llm.models.ModelInvocation.ANTHROPIC_CLAUDE; +import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.vendor.Vendor.VENDOR_VERSION; /** * Service client for accessing Amazon Bedrock Runtime. @@ -59,7 +60,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { long startTime = System.currentTimeMillis(); InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); - ModelInvocation.incrementInstrumentedSupportabilityMetric(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); // Transaction txn = NewRelic.getAgent().getTransaction(); Transaction txn = AgentBridge.getAgent().getTransaction(); @@ -70,10 +71,10 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { String modelId = invokeModelRequest.modelId(); if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(invokeModelRequest, + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(txn, invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation - anthropicClaudeModelInvocation.setLlmOperationMetricName(txn, "invokeModel"); + anthropicClaudeModelInvocation.setLlmOperationMetricName("invokeModel"); // Set llm = true agent attribute ModelInvocation.setLlmTrueAgentAttribute(txn); anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java index 292091654c..fbc7839994 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java @@ -313,6 +313,12 @@ public Map getAgentAttributes() { return (tx != null) ? tx.getAgentAttributes() : NoOpTransaction.INSTANCE.getAgentAttributes(); } + @Override + public Map getUserAttributes() { + Transaction tx = getTransactionIfExists(); + return (tx != null) ? tx.getUserAttributes() : NoOpTransaction.INSTANCE.getUserAttributes(); + } + @Override public void provideHeaders(InboundHeaders headers) { Transaction tx = getTransactionIfExists(); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java index 650beeefb9..70e0ac3429 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java @@ -12,8 +12,11 @@ /** * Attribute validator with truncation rules specific to custom events. */ -public class CustomEventAttributeValidator extends AttributeValidator{ - private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService().getDefaultAgentConfig().getInsightsConfig().getMaxAttributeValue(); +public class CustomEventAttributeValidator extends AttributeValidator { + private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService() + .getDefaultAgentConfig() + .getInsightsConfig() + .getMaxAttributeValue(); public CustomEventAttributeValidator(String attributeType) { super(attributeType); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java new file mode 100644 index 0000000000..6350df9c54 --- /dev/null +++ b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java @@ -0,0 +1,36 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.attributes; + +import com.newrelic.agent.service.ServiceFactory; + +/** + * Attribute validator with truncation rules specific to LLM events. + */ +public class LlmEventAttributeValidator extends AttributeValidator { + // FIXME different size attribute limits for LLM events InsightsConfigImpl.MAX_MAX_ATTRIBUTE_VALUE ? + private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService() + .getDefaultAgentConfig() + .getInsightsConfig() + .getMaxAttributeValue(); + + public LlmEventAttributeValidator(String attributeType) { + super(attributeType); + } + + @Override + protected String truncateValue(String key, String value, String methodCalled) { + // TODO make sure that this behavior is accepted into the agent spec + if (key.equals("content")) { + return value; + } + String truncatedVal = truncateString(value, MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE); + logTruncatedValue(key, value, truncatedVal, methodCalled, MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE); + return truncatedVal; + } +} diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java index 5ed82abb4a..c08a373a15 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java @@ -18,10 +18,12 @@ import com.newrelic.agent.TransactionData; import com.newrelic.agent.attributes.AttributeSender; import com.newrelic.agent.attributes.CustomEventAttributeValidator; +import com.newrelic.agent.attributes.LlmEventAttributeValidator; import com.newrelic.agent.config.AgentConfig; import com.newrelic.agent.config.AgentConfigListener; import com.newrelic.agent.model.AnalyticsEvent; import com.newrelic.agent.model.CustomInsightsEvent; +import com.newrelic.agent.model.LlmCustomInsightsEvent; import com.newrelic.agent.service.AbstractService; import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.stats.StatsEngine; @@ -332,7 +334,7 @@ public String getEventHarvestLimitMetric() { } private void recordSupportabilityMetrics(StatsEngine statsEngine, long durationInNanoseconds, - DistributedSamplingPriorityQueue reservoir) { + DistributedSamplingPriorityQueue reservoir) { statsEngine.getStats(MetricNames.SUPPORTABILITY_INSIGHTS_SERVICE_CUSTOMER_SENT) .incrementCallCount(reservoir.size()); statsEngine.getStats(MetricNames.SUPPORTABILITY_INSIGHTS_SERVICE_CUSTOMER_SEEN) @@ -366,16 +368,25 @@ private static String mapInternString(String value) { private static CustomInsightsEvent createValidatedEvent(String eventType, Map attributes) { Map userAttributes = new HashMap<>(attributes.size()); - CustomInsightsEvent event = new CustomInsightsEvent(mapInternString(eventType), System.currentTimeMillis(), userAttributes, DistributedTraceServiceImpl.nextTruncatedFloat()); + CustomInsightsEvent event = new CustomInsightsEvent(mapInternString(eventType), System.currentTimeMillis(), userAttributes, + DistributedTraceServiceImpl.nextTruncatedFloat()); // Now add the attributes from the argument map to the event using an AttributeSender. // An AttributeSender is the way to reuse all the existing attribute validations. We // also locally "intern" Strings because we anticipate a lot of reuse of the keys and, // possibly, the values. But there's an interaction: if the key or value is chopped // within the attribute sender, the modified value won't be "interned" in our map. + AttributeSender sender; + final String method; - AttributeSender sender = new CustomEventAttributeSender(userAttributes); - final String method = "add custom event attribute"; + // CustomInsightsEvents are being overloaded to support some internal event types being sent to the same agent endpoint + if (LlmCustomInsightsEvent.isLlmEvent(eventType)) { + sender = new LlmEventAttributeSender(userAttributes); + method = "add llm event attribute"; + } else { + sender = new CustomEventAttributeSender(userAttributes); + method = "add custom event attribute"; + } for (Map.Entry entry : attributes.entrySet()) { String key = entry.getKey(); @@ -384,7 +395,7 @@ private static CustomInsightsEvent createValidatedEvent(String eventType, Map getAttributeMap() { } } + /** + * LlmEvent attribute validation rules differ from those of a standard CustomInsightsEvent + */ + private static class LlmEventAttributeSender extends AttributeSender { + private static final String ATTRIBUTE_TYPE = "llm"; + + private final Map userAttributes; + + public LlmEventAttributeSender(Map userAttributes) { + super(new LlmEventAttributeValidator(ATTRIBUTE_TYPE)); + this.userAttributes = userAttributes; + setTransactional(true); + } + + @Override + protected String getAttributeType() { + return ATTRIBUTE_TYPE; + } + + @Override + protected Map getAttributeMap() { + if (ServiceFactory.getConfigService().getDefaultAgentConfig().isCustomParametersAllowed()) { + return userAttributes; + } + return null; + } + } + @Override public Insights getTransactionInsights(AgentConfig config) { return new TransactionInsights(config); From a3c2f2ae3f165b594c8afb7a8f3c17f3f0f975bd Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Mon, 26 Feb 2024 15:33:14 -0800 Subject: [PATCH 09/68] Refactoring. Wire up async client. --- .../aws-bedrock-runtime-2.20/README.md | 2 - .../src/main/java/llm/events/LlmEvent.java | 24 +-- .../main/java/llm/models/ModelInvocation.java | 15 +- .../AnthropicClaudeModelInvocation.java | 22 ++- ...ockRuntimeAsyncClient_Instrumentation.java | 71 ++++++++ .../BedrockRuntimeClient_Instrumentation.java | 61 +++++++ ...ockRuntimeAsyncClient_Instrumentation.java | 172 +++++++++--------- ...tBedrockRuntimeClient_Instrumentation.java | 172 +++++++++--------- 8 files changed, 343 insertions(+), 196 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index fbb9fd453a..17744dc6d4 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -77,9 +77,7 @@ Can be built via `LlmEvent` builder ## TODO -* Add custom `llm.` attributes * Wire up async client -* Switch instrumentation back to BedrockRuntimeClient/BedrockRuntimeAsyncClient interfaces * Clean up request/response parsing logic * Wire up Config * Generate `Supportability/{language}/ML/Streaming/Disabled` metric? diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 07c11794c4..0ff1a583d1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -7,7 +7,6 @@ package llm.events; -import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import llm.models.ModelInvocation; import llm.models.ModelRequest; @@ -22,6 +21,7 @@ */ public class LlmEvent { private final Map eventAttributes; + private final Map userLlmAttributes; // LLM event types public static final String LLM_EMBEDDING = "LlmEmbedding"; @@ -55,8 +55,8 @@ public class LlmEvent { public static class Builder { // Required builder parameters + private final Map userAttributes; private final Map linkingMetadata; - private final Transaction txn; private final ModelRequest modelRequest; private final ModelResponse modelResponse; @@ -93,8 +93,8 @@ public static class Builder { private Integer responseUsageCompletionTokens = null; private String responseChoicesFinishReason = null; - public Builder(Transaction txn, Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { - this.txn = txn; + public Builder(Map userAttributes, Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { + this.userAttributes = userAttributes; this.linkingMetadata = linkingMetadata; this.modelRequest = modelRequest; this.modelResponse = modelResponse; @@ -220,14 +220,17 @@ public Builder responseChoicesFinishReason() { } public LlmEvent build() { - return new LlmEvent(this, txn); + return new LlmEvent(this); } } // This populates the LlmEvent attributes map with only the attributes that were explicitly set on the builder. - private LlmEvent(Builder builder, Transaction txn) { - // Init map with any user attributes containing the llm. prefix - eventAttributes = new HashMap<>(getUserLlmAttributes(txn)); + private LlmEvent(Builder builder) { + // Map of custom user attributes with the llm prefix + userLlmAttributes = getUserLlmAttributes(builder.userAttributes); + + // Map of all LLM event attributes + eventAttributes = new HashMap<>(userLlmAttributes); spanId = builder.spanId; if (spanId != null && !spanId.isEmpty()) { @@ -349,11 +352,10 @@ private LlmEvent(Builder builder, Transaction txn) { * Takes a map of all attributes added by the customer via the addCustomParameter API and returns a map * containing only custom attributes with a llm. prefix to be added to LlmEvents. * - * @param txn current transaction + * @param userAttributes Map of all custom user attributes * @return Map of user attributes prefixed with llm. */ - private Map getUserLlmAttributes(Transaction txn) { - Map userAttributes = txn.getUserAttributes(); + private Map getUserLlmAttributes(Map userAttributes) { Map userLlmAttributes = new HashMap<>(); if (userAttributes != null && !userAttributes.isEmpty()) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 024211ae25..7d0e3371bf 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -9,6 +9,7 @@ import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; import java.util.Map; import java.util.UUID; @@ -17,10 +18,20 @@ public interface ModelInvocation { /** - * Set name of the span/segment for each LLM embedding and chat completion call + * Set name of the traced method for each LLM embedding and chat completion call * Llm/{operation_type}/{vendor_name}/{function_name} + *

+ * Used with the sync client + */ + void setTracedMethodName(Transaction txn, String functionName); + + /** + * Set name of the async segment for each LLM embedding and chat completion call + * Llm/{operation_type}/{vendor_name}/{function_name} + *

+ * Used with the async client */ - void setLlmOperationMetricName(String functionName); + void setSegmentName(Segment segment, String functionName); void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 4d1bebe542..69404f7d8f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -9,6 +9,7 @@ import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; import llm.events.LlmEvent; import llm.models.ModelInvocation; import llm.models.ModelRequest; @@ -25,21 +26,26 @@ import static llm.vendor.Vendor.BEDROCK; public class AnthropicClaudeModelInvocation implements ModelInvocation { - Transaction txn; + Map userAttributes; ModelRequest claudeRequest; ModelResponse claudeResponse; - public AnthropicClaudeModelInvocation(Transaction currentTransaction, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { - txn = currentTransaction; + public AnthropicClaudeModelInvocation(Map userCustomAttributes, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { + userAttributes = userCustomAttributes; claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); } @Override - public void setLlmOperationMetricName(String functionName) { + public void setTracedMethodName(Transaction txn, String functionName) { txn.getTracedMethod().setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); } + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); + } + @Override public void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata) { if (claudeResponse.isErrorResponse()) { @@ -47,8 +53,8 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM } // TODO should the builder just take a ModelInvocation instance and pull all of this stuff from it? All it would // require is storing the linking metadata on the ModelInvocation instance and adding getters for the - // txn, linkingMetadata, claudeRequest, claudeResponse. - LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); + // userAttributes, linkingMetadata, claudeRequest, claudeResponse. + LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmEmbeddingEvent = builder .spanId() @@ -75,7 +81,7 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start reportLlmError(); } - LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmChatCompletionSummaryEvent = builder .spanId() @@ -104,7 +110,7 @@ public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long start public void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata) { boolean isUser = message.contains("Human:"); - LlmEvent.Builder builder = new LlmEvent.Builder(txn, linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); LlmEvent llmChatCompletionMessageEvent = builder .spanId() diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java new file mode 100644 index 0000000000..8ea71f473d --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -0,0 +1,71 @@ +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import llm.models.ModelInvocation; +import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; + +import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.vendor.Vendor.VENDOR_VERSION; + +/** + * Service client for accessing Amazon Bedrock Runtime asynchronously. + */ +@Weave(type = MatchType.Interface, originalName = "software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient") +public abstract class BedrockRuntimeAsyncClient_Instrumentation { + + @Trace + public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + // TODO check AIM config + long startTime = System.currentTimeMillis(); + Transaction txn = AgentBridge.getAgent().getTransaction(); + // Segment will be named later when the response is available + Segment segment = txn.startSegment(""); + + CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); + + // Set llm = true agent attribute + ModelInvocation.setLlmTrueAgentAttribute(txn); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + + // this should never happen, but protecting against bad implementations + if (invokeModelResponseFuture == null) { + segment.end(); + } else { + invokeModelResponseFuture.whenComplete(new BiConsumer() { + @Override + public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { + try { + // TODO check AIM config + String modelId = invokeModelRequest.modelId(); + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + invokeModelResponse); + // Set segment name based on LLM operation + anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); + anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + } + segment.end(); + } catch (Throwable t) { + AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); + } + } + }); + } + return invokeModelResponseFuture; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java new file mode 100644 index 0000000000..3ee2ae5a34 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -0,0 +1,61 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import llm.models.ModelInvocation; +import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Map; + +import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.vendor.Vendor.VENDOR_VERSION; + +/** + * Service client for accessing Amazon Bedrock Runtime. + */ +@Weave(type = MatchType.Interface, originalName = "software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient") +public abstract class BedrockRuntimeClient_Instrumentation { + + @Trace + public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { + long startTime = System.currentTimeMillis(); + InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); + + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + Transaction txn = AgentBridge.getAgent().getTransaction(); + + // TODO check AIM config + if (!(txn instanceof NoOpTransaction)) { + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + + String modelId = invokeModelRequest.modelId(); + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation + anthropicClaudeModelInvocation.setTracedMethodName(txn, "invokeModel"); + // Set llm = true agent attribute + ModelInvocation.setLlmTrueAgentAttribute(txn); + anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + } + } + + return invokeModelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java index 0c5137a80e..167352fdac 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java @@ -1,86 +1,86 @@ -/* - * - * * Copyright 2024 New Relic Corporation. All rights reserved. - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package software.amazon.awssdk.services.bedrockruntime; - -import com.newrelic.agent.bridge.AgentBridge; -import com.newrelic.api.agent.NewRelic; -import com.newrelic.api.agent.Segment; -import com.newrelic.api.agent.Trace; -import com.newrelic.api.agent.weaver.MatchType; -import com.newrelic.api.agent.weaver.Weave; -import com.newrelic.api.agent.weaver.Weaver; -import software.amazon.awssdk.core.client.config.SdkClientConfiguration; -import software.amazon.awssdk.core.client.handler.AsyncClientHandler; -import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; - -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; - -import static llm.models.ModelInvocation.incrementInstrumentedSupportabilityMetric; -import static llm.vendor.Vendor.VENDOR_VERSION; - -/** - * Service client for accessing Amazon Bedrock Runtime asynchronously. - */ -// TODO switch back to instrumenting the BedrockRuntimeAsyncClient interface instead of this implementation class -@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient") -final class DefaultBedrockRuntimeAsyncClient_Instrumentation { -// private static final Logger log = LoggerFactory.getLogger(DefaultBedrockRuntimeAsyncClient.class); -// -// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() -// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); - - private final AsyncClientHandler clientHandler; - - private final AwsJsonProtocolFactory protocolFactory; - - private final SdkClientConfiguration clientConfiguration; - - private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; - - private final Executor executor; - - protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, - SdkClientConfiguration clientConfiguration) { - this.clientHandler = Weaver.callOriginal(); - this.clientConfiguration = Weaver.callOriginal(); - this.serviceClientConfiguration = Weaver.callOriginal(); - this.protocolFactory = Weaver.callOriginal(); - this.executor = Weaver.callOriginal(); - } - - @Trace - public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { - long startTime = System.currentTimeMillis(); - // TODO name "Llm/" + operationType + "/Bedrock/InvokeModelAsync" ???? - Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); - CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); - - incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); - - // this should never happen, but protecting against bad implementations - if (invokeModelResponseFuture == null) { - segment.end(); - } else { - invokeModelResponseFuture.whenComplete((invokeModelResponse, throwable) -> { - try { - // TODO do all the stuff - segment.end(); - } catch (Throwable t) { - AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); - } - }); - } - - return invokeModelResponseFuture; - - } - -} +///* +// * +// * * Copyright 2024 New Relic Corporation. All rights reserved. +// * * SPDX-License-Identifier: Apache-2.0 +// * +// */ +// +//package software.amazon.awssdk.services.bedrockruntime; +// +//import com.newrelic.agent.bridge.AgentBridge; +//import com.newrelic.api.agent.NewRelic; +//import com.newrelic.api.agent.Segment; +//import com.newrelic.api.agent.Trace; +//import com.newrelic.api.agent.weaver.MatchType; +//import com.newrelic.api.agent.weaver.Weave; +//import com.newrelic.api.agent.weaver.Weaver; +//import software.amazon.awssdk.core.client.config.SdkClientConfiguration; +//import software.amazon.awssdk.core.client.handler.AsyncClientHandler; +//import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; +//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +// +//import java.util.concurrent.CompletableFuture; +//import java.util.concurrent.Executor; +// +//import static llm.models.ModelInvocation.incrementInstrumentedSupportabilityMetric; +//import static llm.vendor.Vendor.VENDOR_VERSION; +// +///** +// * Service client for accessing Amazon Bedrock Runtime asynchronously. +// */ +//// TODO switch back to instrumenting the BedrockRuntimeAsyncClient interface instead of this implementation class +//@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient") +//final class DefaultBedrockRuntimeAsyncClient_Instrumentation { +//// private static final Logger log = LoggerFactory.getLogger(DefaultBedrockRuntimeAsyncClient.class); +//// +//// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() +//// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); +// +// private final AsyncClientHandler clientHandler; +// +// private final AwsJsonProtocolFactory protocolFactory; +// +// private final SdkClientConfiguration clientConfiguration; +// +// private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; +// +// private final Executor executor; +// +// protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, +// SdkClientConfiguration clientConfiguration) { +// this.clientHandler = Weaver.callOriginal(); +// this.clientConfiguration = Weaver.callOriginal(); +// this.serviceClientConfiguration = Weaver.callOriginal(); +// this.protocolFactory = Weaver.callOriginal(); +// this.executor = Weaver.callOriginal(); +// } +// +// @Trace +// public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { +// long startTime = System.currentTimeMillis(); +// // TODO name "Llm/" + operationType + "/Bedrock/InvokeModelAsync" ???? +// Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); +// CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); +// +// incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); +// +// // this should never happen, but protecting against bad implementations +// if (invokeModelResponseFuture == null) { +// segment.end(); +// } else { +// invokeModelResponseFuture.whenComplete((invokeModelResponse, throwable) -> { +// try { +// // TODO do all the stuff +// segment.end(); +// } catch (Throwable t) { +// AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); +// } +// }); +// } +// +// return invokeModelResponseFuture; +// +// } +// +//} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java index 8a8dd96389..d810db8ecf 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java @@ -1,87 +1,85 @@ -/* - * - * * Copyright 2024 New Relic Corporation. All rights reserved. - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package software.amazon.awssdk.services.bedrockruntime; - -import com.newrelic.agent.bridge.AgentBridge; -import com.newrelic.agent.bridge.NoOpTransaction; -import com.newrelic.agent.bridge.Transaction; -import com.newrelic.api.agent.NewRelic; -import com.newrelic.api.agent.Trace; -import com.newrelic.api.agent.weaver.MatchType; -import com.newrelic.api.agent.weaver.Weave; -import com.newrelic.api.agent.weaver.Weaver; -import llm.models.ModelInvocation; -import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; -import software.amazon.awssdk.core.client.config.SdkClientConfiguration; -import software.amazon.awssdk.core.client.handler.SyncClientHandler; -import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; - -import java.util.Map; - -import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; -import static llm.vendor.Vendor.VENDOR_VERSION; - -/** - * Service client for accessing Amazon Bedrock Runtime. - */ -// TODO switch back to instrumenting the BedrockRuntimeClient interface instead of this implementation class -@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient") -final class DefaultBedrockRuntimeClient_Instrumentation { -// private static final Logger log = Logger.loggerFor(DefaultBedrockRuntimeClient.class); -// -// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() -// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); - - private final SyncClientHandler clientHandler; - - private final AwsJsonProtocolFactory protocolFactory; - - private final SdkClientConfiguration clientConfiguration; - - private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; - - protected DefaultBedrockRuntimeClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, - SdkClientConfiguration clientConfiguration) { - this.clientHandler = Weaver.callOriginal(); - this.clientConfiguration = Weaver.callOriginal(); - this.serviceClientConfiguration = Weaver.callOriginal(); - this.protocolFactory = Weaver.callOriginal(); - } - - @Trace - public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { - long startTime = System.currentTimeMillis(); - InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); - - ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); - -// Transaction txn = NewRelic.getAgent().getTransaction(); - Transaction txn = AgentBridge.getAgent().getTransaction(); - // TODO check AIM config - if (txn != null && !(txn instanceof NoOpTransaction)) { - Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); -// Map userAttributes = txn.getUserAttributes(); - - String modelId = invokeModelRequest.modelId(); - if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(txn, invokeModelRequest, - invokeModelResponse); - // Set traced method name based on LLM operation - anthropicClaudeModelInvocation.setLlmOperationMetricName("invokeModel"); - // Set llm = true agent attribute - ModelInvocation.setLlmTrueAgentAttribute(txn); - anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); - } - } - - return invokeModelResponse; - } - -} +///* +// * +// * * Copyright 2024 New Relic Corporation. All rights reserved. +// * * SPDX-License-Identifier: Apache-2.0 +// * +// */ +// +//package software.amazon.awssdk.services.bedrockruntime; +// +//import com.newrelic.agent.bridge.AgentBridge; +//import com.newrelic.agent.bridge.NoOpTransaction; +//import com.newrelic.agent.bridge.Transaction; +//import com.newrelic.api.agent.NewRelic; +//import com.newrelic.api.agent.Trace; +//import com.newrelic.api.agent.weaver.MatchType; +//import com.newrelic.api.agent.weaver.Weave; +//import com.newrelic.api.agent.weaver.Weaver; +//import llm.models.ModelInvocation; +//import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; +//import software.amazon.awssdk.core.client.config.SdkClientConfiguration; +//import software.amazon.awssdk.core.client.handler.SyncClientHandler; +//import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; +//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +// +//import java.util.Map; +// +//import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +//import static llm.vendor.Vendor.VENDOR_VERSION; +// +///** +// * Service client for accessing Amazon Bedrock Runtime. +// */ +//// TODO switch back to instrumenting the BedrockRuntimeClient interface instead of this implementation class +//@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient") +//final class DefaultBedrockRuntimeClient_Instrumentation { +//// private static final Logger log = Logger.loggerFor(DefaultBedrockRuntimeClient.class); +//// +//// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() +//// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); +// +// private final SyncClientHandler clientHandler; +// +// private final AwsJsonProtocolFactory protocolFactory; +// +// private final SdkClientConfiguration clientConfiguration; +// +// private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; +// +// protected DefaultBedrockRuntimeClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, +// SdkClientConfiguration clientConfiguration) { +// this.clientHandler = Weaver.callOriginal(); +// this.clientConfiguration = Weaver.callOriginal(); +// this.serviceClientConfiguration = Weaver.callOriginal(); +// this.protocolFactory = Weaver.callOriginal(); +// } +// +// @Trace +// public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { +// long startTime = System.currentTimeMillis(); +// InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); +// +// ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); +// +// Transaction txn = AgentBridge.getAgent().getTransaction(); +// // TODO check AIM config +// if (txn != null && !(txn instanceof NoOpTransaction)) { +// Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); +// +// String modelId = invokeModelRequest.modelId(); +// if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { +// ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(txn, invokeModelRequest, +// invokeModelResponse); +// // Set traced method name based on LLM operation +// anthropicClaudeModelInvocation.setLlmOperationMetricName("invokeModel"); +// // Set llm = true agent attribute +// ModelInvocation.setLlmTrueAgentAttribute(txn); +// anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); +// } +// } +// +// return invokeModelResponse; +// } +// +//} From cfc58235a39b5407dd8670e65f3d7d8c1730dd85 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 27 Feb 2024 10:53:46 -0800 Subject: [PATCH 10/68] Add config. Some cleanup. --- .../aimonitoring/AiMonitoringUtils.java | 35 ++++++ .../aws-bedrock-runtime-2.20/README.md | 46 +++++++- .../main/java/llm/models/ModelInvocation.java | 4 - ...ockRuntimeAsyncClient_Instrumentation.java | 106 +++++++++++++----- .../BedrockRuntimeClient_Instrumentation.java | 50 ++++++--- 5 files changed, 183 insertions(+), 58 deletions(-) create mode 100644 agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java new file mode 100644 index 0000000000..8168925a04 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java @@ -0,0 +1,35 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.bridge.aimonitoring; + +import com.newrelic.api.agent.NewRelic; + +public class AiMonitoringUtils { + // Enabled defaults + private static final boolean AI_MONITORING_ENABLED_DEFAULT = false; + private static final boolean AI_MONITORING_STREAMING_ENABLED_DEFAULT = true; + + /** + * Check if ai_monitoring features are enabled. + * Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. + * + * @return true if enabled, else false + */ + public static boolean isAiMonitoringEnabled() { + return NewRelic.getAgent().getConfig().getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + } + + /** + * Check if ai_monitoring.streaming features are enabled. + * + * @return true if enabled, else false + */ + public static boolean isAiMonitoringStreamingEnabled() { + return NewRelic.getAgent().getConfig().getValue("ai_monitoring.streaming.enabled", AI_MONITORING_STREAMING_ENABLED_DEFAULT); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index 17744dc6d4..6f91d64ee2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -2,7 +2,7 @@ ## About -Instruments invocations of LLMs via AWS Bedrock Runtime. +Instruments invocations of LLMs via the AWS Bedrock Runtime SDK. ## Support @@ -15,6 +15,8 @@ The following AWS Bedrock Runtime clients and APIs are supported: * `BedrockRuntimeAsyncClient` * `invokeModel` +Note: Currently, `invokeModelWithResponseStream` is not supported. + ### Supported Models Currently, only the following text-based foundation models are supported: @@ -42,7 +44,7 @@ These events are custom events sent via the public `recordCustomEvent` API. Curr max_samples_stored: 100000 ``` -LLM events also have some unique limits for their attributes... +LLM events also have some unique limits for the content attribute... ``` Regardless of which implementation(s) are built, there are consistent changes within the agents and the UX to support AI Monitoring. @@ -55,7 +57,6 @@ Agents should remove token counts from the LlmChatCompletionSummary ``` -call out llm. behavior Can be built via `LlmEvent` builder @@ -65,19 +66,54 @@ Can be built via `LlmEvent` builder * `ModelRequest` * `ModelResponse` +### Custom LLM Attributes + +Any custom attributes added by customers using the `addCustomParameters` API that are prefixed with `llm.` will automatically be copied to `LlmEvent`s. For custom attributes added by the `addCustomParameters` API to be added to `LlmEvent`s the API calls must occur before invoking the Bedrock SDK. + +One potential custom attribute with special meaning that customers are encouraged to add is `llm.conversation_id`, which has implications in the UI and can be used to group LLM messages into specific conversations. + ### Metrics +When in an active transaction a named span/segment for each LLM embedding and chat completion call is created using the following format: +`Llm/{operation_type}/{vendor_name}/{function_name}` -## Config +* `operation_type`: `completion` or `embedding` +* `vendor_name`: Name of LLM vendor (ex: `OpenAI`, `Bedrock`) +* `function_name`: Name of instrumented function (ex: `invokeModel`, `create`) + +A supportability metric is reported each time an instrumented framework method is invoked. These metrics are detected and parsed by APM Services to support entity tagging in the UI, if a metric isn't reported within the past day the LLM UI will not display in APM. The metric uses the following format: + +`Supportability/{language}/ML/{vendor_name}/{vendor_version}` + +* `language`: Name of language agent (ex: `Java`) +* `vendor_name`: Name of LLM vendor (ex: `Bedrock`) +* `vendor_version`: Version of instrumented LLM library (ex: `2.20`) + +Note: The vendor version isn't obtainable from the AWS Bedrock SDK for Java so the instrumentation version is used instead. +Additionally, a supportability metric is recorded to indicate if streaming is disabled. Streaming is considered disabled if the value of the `ai_monitoring.streaming.enabled` configuration setting is `false`. If streaming is enabled, no supportability metric will be sent. The metric uses the following format: + +`Supportability/{language}/ML/Streaming/Disabled` + +* `language`: Name of language agent (ex: `Java`) + +Note: Streaming is not currently supported. + + + + // Set llm = true agent attribute required on TransactionEvents + + +## Config +`ai_monitoring.enabled`: Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. +`ai_monitoring.streaming.enabled`: NOT SUPPORTED ## Testing ## TODO -* Wire up async client * Clean up request/response parsing logic * Wire up Config * Generate `Supportability/{language}/ML/Streaming/Disabled` metric? diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 7d0e3371bf..c6de96b32d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -46,10 +46,6 @@ public interface ModelInvocation { /** * This needs to be incremented for every invocation of the SDK. * Supportability/{language}/ML/{vendor_name}/{vendor_version} - *

- * The metric generated triggers the creation of a tag which gates the AI Response UI. The - * tag lives for 27 hours so if this metric isn't repeatedly sent the tag will disappear and - * the UI will be hidden. */ static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { NewRelic.incrementCounter("Supportability/Java/ML/" + VENDOR + "/" + vendorVersion); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 8ea71f473d..061ac2365d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -1,6 +1,14 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + package software.amazon.awssdk.services.bedrockruntime; import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.agent.bridge.NoOpTransaction; import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; @@ -12,12 +20,21 @@ import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.logging.Level; +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringEnabled; +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringStreamingEnabled; +import static llm.models.SupportedModels.AI_21_LABS_JURASSIC; +import static llm.models.SupportedModels.AMAZON_TITAN; import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.META_LLAMA_2; import static llm.vendor.Vendor.VENDOR_VERSION; /** @@ -28,44 +45,71 @@ public abstract class BedrockRuntimeAsyncClient_Instrumentation { @Trace public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { - // TODO check AIM config long startTime = System.currentTimeMillis(); - Transaction txn = AgentBridge.getAgent().getTransaction(); - // Segment will be named later when the response is available - Segment segment = txn.startSegment(""); - CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); - // Set llm = true agent attribute - ModelInvocation.setLlmTrueAgentAttribute(txn); - ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); - Map userAttributes = txn.getUserAttributes(); - Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + if (isAiMonitoringEnabled()) { + Transaction txn = AgentBridge.getAgent().getTransaction(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + + if (!(txn instanceof NoOpTransaction)) { + // Segment will be renamed later when the response is available + Segment segment = txn.startSegment(""); + // Set llm = true agent attribute, this is required on transaction events + ModelInvocation.setLlmTrueAgentAttribute(txn); + + // This should never happen, but protecting against bad implementations + if (invokeModelResponseFuture == null) { + segment.end(); + } else { + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + String modelId = invokeModelRequest.modelId(); + + invokeModelResponseFuture.whenComplete(new BiConsumer() { + @Override + public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { + try { + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + invokeModelResponse); + // Set segment name based on LLM operation from response + anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); + anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { - // this should never happen, but protecting against bad implementations - if (invokeModelResponseFuture == null) { - segment.end(); - } else { - invokeModelResponseFuture.whenComplete(new BiConsumer() { - @Override - public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { - try { - // TODO check AIM config - String modelId = invokeModelRequest.modelId(); - if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, - invokeModelResponse); - // Set segment name based on LLM operation - anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { + + } else if (modelId.toLowerCase().contains(COHERE_COMMAND)) { + + } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { + + } + segment.end(); + } catch (Throwable t) { + AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); + } } - segment.end(); - } catch (Throwable t) { - AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); - } + }); } - }); + } } return invokeModelResponseFuture; } + + public CompletableFuture invokeModelWithResponseStream( + InvokeModelWithResponseStreamRequest invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + if (isAiMonitoringEnabled()) { + if (isAiMonitoringStreamingEnabled()) { + NewRelic.getAgent() + .getLogger() + .log(Level.FINER, + "aws-bedrock-runtime-2.20 instrumentation does not currently support response streaming. Enabling ai_monitoring.streaming will have no effect."); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); + } + } + return Weaver.callOriginal(); + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java index 3ee2ae5a34..7c5b068c20 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -22,11 +22,16 @@ import java.util.Map; +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringEnabled; +import static llm.models.SupportedModels.AI_21_LABS_JURASSIC; +import static llm.models.SupportedModels.AMAZON_TITAN; import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.META_LLAMA_2; import static llm.vendor.Vendor.VENDOR_VERSION; /** - * Service client for accessing Amazon Bedrock Runtime. + * Service client for accessing Amazon Bedrock Runtime synchronously. */ @Weave(type = MatchType.Interface, originalName = "software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient") public abstract class BedrockRuntimeClient_Instrumentation { @@ -36,26 +41,35 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { long startTime = System.currentTimeMillis(); InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); - ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); - Transaction txn = AgentBridge.getAgent().getTransaction(); - - // TODO check AIM config - if (!(txn instanceof NoOpTransaction)) { - Map userAttributes = txn.getUserAttributes(); - Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); - - String modelId = invokeModelRequest.modelId(); - if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, - invokeModelResponse); - // Set traced method name based on LLM operation - anthropicClaudeModelInvocation.setTracedMethodName(txn, "invokeModel"); - // Set llm = true agent attribute + if (isAiMonitoringEnabled()) { + Transaction txn = AgentBridge.getAgent().getTransaction(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + + if (!(txn instanceof NoOpTransaction)) { + // Set llm = true agent attribute, this is required on transaction events ModelInvocation.setLlmTrueAgentAttribute(txn); - anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + String modelId = invokeModelRequest.modelId(); + + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + anthropicClaudeModelInvocation.setTracedMethodName(txn, "invokeModel"); + anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { + + } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { + + } else if (modelId.toLowerCase().contains(COHERE_COMMAND)) { + + } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { + + } } } - return invokeModelResponse; } } From 78b9166e7a1429c8c8a1743c303333fe1b073946 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 27 Feb 2024 13:13:23 -0800 Subject: [PATCH 11/68] Cleanup --- instrumentation/aws-bedrock-runtime-2.20/README.md | 2 -- .../src/main/java/llm/models/ModelInvocation.java | 2 +- .../anthropic/claude/AnthropicClaudeModelInvocation.java | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index 6f91d64ee2..baeb0bd332 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -115,8 +115,6 @@ Note: Streaming is not currently supported. ## TODO * Clean up request/response parsing logic -* Wire up Config - * Generate `Supportability/{language}/ML/Streaming/Disabled` metric? * Set up and test new models * Write instrumentation tests * Finish readme diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index c6de96b32d..157dd40b90 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -35,7 +35,7 @@ public interface ModelInvocation { void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata); - void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata); + void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages, Map linkingMetadata); void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 69404f7d8f..d84e178f65 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -76,7 +76,7 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM } @Override - public void recordLlmChatCompletionSummaryEvent(int numberOfMessages, long startTime, Map linkingMetadata) { + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages, Map linkingMetadata) { if (claudeResponse.isErrorResponse()) { reportLlmError(); } @@ -166,6 +166,6 @@ private void recordLlmChatCompletionEvents(long startTime, Map l // Second LlmChatCompletionMessage represents the completion message from the LLM response recordLlmChatCompletionMessageEvent(1, claudeResponse.getResponseMessage(), linkingMetadata); // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(2, startTime, linkingMetadata); + recordLlmChatCompletionSummaryEvent(startTime, 2, linkingMetadata); } } From a0ad252ad20c75592410669af71e2f30f4237cd1 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 27 Feb 2024 16:13:49 -0800 Subject: [PATCH 12/68] Cleanup --- .../src/main/java/llm/events/LlmEvent.java | 10 +-- .../main/java/llm/models/ModelInvocation.java | 31 +++++++--- .../AnthropicClaudeModelInvocation.java | 61 +++++++++++++------ ...ockRuntimeAsyncClient_Instrumentation.java | 7 ++- .../BedrockRuntimeClient_Instrumentation.java | 4 +- 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 0ff1a583d1..adaf02f9f2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -93,11 +93,11 @@ public static class Builder { private Integer responseUsageCompletionTokens = null; private String responseChoicesFinishReason = null; - public Builder(Map userAttributes, Map linkingMetadata, ModelRequest modelRequest, ModelResponse modelResponse) { - this.userAttributes = userAttributes; - this.linkingMetadata = linkingMetadata; - this.modelRequest = modelRequest; - this.modelResponse = modelResponse; + public Builder(ModelInvocation modelInvocation) { + userAttributes = modelInvocation.getUserAttributes(); + linkingMetadata = modelInvocation.getLinkingMetadata(); + modelRequest = modelInvocation.getModelRequest(); + modelResponse = modelInvocation.getModelResponse(); } public Builder spanId() { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 157dd40b90..96ce152bdf 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -33,16 +33,24 @@ public interface ModelInvocation { */ void setSegmentName(Segment segment, String functionName); - void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata); + void recordLlmEmbeddingEvent(long startTime); - void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages, Map linkingMetadata); + void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages); - void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata); + void recordLlmChatCompletionMessageEvent(int sequence, String message); - void recordLlmEvents(long startTime, Map linkingMetadata); + void recordLlmEvents(long startTime); void reportLlmError(); + Map getLinkingMetadata(); + + Map getUserAttributes(); + + ModelRequest getModelRequest(); + + ModelResponse getModelResponse(); + /** * This needs to be incremented for every invocation of the SDK. * Supportability/{language}/ML/{vendor_name}/{vendor_version} @@ -51,19 +59,28 @@ static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { NewRelic.incrementCounter("Supportability/Java/ML/" + VENDOR + "/" + vendorVersion); } + static void incrementStreamingDisabledSupportabilityMetric() { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); + } + static void setLlmTrueAgentAttribute(Transaction txn) { // If in a txn with LLM-related spans txn.getAgentAttributes().put("llm", true); } - // GUID associated with the active trace static String getSpanId(Map linkingMetadata) { - return linkingMetadata.get("span.id"); + if (linkingMetadata != null && !linkingMetadata.isEmpty()) { + return linkingMetadata.get("span.id"); + } + return ""; } // ID of the current trace static String getTraceId(Map linkingMetadata) { - return linkingMetadata.get("trace.id"); + if (linkingMetadata != null && !linkingMetadata.isEmpty()) { + return linkingMetadata.get("trace.id"); + } + return ""; } // Returns a string representation of a random GUID diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index d84e178f65..0af4828d43 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -26,14 +26,17 @@ import static llm.vendor.Vendor.BEDROCK; public class AnthropicClaudeModelInvocation implements ModelInvocation { + Map linkingMetadata; Map userAttributes; ModelRequest claudeRequest; ModelResponse claudeResponse; - public AnthropicClaudeModelInvocation(Map userCustomAttributes, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { - userAttributes = userCustomAttributes; - claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); - claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); + public AnthropicClaudeModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); + this.claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); } @Override @@ -47,14 +50,12 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime, Map linkingMetadata) { + public void recordLlmEmbeddingEvent(long startTime) { if (claudeResponse.isErrorResponse()) { reportLlmError(); } - // TODO should the builder just take a ModelInvocation instance and pull all of this stuff from it? All it would - // require is storing the linking metadata on the ModelInvocation instance and adding getters for the - // userAttributes, linkingMetadata, claudeRequest, claudeResponse. - LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); + + LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmEmbeddingEvent = builder .spanId() @@ -76,12 +77,12 @@ public void recordLlmEmbeddingEvent(long startTime, Map linkingM } @Override - public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages, Map linkingMetadata) { + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { if (claudeResponse.isErrorResponse()) { reportLlmError(); } - LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionSummaryEvent = builder .spanId() @@ -107,10 +108,10 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message, Map linkingMetadata) { + public void recordLlmChatCompletionMessageEvent(int sequence, String message) { boolean isUser = message.contains("Human:"); - LlmEvent.Builder builder = new LlmEvent.Builder(userAttributes, linkingMetadata, claudeRequest, claudeResponse); + LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder .spanId() @@ -131,12 +132,12 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, Ma } @Override - public void recordLlmEvents(long startTime, Map linkingMetadata) { + public void recordLlmEvents(long startTime) { String operationType = claudeResponse.getOperationType(); if (operationType.equals(COMPLETION)) { - recordLlmChatCompletionEvents(startTime, linkingMetadata); + recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime, linkingMetadata); + recordLlmEmbeddingEvent(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -160,12 +161,32 @@ public void reportLlmError() { * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ - private void recordLlmChatCompletionEvents(long startTime, Map linkingMetadata) { + private void recordLlmChatCompletionEvents(long startTime) { // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, claudeRequest.getRequestMessage(), linkingMetadata); + recordLlmChatCompletionMessageEvent(0, claudeRequest.getRequestMessage()); // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, claudeResponse.getResponseMessage(), linkingMetadata); + recordLlmChatCompletionMessageEvent(1, claudeResponse.getResponseMessage()); // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2, linkingMetadata); + recordLlmChatCompletionSummaryEvent(startTime, 2); + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return claudeRequest; + } + + @Override + public ModelResponse getModelResponse() { + return claudeResponse; } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 061ac2365d..64d28a0542 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -71,11 +71,12 @@ public CompletableFuture invokeModel(InvokeModelRequest inv public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { try { if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(linkingMetadata, userAttributes, + invokeModelRequest, invokeModelResponse); // Set segment name based on LLM operation from response anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + anthropicClaudeModelInvocation.recordLlmEvents(startTime); } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { @@ -107,7 +108,7 @@ public CompletableFuture invokeModelWithResponseStream( .log(Level.FINER, "aws-bedrock-runtime-2.20 instrumentation does not currently support response streaming. Enabling ai_monitoring.streaming will have no effect."); } else { - NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); + ModelInvocation.incrementStreamingDisabledSupportabilityMetric(); } } return Weaver.callOriginal(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java index 7c5b068c20..b962cc56da 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -54,11 +54,11 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { String modelId = invokeModelRequest.modelId(); if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(userAttributes, invokeModelRequest, + ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation from response anthropicClaudeModelInvocation.setTracedMethodName(txn, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); + anthropicClaudeModelInvocation.recordLlmEvents(startTime); } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { From fe6268db6579ccea5ceebc9820f83559bb7ac487 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 27 Feb 2024 17:27:10 -0800 Subject: [PATCH 13/68] Cleanup logging around parsing errors --- .../main/java/llm/models/ModelRequest.java | 12 +++++ .../main/java/llm/models/ModelResponse.java | 12 +++++ .../AnthropicClaudeInvokeModelRequest.java | 46 ++++++++++++------- .../AnthropicClaudeInvokeModelResponse.java | 46 +++++++++++-------- 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index 44dda2d2d5..9339ff0a8c 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -7,6 +7,10 @@ package llm.models; +import com.newrelic.api.agent.NewRelic; + +import java.util.logging.Level; + public interface ModelRequest { int getMaxTokensToSample(); @@ -19,4 +23,12 @@ public interface ModelRequest { String getInputText(); String getModelId(); + + static void logParsingFailure(Exception e, String fieldBeingParsed) { + if (e != null) { + NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelRequest"); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Unable to parse empty/null " + fieldBeingParsed + " from ModelRequest"); + } + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index 0c95c1942f..9c7be62592 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -7,6 +7,10 @@ package llm.models; +import com.newrelic.api.agent.NewRelic; + +import java.util.logging.Level; + public interface ModelResponse { String COMPLETION = "completion"; String EMBEDDING = "embedding"; @@ -34,4 +38,12 @@ public interface ModelResponse { int getStatusCode(); String getStatusText(); + + static void logParsingFailure(Exception e, String fieldBeingParsed) { + if (e != null) { + NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelResponse"); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Unable to parse empty/null " + fieldBeingParsed + " from ModelResponse"); + } + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java index e43b905ec3..f78b9fc63a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java @@ -8,6 +8,7 @@ package llm.models.anthropic.claude; import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; import software.amazon.awssdk.protocols.jsoncore.JsonNode; import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; @@ -16,11 +17,13 @@ import java.util.Map; import java.util.logging.Level; +import static llm.models.ModelRequest.logParsingFailure; + /** - * Stores the required info from the Bedrock InvokeModelRequest - * but doesn't hold a reference to the actual request object. + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. */ -public class AnthropicClaudeInvokeModelRequest implements llm.models.ModelRequest { +public class AnthropicClaudeInvokeModelRequest implements ModelRequest { // TODO might be able to move some of these constants to the ModelRequest interface // need to figure out if they are consistent across all models private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; @@ -41,7 +44,7 @@ public AnthropicClaudeInvokeModelRequest(InvokeModelRequest invokeModelRequest) invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); modelId = invokeModelRequest.modelId(); } else { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelRequest"); + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); } } @@ -70,14 +73,15 @@ private Map parseInvokeModelRequestBodyMap() { JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); Map requestBodyJsonMap = null; - // TODO check for other types? Or will it always be Object? - // add try/catch? - if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { - requestBodyJsonMap = requestBodyJsonNode.asObject(); - } else { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelRequest body as Map Object"); + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); } - return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); } @@ -91,9 +95,11 @@ public int getMaxTokensToSample() { String maxTokensToSampleString = jsonNode.asNumber(); maxTokensToSample = Integer.parseInt(maxTokensToSampleString); } + } else { + logParsingFailure(null, MAX_TOKENS_TO_SAMPLE); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + MAX_TOKENS_TO_SAMPLE); + logParsingFailure(e, MAX_TOKENS_TO_SAMPLE); } return maxTokensToSample; } @@ -108,9 +114,11 @@ public float getTemperature() { String temperatureString = jsonNode.asNumber(); temperature = Float.parseFloat(temperatureString); } + } else { + logParsingFailure(null, TEMPERATURE); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + TEMPERATURE); + logParsingFailure(e, TEMPERATURE); } return temperature; } @@ -124,9 +132,11 @@ public String getRequestMessage() { if (jsonNode.isString()) { prompt = jsonNode.asString(); } + } else { + logParsingFailure(null, PROMPT); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + PROMPT); + logParsingFailure(e, PROMPT); } return prompt; } @@ -143,9 +153,11 @@ public String getRole() { } else if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + ASSISTANT)) { return ASSISTANT; } + } else { + logParsingFailure(null, "role"); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse role from InvokeModelRequest"); + logParsingFailure(e, "role"); } return ""; } @@ -159,9 +171,11 @@ public String getInputText() { if (jsonNode.isString()) { inputText = jsonNode.asString(); } + } else { + logParsingFailure(null, INPUT_TEXT); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + INPUT_TEXT); + logParsingFailure(e, INPUT_TEXT); } return inputText; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java index 3cf341a586..d320dd0750 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java @@ -8,6 +8,7 @@ package llm.models.anthropic.claude; import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; import software.amazon.awssdk.protocols.jsoncore.JsonNode; import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -19,12 +20,13 @@ import java.util.logging.Level; import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; /** - * Stores the required info from the Bedrock InvokeModelResponse - * but doesn't hold a reference to the actual response object. + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. */ -public class AnthropicClaudeInvokeModelResponse implements llm.models.ModelResponse { +public class AnthropicClaudeInvokeModelResponse implements ModelResponse { private static final String STOP_REASON = "stop_reason"; // Response headers @@ -97,26 +99,28 @@ private Map parseInvokeModelResponseBodyMap() { // TODO check for other types? Or will it always be Object? if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); } -// else { -// NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse body as Map Object"); -// } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse body as Map Object"); + logParsingFailure(e, "response body"); } - return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); } private void setOperationType(String invokeModelResponseBody) { - if (!invokeModelResponseBody.isEmpty()) { - if (invokeModelResponseBody.startsWith(JSON_START + COMPLETION)) { - operationType = COMPLETION; - } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { - operationType = EMBEDDING; - } else { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unknown operation type"); + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.startsWith(JSON_START + COMPLETION)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } } + } catch (Exception e) { + logParsingFailure(e, "operation type"); } } @@ -138,9 +142,11 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { amznRequestId = amznRequestIdHeaders.get(0); } + } else { + logParsingFailure(null, "response headers"); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse InvokeModelResponse headers"); + logParsingFailure(e, "response headers"); } } @@ -158,9 +164,11 @@ public String getResponseMessage() { if (jsonNode.isString()) { completion = jsonNode.asString(); } + } else { + logParsingFailure(null, COMPLETION); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + COMPLETION); + logParsingFailure(e, COMPLETION); } return completion; } @@ -174,9 +182,11 @@ public String getStopReason() { if (jsonNode.isString()) { stopReason = jsonNode.asString(); } + } else { + logParsingFailure(null, STOP_REASON); } } catch (Exception e) { - NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unable to parse " + STOP_REASON); + logParsingFailure(e, STOP_REASON); } return stopReason; } From 7c09a4d5c2ec66891973e3974ed86fe9a97ac239 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 28 Feb 2024 15:26:26 -0800 Subject: [PATCH 14/68] Fix some async client issues --- .../aws-bedrock-runtime-2.20/README.md | 62 +++++++++++++++++-- .../main/java/llm/models/ModelInvocation.java | 4 ++ .../AnthropicClaudeInvokeModelRequest.java | 31 ++++------ .../AnthropicClaudeInvokeModelResponse.java | 31 ++++------ .../AnthropicClaudeModelInvocation.java | 11 ++++ ...ockRuntimeAsyncClient_Instrumentation.java | 7 ++- .../analytics/InsightsServiceImpl.java | 1 + 7 files changed, 100 insertions(+), 47 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index baeb0bd332..c3117872ba 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -66,12 +66,19 @@ Can be built via `LlmEvent` builder * `ModelRequest` * `ModelResponse` -### Custom LLM Attributes +### Attributes + +#### Custom LLM Attributes Any custom attributes added by customers using the `addCustomParameters` API that are prefixed with `llm.` will automatically be copied to `LlmEvent`s. For custom attributes added by the `addCustomParameters` API to be added to `LlmEvent`s the API calls must occur before invoking the Bedrock SDK. One potential custom attribute with special meaning that customers are encouraged to add is `llm.conversation_id`, which has implications in the UI and can be used to group LLM messages into specific conversations. +#### Agent Attributes + + // Set llm = true agent attribute required on TransactionEvents + + ### Metrics When in an active transaction a named span/segment for each LLM embedding and chat completion call is created using the following format: @@ -100,11 +107,6 @@ Additionally, a supportability metric is recorded to indicate if streaming is di Note: Streaming is not currently supported. - - - // Set llm = true agent attribute required on TransactionEvents - - ## Config `ai_monitoring.enabled`: Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. @@ -112,9 +114,57 @@ Note: Streaming is not currently supported. ## Testing +## Known Issues + +When using the `BedrockRuntimeAsyncClient`, which returns the response as a `CompletableFuture`, the external call to AWS isn't being captured. This is likely deeper instrumentation of the awssdk core classes, perhaps the `software.amazon.awssdk.core.internal.http.AmazonAsyncHttpClient` or `software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient`. The external call is actually made by `NettyRequestExecutor(ctx)).execute()` + +```java +"http-nio-8081-exec-9@16674" tid=0x56 nid=NA runnable + java.lang.Thread.State: RUNNABLE + at software.amazon.awssdk.http.nio.netty.internal.NettyRequestExecutor.execute(NettyRequestExecutor.java:92) + at software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient.execute(NettyNioAsyncHttpClient.java:123) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.doExecuteHttpRequest(MakeAsyncHttpRequestStage.java:189) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.executeHttpRequest(MakeAsyncHttpRequestStage.java:147) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.lambda$execute$1(MakeAsyncHttpRequestStage.java:99) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage$$Lambda/0x0000000800aefa78.accept(Unknown Source:-1) + at java.util.concurrent.CompletableFuture.uniAcceptNow(CompletableFuture.java:757) + at java.util.concurrent.CompletableFuture.uniAcceptStage(CompletableFuture.java:735) + at java.util.concurrent.CompletableFuture.thenAccept(CompletableFuture.java:2214) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.execute(MakeAsyncHttpRequestStage.java:95) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.execute(MakeAsyncHttpRequestStage.java:60) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage.execute(AsyncApiCallAttemptMetricCollectionStage.java:56) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage.execute(AsyncApiCallAttemptMetricCollectionStage.java:38) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.attemptExecute(AsyncRetryableStage.java:144) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.maybeAttemptExecute(AsyncRetryableStage.java:136) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.execute(AsyncRetryableStage.java:95) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage.execute(AsyncRetryableStage.java:79) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage.execute(AsyncRetryableStage.java:44) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage.execute(AsyncExecutionFailureExceptionReportingStage.java:41) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage.execute(AsyncExecutionFailureExceptionReportingStage.java:29) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage.execute(AsyncApiCallTimeoutTrackingStage.java:64) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage.execute(AsyncApiCallTimeoutTrackingStage.java:36) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage.execute(AsyncApiCallMetricCollectionStage.java:49) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage.execute(AsyncApiCallMetricCollectionStage.java:32) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.AmazonAsyncHttpClient$RequestExecutionBuilderImpl.execute(AmazonAsyncHttpClient.java:190) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.invoke(BaseAsyncClientHandler.java:285) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.doExecute(BaseAsyncClientHandler.java:227) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.lambda$execute$1(BaseAsyncClientHandler.java:82) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler$$Lambda/0x0000000800ab3088.get(Unknown Source:-1) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.measureApiCallSuccess(BaseAsyncClientHandler.java:291) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.execute(BaseAsyncClientHandler.java:75) + at software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler.execute(AwsAsyncClientHandler.java:52) + at software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient.invokeModel(DefaultBedrockRuntimeAsyncClient.java:161) +``` + ## TODO * Clean up request/response parsing logic +* Add Javadoc comments to interfaces * Set up and test new models * Write instrumentation tests * Finish readme +* Figure out how to get external call linked with async client diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 96ce152bdf..ece741885e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -7,6 +7,7 @@ package llm.models; +import com.newrelic.agent.bridge.Token; import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; @@ -41,6 +42,9 @@ public interface ModelInvocation { void recordLlmEvents(long startTime); + // This causes the txn to be active on the thread where the LlmEvents are created so that they properly added to the event reservoir on the txn. This is used when the model response is returned asynchronously. + void recordLlmEventsAsync(long startTime, Token token); + void reportLlmError(); Map getLinkingMetadata(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java index f78b9fc63a..04ee370eee 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java @@ -125,20 +125,7 @@ public float getTemperature() { @Override public String getRequestMessage() { - String prompt = ""; - try { - if (!getRequestBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getRequestBodyJsonMap().get(PROMPT); - if (jsonNode.isString()) { - prompt = jsonNode.asString(); - } - } else { - logParsingFailure(null, PROMPT); - } - } catch (Exception e) { - logParsingFailure(e, PROMPT); - } - return prompt; + return parseStringValue(PROMPT); } @Override @@ -164,20 +151,24 @@ public String getRole() { @Override public String getInputText() { - String inputText = ""; + return parseStringValue(INPUT_TEXT); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getRequestBodyJsonMap().get(INPUT_TEXT); + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); if (jsonNode.isString()) { - inputText = jsonNode.asString(); + parsedStringValue = jsonNode.asString(); } } else { - logParsingFailure(null, INPUT_TEXT); + logParsingFailure(null, fieldToParse); } } catch (Exception e) { - logParsingFailure(e, INPUT_TEXT); + logParsingFailure(e, fieldToParse); } - return inputText; + return parsedStringValue; } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java index d320dd0750..e260cd7da6 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java @@ -157,38 +157,29 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { */ @Override public String getResponseMessage() { - String completion = ""; - try { - if (!getResponseBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getResponseBodyJsonMap().get(COMPLETION); - if (jsonNode.isString()) { - completion = jsonNode.asString(); - } - } else { - logParsingFailure(null, COMPLETION); - } - } catch (Exception e) { - logParsingFailure(e, COMPLETION); - } - return completion; + return parseStringValue(COMPLETION); } @Override public String getStopReason() { - String stopReason = ""; + return parseStringValue(STOP_REASON); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { - JsonNode jsonNode = getResponseBodyJsonMap().get(STOP_REASON); + JsonNode jsonNode = getResponseBodyJsonMap().get(fieldToParse); if (jsonNode.isString()) { - stopReason = jsonNode.asString(); + parsedStringValue = jsonNode.asString(); } } else { - logParsingFailure(null, STOP_REASON); + logParsingFailure(null, fieldToParse); } } catch (Exception e) { - logParsingFailure(e, STOP_REASON); + logParsingFailure(e, fieldToParse); } - return stopReason; + return parsedStringValue; } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java index 0af4828d43..ec07252b3f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java @@ -7,9 +7,11 @@ package llm.models.anthropic.claude; +import com.newrelic.agent.bridge.Token; import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; import llm.events.LlmEvent; import llm.models.ModelInvocation; import llm.models.ModelRequest; @@ -143,6 +145,15 @@ public void recordLlmEvents(long startTime) { } } + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + @Override public void reportLlmError() { Map errorParams = new HashMap<>(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 64d28a0542..8faa9c0bab 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -9,6 +9,7 @@ import com.newrelic.agent.bridge.AgentBridge; import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.agent.bridge.Token; import com.newrelic.agent.bridge.Transaction; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; @@ -66,9 +67,13 @@ public CompletableFuture invokeModel(InvokeModelRequest inv Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); String modelId = invokeModelRequest.modelId(); + Token token = txn.getToken(); + + // TODO instrumentation fails if the BiConsumer is replaced with a lambda invokeModelResponseFuture.whenComplete(new BiConsumer() { @Override public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { + try { if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(linkingMetadata, userAttributes, @@ -76,7 +81,7 @@ public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) invokeModelResponse); // Set segment name based on LLM operation from response anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEvents(startTime); + anthropicClaudeModelInvocation.recordLlmEventsAsync(startTime, token); } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java index c08a373a15..39679a3a14 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java @@ -454,6 +454,7 @@ private static class LlmEventAttributeSender extends AttributeSender { public LlmEventAttributeSender(Map userAttributes) { super(new LlmEventAttributeValidator(ATTRIBUTE_TYPE)); this.userAttributes = userAttributes; + // This will have the effect of only copying attributes onto LlmEvents if there is an active transaction setTransactional(true); } From 549692cec4fde1e0282bc216d5144ac2592360f8 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 5 Mar 2024 17:31:01 -0800 Subject: [PATCH 15/68] Add support for Jurassic and Titan models. Some refactoring. --- .../aws-bedrock-runtime-2.20/README.md | 58 +++- .../src/main/java/llm/events/LlmEvent.java | 3 + .../main/java/llm/models/ModelInvocation.java | 99 ++++++- .../main/java/llm/models/ModelRequest.java | 36 +++ .../main/java/llm/models/ModelResponse.java | 72 +++++ .../main/java/llm/models/SupportedModels.java | 31 +- .../jurassic/JurassicModelInvocation.java | 203 +++++++++++++ .../jurassic/JurassicModelRequest.java | 157 ++++++++++ .../jurassic/JurassicModelResponse.java | 280 ++++++++++++++++++ .../amazon/titan/TitanModelInvocation.java | 203 +++++++++++++ .../amazon/titan/TitanModelRequest.java | 170 +++++++++++ .../amazon/titan/TitanModelResponse.java | 268 +++++++++++++++++ ...cation.java => ClaudeModelInvocation.java} | 54 ++-- ...elRequest.java => ClaudeModelRequest.java} | 6 +- ...Response.java => ClaudeModelResponse.java} | 25 +- ...ockRuntimeAsyncClient_Instrumentation.java | 40 ++- .../BedrockRuntimeClient_Instrumentation.java | 38 ++- 17 files changed, 1662 insertions(+), 81 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java rename instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/{AnthropicClaudeModelInvocation.java => ClaudeModelInvocation.java} (73%) rename instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/{AnthropicClaudeInvokeModelRequest.java => ClaudeModelRequest.java} (95%) rename instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/{AnthropicClaudeInvokeModelResponse.java => ClaudeModelResponse.java} (92%) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index c3117872ba..462ed28389 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -19,13 +19,27 @@ Note: Currently, `invokeModelWithResponseStream` is not supported. ### Supported Models -Currently, only the following text-based foundation models are supported: - -* Anthropic Claude -* Amazon Titan -* Meta Llama 2 -* Cohere Command -* AI21 Labs Jurassic +At the time of the instrumentation being published, only the following text-based foundation models have been tested and confirmed as supported: + +* AI21 Labs + * Jurassic-2 Ultra (ai21.j2-ultra-v1) + * Jurassic-2 Mid (ai21.j2-mid-v1) +* Amazon + * Titan Embeddings G1 - Text (amazon.titan-embed-text-v1) + * Titan Text G1 - Lite (amazon.titan-text-lite-v1) + * Titan Text G1 - Express (amazon.titan-text-express-v1) + * Titan Multimodal Embeddings G1 (amazon.titan-embed-image-v1) +* Anthropic + * Claude (anthropic.claude-v2, anthropic.claude-v2:1) + * Claude Instant (anthropic.claude-instant-v1) +* Cohere + * Command (cohere.command-text-v14) + * Command Light (cohere.command-light-text-v14) + * Embed English (cohere.embed-english-v3) + * Embed Multilingual (cohere.embed-multilingual-v3) +* Meta + * Llama 2 Chat 13B (meta.llama2-13b-chat-v1) + * Llama 2 Chat 70B (meta.llama2-70b-chat-v1) ## Involved Pieces @@ -112,6 +126,12 @@ Note: Streaming is not currently supported. `ai_monitoring.enabled`: Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. `ai_monitoring.streaming.enabled`: NOT SUPPORTED +## Related Agent APIs + +feedback +callback +addCustomParameter + ## Testing ## Known Issues @@ -162,9 +182,33 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com ## TODO +* Make all LLM event attribute values un-truncated https://source.datanerd.us/agents/agent-specs/pull/664 +* Add new `ai_monitoring.record_content.enabled` config https://source.datanerd.us/agents/agent-specs/pull/663 +* Refactoring related to token count, new callback API https://source.datanerd.us/agents/agent-specs/pull/662 * Clean up request/response parsing logic * Add Javadoc comments to interfaces * Set up and test new models + * AI21 Labs + * Jurassic-2 Ultra (~~ai21.j2-ultra-v1~~) + * Jurassic-2 Mid (~~ai21.j2-mid-v1~~) + * Amazon + * Titan Embeddings G1 - Text (~~amazon.titan-embed-text-v1~~) + * Titan Text G1 - Lite (~~amazon.titan-text-lite-v1~~) + * Titan Text G1 - Express (~~amazon.titan-text-express-v1~~) + * Titan Multimodal Embeddings G1 (~~amazon.titan-embed-image-v1~~) + * Anthropic + * Claude (~~anthropic.claude-v2~~, ~~anthropic.claude-v2:1~~) + * Claude Instant (~~anthropic.claude-instant-v1~~) + * Cohere + * Command (cohere.command-text-v14) + * Command Light (cohere.command-light-text-v14) + * Embed English (cohere.embed-english-v3) + * Embed Multilingual (cohere.embed-multilingual-v3) + * Meta + * Llama 2 Chat 13B (meta.llama2-13b-chat-v1) + * Llama 2 Chat 70B (meta.llama2-70b-chat-v1) +* Test env var and sys prop config * Write instrumentation tests * Finish readme +* Refactor test app to have multiple invokeMethods for a single transaction... * Figure out how to get external call linked with async client diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index adaf02f9f2..2b49efce58 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -135,6 +135,9 @@ public Builder role(boolean isUser) { role = "user"; } else { role = modelRequest.getRole(); + if (role.isEmpty()) { + role = "assistant"; + } } return this; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index ece741885e..e867d6b3da 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -15,7 +15,7 @@ import java.util.Map; import java.util.UUID; -import static llm.vendor.Vendor.VENDOR; +import static llm.vendor.Vendor.BEDROCK; public interface ModelInvocation { /** @@ -23,6 +23,9 @@ public interface ModelInvocation { * Llm/{operation_type}/{vendor_name}/{function_name} *

* Used with the sync client + * + * @param txn current transaction + * @param functionName name of SDK function being invoked */ void setTracedMethodName(Transaction txn, String functionName); @@ -31,47 +34,124 @@ public interface ModelInvocation { * Llm/{operation_type}/{vendor_name}/{function_name} *

* Used with the async client + * + * @param segment active segment for async timing + * @param functionName name of SDK function being invoked */ void setSegmentName(Segment segment, String functionName); + /** + * Record an LlmEmbedding event that captures data specific to the creation of an embedding. + * + * @param startTime start time of SDK invoke method + */ void recordLlmEmbeddingEvent(long startTime); + /** + * Record an LlmChatCompletionSummary event that captures high-level data about + * the creation of a chat completion including request, response, and call information. + * + * @param startTime start time of SDK invoke method + * @param numberOfMessages total number of LlmChatCompletionMessage events associated with the summary + */ void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages); + /** + * Record an LlmChatCompletionMessage event that corresponds to each message (sent and received) + * from a chat completion call including those created by the user, assistant, and the system. + * + * @param sequence index starting at 0 associated with each message + * @param message String representing the input/output message + */ void recordLlmChatCompletionMessageEvent(int sequence, String message); + /** + * Record all LLM events when using the sync client. + * + * @param startTime start time of SDK invoke method + */ void recordLlmEvents(long startTime); - // This causes the txn to be active on the thread where the LlmEvents are created so that they properly added to the event reservoir on the txn. This is used when the model response is returned asynchronously. + /** + * Record all LLM events when using the async client. + *

+ * This causes the txn to be active on the thread where the LlmEvents are created so + * that they properly added to the event reservoir on the txn. This is used when the + * model response is returned asynchronously via CompleteableFuture. + * + * @param startTime start time of SDK invoke method + * @param token Token used to link the transaction to the thread that produces the response + */ void recordLlmEventsAsync(long startTime, Token token); + /** + * Report an LLM error. + */ void reportLlmError(); + /** + * Get a map of linking metadata. + * + * @return Map of linking metadata + */ Map getLinkingMetadata(); + /** + * Get a map of user custom attributes. + * + * @return Map of user custom attributes + */ Map getUserAttributes(); + /** + * Get a ModelRequest wrapper class for the SDK Request object. + * + * @return ModelRequest + */ ModelRequest getModelRequest(); + /** + * Get a ModelResponse wrapper class for the SDK Response object. + * + * @return ModelResponse + */ ModelResponse getModelResponse(); /** + * Increment a Supportability metric indicating that the SDK was instrumented. + *

* This needs to be incremented for every invocation of the SDK. * Supportability/{language}/ML/{vendor_name}/{vendor_version} + * + * @param vendorVersion version of vendor */ static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { - NewRelic.incrementCounter("Supportability/Java/ML/" + VENDOR + "/" + vendorVersion); + NewRelic.incrementCounter("Supportability/Java/ML/" + BEDROCK + "/" + vendorVersion); } + /** + * Increment a Supportability metric indicating that streaming support is disabled. + */ static void incrementStreamingDisabledSupportabilityMetric() { NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); } + /** + * Set the llm:true attribute on the active transaction. + * + * @param txn current transaction + */ static void setLlmTrueAgentAttribute(Transaction txn) { // If in a txn with LLM-related spans txn.getAgentAttributes().put("llm", true); } + /** + * Get the span.id attribute from the map of linking metadata. + * + * @param linkingMetadata Map of linking metadata + * @return String representing the span.id + */ static String getSpanId(Map linkingMetadata) { if (linkingMetadata != null && !linkingMetadata.isEmpty()) { return linkingMetadata.get("span.id"); @@ -79,7 +159,12 @@ static String getSpanId(Map linkingMetadata) { return ""; } - // ID of the current trace + /** + * Get the trace.id attribute from the map of linking metadata. + * + * @param linkingMetadata Map of linking metadata + * @return String representing the trace.id + */ static String getTraceId(Map linkingMetadata) { if (linkingMetadata != null && !linkingMetadata.isEmpty()) { return linkingMetadata.get("trace.id"); @@ -87,7 +172,11 @@ static String getTraceId(Map linkingMetadata) { return ""; } - // Returns a string representation of a random GUID + /** + * Generate a string representation of a random GUID + * + * @return String representation of a GUID + */ static String getRandomGuid() { return UUID.randomUUID().toString(); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index 9339ff0a8c..d30c589bea 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -12,18 +12,54 @@ import java.util.logging.Level; public interface ModelRequest { + /** + * Get the max tokens allowed for the request. + * + * @return int representing the max tokens allowed for the request + */ int getMaxTokensToSample(); + /** + * Get the temperature of the request. + * + * @return float representing the temperature of the request + */ float getTemperature(); + /** + * Get the content of the request message. + * + * @return String representing the content of the request message + */ String getRequestMessage(); + /** + * Get the role of the requester. + * + * @return String representing the role of the requester + */ String getRole(); + /** + * Get the input to the embedding creation call. + * + * @return String representing the input to the embedding creation call + */ String getInputText(); + /** + * Get the LLM model ID. + * + * @return String representing the LLM model ID + */ String getModelId(); + /** + * Log when a parsing error occurs. + * + * @param e Exception encountered when parsing the request + * @param fieldBeingParsed field that was being parsed + */ static void logParsingFailure(Exception e, String fieldBeingParsed) { if (e != null) { NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelRequest"); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index 9c7be62592..150c47552d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -12,33 +12,105 @@ import java.util.logging.Level; public interface ModelResponse { + // Response headers + String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; + String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; + String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; + + // Operation types String COMPLETION = "completion"; String EMBEDDING = "embedding"; + /** + * Get the response message. + * + * @return String representing the response message + */ String getResponseMessage(); + /** + * Get the stop reason. + * + * @return String representing the stop reason + */ String getStopReason(); + /** + * Get the count of input tokens used. + * + * @return int representing the count of input tokens used + */ int getInputTokenCount(); + /** + * Get the count of output tokens used. + * + * @return int representing the count of output tokens used + */ int getOutputTokenCount(); + /** + * Get the count of total tokens used. + * + * @return int representing the count of total tokens used + */ int getTotalTokenCount(); + /** + * Get the Amazon Request ID. + * + * @return String representing the Amazon Request ID + */ String getAmznRequestId(); + /** + * Get the operation type. + * + * @return String representing the operation type + */ String getOperationType(); + /** + * Get the ID for the associated LlmChatCompletionSummary event. + * + * @return String representing the ID for the associated LlmChatCompletionSummary event + */ String getLlmChatCompletionSummaryId(); + /** + * Get the ID for the associated LlmEmbedding event. + * + * @return String representing the ID for the associated LlmEmbedding event + */ String getLlmEmbeddingId(); + /** + * Determine whether the response resulted in an error or not. + * + * @return boolean true when the LLM response is an error, false when the response was successful + */ boolean isErrorResponse(); + /** + * Get the response status code. + * + * @return int representing the response status code + */ int getStatusCode(); + /** + * Get the response status text. + * + * @return String representing the response status text + */ String getStatusText(); + /** + * Log when a parsing error occurs. + * + * @param e Exception encountered when parsing the response + * @param fieldBeingParsed field that was being parsed + */ static void logParsingFailure(Exception e, String fieldBeingParsed) { if (e != null) { NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelResponse"); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java index 8f94975185..6c5cdb7972 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java @@ -8,9 +8,30 @@ package llm.models; public class SupportedModels { - public static final String ANTHROPIC_CLAUDE = "claude"; - public static final String AMAZON_TITAN = "titan"; - public static final String META_LLAMA_2 = "llama"; - public static final String COHERE_COMMAND = "cohere"; - public static final String AI_21_LABS_JURASSIC = "jurassic"; + public static final String ANTHROPIC_CLAUDE = "anthropic.claude"; + public static final String AMAZON_TITAN = "amazon.titan"; + public static final String META_LLAMA_2 = "meta.llama2"; + public static final String COHERE_COMMAND = "cohere.command"; + public static final String COHERE_EMBED = "cohere.embed"; + public static final String AI_21_LABS_JURASSIC = "ai21.j2"; } + +//*AI21 Labs +// *Jurassic-2Ultra(ai21.j2-ultra-v1) +// *Jurassic-2Mid(ai21.j2-mid-v1) +//*Amazon +// *Titan Embeddings G1-Text(amazon.titan-embed-text-v1) +// *Titan Text G1-Lite(amazon.titan-text-lite-v1) +// *Titan Text G1-Express(amazon.titan-text-express-v1) +// *Titan Multimodal Embeddings G1(amazon.titan-embed-image-v1) +//*Anthropic +// *Claude(anthropic.claude-v2,anthropic.claude-v2:1) +// *Claude Instant(anthropic.claude-instant-v1) +//*Cohere +// *Command(cohere.command-text-v14) +// *Command Light(cohere.command-light-text-v14) +// *Embed English(cohere.embed-english-v3) +// *Embed Multilingual(cohere.embed-multilingual-v3) +//*Meta +// *Llama 2Chat 13B(meta.llama2-13b-chat-v1) +// *Llama 2Chat 70B(meta.llama2-70b-chat-v1) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java new file mode 100644 index 0000000000..77ce6d96c9 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -0,0 +1,203 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class JurassicModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public JurassicModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new JurassicModelRequest(invokeModelRequest); + this.modelResponse = new JurassicModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input() + .requestModel() + .responseModel() + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .responseUsageCompletionTokens() + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message) { + boolean isUser = sequence % 2 == 0; + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvent(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, 2); + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java new file mode 100644 index 0000000000..4a3e223458 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -0,0 +1,157 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class JurassicModelRequest implements ModelRequest { + private static final String MAX_TOKENS = "maxTokens"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + private static final String INPUT_TEXT = "inputText"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public JurassicModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } else { + logParsingFailure(null, MAX_TOKENS); + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKENS); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public String getRequestMessage() { + return parseStringValue(PROMPT); + } + + @Override + public String getRole() { + // This is effectively a NoOp for Jurassic as the request doesn't contain any signifier of the role + return ""; + } + + @Override + public String getInputText() { + return parseStringValue(INPUT_TEXT); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } else { + logParsingFailure(null, fieldToParse); + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java new file mode 100644 index 0000000000..f3f5b3be07 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -0,0 +1,280 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class JurassicModelResponse implements ModelResponse { + private static final String FINISH_REASON = "finishReason"; + private static final String REASON = "reason"; + private static final String COMPLETIONS = "completions"; + private static final String DATA = "data"; + private static final String TEXT = "text"; + + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public JurassicModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + setHeaderFields(invokeModelResponse); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(COMPLETION)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.contains(EMBEDDING)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + /** + * Parses header values from the response object and assigns them to fields. + * + * @param invokeModelResponse response object + */ + private void setHeaderFields(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + try { + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + String result = inputTokenCountHeaders.get(0); + inputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + String result = outputTokenCountHeaders.get(0); + outputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); + } + } else { + logParsingFailure(null, "response headers"); + } + } catch (Exception e) { + logParsingFailure(e, "response headers"); + } + } + + @Override + public String getResponseMessage() { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List jsonNodeArray = completionsJsonNode.asArray(); + if (!jsonNodeArray.isEmpty()) { + JsonNode jsonNode = jsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode dataJsonNode = jsonNodeObject.get(DATA); + if (dataJsonNode.isObject()) { + Map dataJsonNodeObject = dataJsonNode.asObject(); + if (!dataJsonNodeObject.isEmpty()) { + JsonNode textJsonNode = dataJsonNodeObject.get(TEXT); + if (textJsonNode.isString()) { + parsedResponseMessage = textJsonNode.asString(); + } + } + } + } + } + } + } + } else { + logParsingFailure(null, TEXT); + } + } catch (Exception e) { + logParsingFailure(e, TEXT); + } + return parsedResponseMessage; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List jsonNodeArray = completionsJsonNode.asArray(); + if (!jsonNodeArray.isEmpty()) { + JsonNode jsonNode = jsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode dataJsonNode = jsonNodeObject.get(FINISH_REASON); + if (dataJsonNode.isObject()) { + Map dataJsonNodeObject = dataJsonNode.asObject(); + if (!dataJsonNodeObject.isEmpty()) { + JsonNode textJsonNode = dataJsonNodeObject.get(REASON); + if (textJsonNode.isString()) { + parsedStopReason = textJsonNode.asString(); + } + } + } + } + } + } + } + } else { + logParsingFailure(null, FINISH_REASON); + } + } catch (Exception e) { + logParsingFailure(e, FINISH_REASON); + } + return parsedStopReason; + } + + @Override + public int getInputTokenCount() { + return inputTokenCount; + } + + @Override + public int getOutputTokenCount() { + return outputTokenCount; + } + + @Override + public int getTotalTokenCount() { + return inputTokenCount + outputTokenCount; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java new file mode 100644 index 0000000000..87883a5119 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -0,0 +1,203 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class TitanModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public TitanModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new TitanModelRequest(invokeModelRequest); + this.modelResponse = new TitanModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input() + .requestModel() + .responseModel() + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .responseUsageCompletionTokens() + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message) { + boolean isUser = sequence % 2 == 0; + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvent(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, 2); + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java new file mode 100644 index 0000000000..ab2e44f243 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -0,0 +1,170 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class TitanModelRequest implements ModelRequest { + private static final String MAX_TOKEN_COUNT = "maxTokenCount"; + private static final String TEMPERATURE = "temperature"; + private static final String TEXT_GENERATION_CONFIG = "textGenerationConfig"; + private static final String INPUT_TEXT = "inputText"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public TitanModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textGenConfigJsonNode = getRequestBodyJsonMap().get(TEXT_GENERATION_CONFIG); + if (textGenConfigJsonNode.isObject()) { + Map textGenConfigJsonNodeObject = textGenConfigJsonNode.asObject(); + if (!textGenConfigJsonNodeObject.isEmpty()) { + JsonNode maxTokenCountJsonNode = textGenConfigJsonNodeObject.get(MAX_TOKEN_COUNT); + if (maxTokenCountJsonNode.isNumber()) { + String maxTokenCountString = maxTokenCountJsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokenCountString); + } + } + } + + } else { + logParsingFailure(null, MAX_TOKEN_COUNT); + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKEN_COUNT); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textGenConfigJsonNode = getRequestBodyJsonMap().get(TEXT_GENERATION_CONFIG); + if (textGenConfigJsonNode.isObject()) { + Map textGenConfigJsonNodeObject = textGenConfigJsonNode.asObject(); + if (!textGenConfigJsonNodeObject.isEmpty()) { + JsonNode temperatureJsonNode = textGenConfigJsonNodeObject.get(TEMPERATURE); + if (temperatureJsonNode.isNumber()) { + String temperatureString = temperatureJsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public String getRequestMessage() { + return parseStringValue(INPUT_TEXT); + } + + @Override + public String getRole() { + // This is effectively a NoOp for Titan as the request doesn't contain any signifier of the role + return ""; + } + + @Override + public String getInputText() { + return parseStringValue(INPUT_TEXT); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } else { + logParsingFailure(null, fieldToParse); + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java new file mode 100644 index 0000000000..8a321ad2a2 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -0,0 +1,268 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class TitanModelResponse implements ModelResponse { + private static final String COMPLETION_REASON = "completionReason"; + private static final String RESULTS = "results"; + private static final String OUTPUT_TEXT = "outputText"; + + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + private static final String JSON_START = "{\""; + + public TitanModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + setHeaderFields(invokeModelResponse); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(COMPLETION_REASON)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + /** + * Parses header values from the response object and assigns them to fields. + * + * @param invokeModelResponse response object + */ + private void setHeaderFields(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + try { + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + String result = inputTokenCountHeaders.get(0); + inputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + String result = outputTokenCountHeaders.get(0); + outputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); + } + } else { + logParsingFailure(null, "response headers"); + } + } catch (Exception e) { + logParsingFailure(e, "response headers"); + } + } + + @Override + public String getResponseMessage() { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + JsonNode resultsJsonNode = resultsJsonNodeArray.get(0); + if (resultsJsonNode.isObject()) { + Map resultsJsonNodeObject = resultsJsonNode.asObject(); + if (!resultsJsonNodeObject.isEmpty()) { + JsonNode outputTextJsonNode = resultsJsonNodeObject.get(OUTPUT_TEXT); + if (outputTextJsonNode.isString()) { + parsedResponseMessage = outputTextJsonNode.asString(); + } + } + } + } + } + } else { + logParsingFailure(null, OUTPUT_TEXT); + } + } catch (Exception e) { + logParsingFailure(e, OUTPUT_TEXT); + } + return parsedResponseMessage; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + JsonNode resultsJsonNode = resultsJsonNodeArray.get(0); + if (resultsJsonNode.isObject()) { + Map resultsJsonNodeObject = resultsJsonNode.asObject(); + if (!resultsJsonNodeObject.isEmpty()) { + JsonNode outputTextJsonNode = resultsJsonNodeObject.get(COMPLETION_REASON); + if (outputTextJsonNode.isString()) { + parsedStopReason = outputTextJsonNode.asString(); + } + } + } + } + } + } else { + logParsingFailure(null, COMPLETION_REASON); + } + } catch (Exception e) { + logParsingFailure(e, COMPLETION_REASON); + } + return parsedStopReason; + } + + @Override + public int getInputTokenCount() { + return inputTokenCount; + } + + @Override + public int getOutputTokenCount() { + return outputTokenCount; + } + + @Override + public int getTotalTokenCount() { + return inputTokenCount + outputTokenCount; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java similarity index 73% rename from instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java rename to instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java index ec07252b3f..7118548ff2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -23,37 +23,37 @@ import java.util.Map; import java.util.logging.Level; -import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.COMPLETION; -import static llm.models.anthropic.claude.AnthropicClaudeInvokeModelResponse.EMBEDDING; +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; import static llm.vendor.Vendor.BEDROCK; -public class AnthropicClaudeModelInvocation implements ModelInvocation { +public class ClaudeModelInvocation implements ModelInvocation { Map linkingMetadata; Map userAttributes; - ModelRequest claudeRequest; - ModelResponse claudeResponse; + ModelRequest modelRequest; + ModelResponse modelResponse; - public AnthropicClaudeModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + public ClaudeModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, InvokeModelResponse invokeModelResponse) { this.linkingMetadata = linkingMetadata; this.userAttributes = userCustomAttributes; - this.claudeRequest = new AnthropicClaudeInvokeModelRequest(invokeModelRequest); - this.claudeResponse = new AnthropicClaudeInvokeModelResponse(invokeModelResponse); + this.modelRequest = new ClaudeModelRequest(invokeModelRequest); + this.modelResponse = new ClaudeModelResponse(invokeModelResponse); } @Override public void setTracedMethodName(Transaction txn, String functionName) { - txn.getTracedMethod().setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); } @Override public void setSegmentName(Segment segment, String functionName) { - segment.setMetricName("Llm", claudeResponse.getOperationType(), BEDROCK, functionName); + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); } @Override public void recordLlmEmbeddingEvent(long startTime) { - if (claudeResponse.isErrorResponse()) { + if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -64,7 +64,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .traceId() .vendor() .ingestSource() - .id(claudeResponse.getLlmEmbeddingId()) + .id(modelResponse.getLlmEmbeddingId()) .requestId() .input() .requestModel() @@ -80,7 +80,7 @@ public void recordLlmEmbeddingEvent(long startTime) { @Override public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { - if (claudeResponse.isErrorResponse()) { + if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -91,7 +91,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .traceId() .vendor() .ingestSource() - .id(claudeResponse.getLlmChatCompletionSummaryId()) + .id(modelResponse.getLlmChatCompletionSummaryId()) .requestId() .requestTemperature() .requestMaxTokens() @@ -111,7 +111,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = message.contains("Human:"); + boolean isUser = sequence % 2 == 0; LlmEvent.Builder builder = new LlmEvent.Builder(this); @@ -135,7 +135,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { @Override public void recordLlmEvents(long startTime) { - String operationType = claudeResponse.getOperationType(); + String operationType = modelResponse.getOperationType(); if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { @@ -157,15 +157,15 @@ public void recordLlmEventsAsync(long startTime, Token token) { @Override public void reportLlmError() { Map errorParams = new HashMap<>(); - errorParams.put("http.statusCode", claudeResponse.getStatusCode()); - errorParams.put("error.code", claudeResponse.getStatusCode()); - if (!claudeResponse.getLlmChatCompletionSummaryId().isEmpty()) { - errorParams.put("completion_id", claudeResponse.getLlmChatCompletionSummaryId()); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); } - if (!claudeResponse.getLlmEmbeddingId().isEmpty()) { - errorParams.put("embedding_id", claudeResponse.getLlmEmbeddingId()); + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); } - NewRelic.noticeError("LlmError: " + claudeResponse.getStatusText(), errorParams); + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); } /** @@ -174,9 +174,9 @@ public void reportLlmError() { */ private void recordLlmChatCompletionEvents(long startTime) { // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, claudeRequest.getRequestMessage()); + recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, claudeResponse.getResponseMessage()); + recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); // A summary of all LlmChatCompletionMessage events recordLlmChatCompletionSummaryEvent(startTime, 2); } @@ -193,11 +193,11 @@ public Map getUserAttributes() { @Override public ModelRequest getModelRequest() { - return claudeRequest; + return modelRequest; } @Override public ModelResponse getModelResponse() { - return claudeResponse; + return modelResponse; } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java similarity index 95% rename from instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java rename to instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java index 04ee370eee..f6cc4d6058 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -23,9 +23,7 @@ * Stores the required info from the Bedrock InvokeModelRequest without holding * a reference to the actual request object to avoid potential memory issues. */ -public class AnthropicClaudeInvokeModelRequest implements ModelRequest { - // TODO might be able to move some of these constants to the ModelRequest interface - // need to figure out if they are consistent across all models +public class ClaudeModelRequest implements ModelRequest { private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; @@ -39,7 +37,7 @@ public class AnthropicClaudeInvokeModelRequest implements ModelRequest { private String modelId = ""; private Map requestBodyJsonMap = null; - public AnthropicClaudeInvokeModelRequest(InvokeModelRequest invokeModelRequest) { + public ClaudeModelRequest(InvokeModelRequest invokeModelRequest) { if (invokeModelRequest != null) { invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); modelId = invokeModelRequest.modelId(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java similarity index 92% rename from instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java rename to instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java index e260cd7da6..85cb37186b 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/AnthropicClaudeInvokeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -26,14 +26,9 @@ * Stores the required info from the Bedrock InvokeModelResponse without holding * a reference to the actual request object to avoid potential memory issues. */ -public class AnthropicClaudeInvokeModelResponse implements ModelResponse { +public class ClaudeModelResponse implements ModelResponse { private static final String STOP_REASON = "stop_reason"; - // Response headers - private static final String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; - private static final String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; - private static final String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; - private int inputTokenCount = 0; private int outputTokenCount = 0; private String amznRequestId = ""; @@ -54,7 +49,7 @@ public class AnthropicClaudeInvokeModelResponse implements ModelResponse { private static final String JSON_START = "{\""; - public AnthropicClaudeInvokeModelResponse(InvokeModelResponse invokeModelResponse) { + public ClaudeModelResponse(InvokeModelResponse invokeModelResponse) { if (invokeModelResponse != null) { invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); @@ -96,7 +91,6 @@ private Map parseInvokeModelResponseBodyMap() { JsonNodeParser jsonNodeParser = JsonNodeParser.create(); JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); - // TODO check for other types? Or will it always be Object? if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { responseBodyJsonMap = responseBodyJsonNode.asObject(); } else { @@ -108,6 +102,11 @@ private Map parseInvokeModelResponseBodyMap() { return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); } + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ private void setOperationType(String invokeModelResponseBody) { try { if (!invokeModelResponseBody.isEmpty()) { @@ -124,6 +123,11 @@ private void setOperationType(String invokeModelResponseBody) { } } + /** + * Parses header values from the response object and assigns them to fields. + * + * @param invokeModelResponse response object + */ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { Map> headers = invokeModelResponse.sdkHttpResponse().headers(); try { @@ -150,11 +154,6 @@ private void setHeaderFields(InvokeModelResponse invokeModelResponse) { } } - /** - * Represents the response message - * - * @return - */ @Override public String getResponseMessage() { return parseStringValue(COMPLETION); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 8faa9c0bab..09a8956a53 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -18,7 +18,11 @@ import com.newrelic.api.agent.weaver.Weave; import com.newrelic.api.agent.weaver.Weaver; import llm.models.ModelInvocation; -import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; +import llm.models.ai21labs.jurassic.JurassicModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import llm.models.cohere.command.CommandModelInvocation; +import llm.models.meta.llama2.Llama2ModelInvocation; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; @@ -35,6 +39,7 @@ import static llm.models.SupportedModels.AMAZON_TITAN; import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.COHERE_EMBED; import static llm.models.SupportedModels.META_LLAMA_2; import static llm.vendor.Vendor.VENDOR_VERSION; @@ -76,20 +81,35 @@ public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) try { if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(linkingMetadata, userAttributes, - invokeModelRequest, + ModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); // Set segment name based on LLM operation from response - anthropicClaudeModelInvocation.setSegmentName(segment, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEventsAsync(startTime, token); + claudeModelInvocation.setSegmentName(segment, "invokeModel"); + claudeModelInvocation.recordLlmEventsAsync(startTime, token); } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { - + ModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + titanModelInvocation.setTracedMethodName(txn, "invokeModel"); + titanModelInvocation.recordLlmEventsAsync(startTime, token); } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { - - } else if (modelId.toLowerCase().contains(COHERE_COMMAND)) { - + ModelInvocation llama2ModelInvocation = new Llama2ModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); + llama2ModelInvocation.recordLlmEventsAsync(startTime, token); + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { // TODO can be combined with COHERE_EMBED? OR should these be separate? + ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + commandModelInvocation.setTracedMethodName(txn, "invokeModel"); + commandModelInvocation.recordLlmEventsAsync(startTime, token); } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { - + ModelInvocation jurassicModelInvocation = new JurassicModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + jurassicModelInvocation.setTracedMethodName(txn, "invokeModel"); + jurassicModelInvocation.recordLlmEventsAsync(startTime, token); } segment.end(); } catch (Throwable t) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java index b962cc56da..63264c9acd 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -16,7 +16,11 @@ import com.newrelic.api.agent.weaver.Weave; import com.newrelic.api.agent.weaver.Weaver; import llm.models.ModelInvocation; -import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; +import llm.models.ai21labs.jurassic.JurassicModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import llm.models.cohere.command.CommandModelInvocation; +import llm.models.meta.llama2.Llama2ModelInvocation; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -27,6 +31,7 @@ import static llm.models.SupportedModels.AMAZON_TITAN; import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.COHERE_EMBED; import static llm.models.SupportedModels.META_LLAMA_2; import static llm.vendor.Vendor.VENDOR_VERSION; @@ -54,19 +59,32 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { String modelId = invokeModelRequest.modelId(); if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { - ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, - invokeModelResponse); + ModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation from response - anthropicClaudeModelInvocation.setTracedMethodName(txn, "invokeModel"); - anthropicClaudeModelInvocation.recordLlmEvents(startTime); + claudeModelInvocation.setTracedMethodName(txn, "invokeModel"); + claudeModelInvocation.recordLlmEvents(startTime); } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { - + ModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); + // Set traced method name based on LLM operation from response + titanModelInvocation.setTracedMethodName(txn, "invokeModel"); + titanModelInvocation.recordLlmEvents(startTime); } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { - - } else if (modelId.toLowerCase().contains(COHERE_COMMAND)) { - + ModelInvocation llama2ModelInvocation = new Llama2ModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); + // Set traced method name based on LLM operation from response + llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); + llama2ModelInvocation.recordLlmEvents(startTime); + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { // TODO can be combined with COHERE_EMBED? OR should these be separate? + ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + commandModelInvocation.setTracedMethodName(txn, "invokeModel"); + commandModelInvocation.recordLlmEvents(startTime); } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { - + ModelInvocation jurassicModelInvocation = new JurassicModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + jurassicModelInvocation.setTracedMethodName(txn, "invokeModel"); + jurassicModelInvocation.recordLlmEvents(startTime); } } } From 48fbd7d1a3f8f9b3b2f0a705a2bd920dc382cd7b Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 6 Mar 2024 14:07:30 -0800 Subject: [PATCH 16/68] Add support for Cohere models --- .../aws-bedrock-runtime-2.20/README.md | 13 +- .../main/java/llm/models/ModelInvocation.java | 13 + .../jurassic/JurassicModelInvocation.java | 2 +- .../jurassic/JurassicModelRequest.java | 10 +- .../jurassic/JurassicModelResponse.java | 10 +- .../amazon/titan/TitanModelInvocation.java | 2 +- .../amazon/titan/TitanModelRequest.java | 11 +- .../amazon/titan/TitanModelResponse.java | 10 +- .../claude/ClaudeModelInvocation.java | 2 +- .../anthropic/claude/ClaudeModelRequest.java | 10 +- .../anthropic/claude/ClaudeModelResponse.java | 5 +- .../command/CommandModelInvocation.java | 203 +++++++++++++ .../cohere/command/CommandModelRequest.java | 180 ++++++++++++ .../cohere/command/CommandModelResponse.java | 269 ++++++++++++++++++ ...ockRuntimeAsyncClient_Instrumentation.java | 4 +- 15 files changed, 709 insertions(+), 35 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index 462ed28389..dfc66670f7 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -182,11 +182,9 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com ## TODO -* Make all LLM event attribute values un-truncated https://source.datanerd.us/agents/agent-specs/pull/664 +* Make all LLM event attribute values un-truncated???? https://source.datanerd.us/agents/agent-specs/pull/664 * Add new `ai_monitoring.record_content.enabled` config https://source.datanerd.us/agents/agent-specs/pull/663 * Refactoring related to token count, new callback API https://source.datanerd.us/agents/agent-specs/pull/662 -* Clean up request/response parsing logic -* Add Javadoc comments to interfaces * Set up and test new models * AI21 Labs * Jurassic-2 Ultra (~~ai21.j2-ultra-v1~~) @@ -200,14 +198,15 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com * Claude (~~anthropic.claude-v2~~, ~~anthropic.claude-v2:1~~) * Claude Instant (~~anthropic.claude-instant-v1~~) * Cohere - * Command (cohere.command-text-v14) - * Command Light (cohere.command-light-text-v14) - * Embed English (cohere.embed-english-v3) - * Embed Multilingual (cohere.embed-multilingual-v3) + * Command (~~cohere.command-text-v14~~) + * Command Light (~~cohere.command-light-text-v14~~) + * Embed English (~~cohere.embed-english-v3~~) + * Embed Multilingual (~~cohere.embed-multilingual-v3~~) * Meta * Llama 2 Chat 13B (meta.llama2-13b-chat-v1) * Llama 2 Chat 70B (meta.llama2-70b-chat-v1) * Test env var and sys prop config +* Update default yaml * Write instrumentation tests * Finish readme * Refactor test app to have multiple invokeMethods for a single transaction... diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index e867d6b3da..ceb230774d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -180,4 +180,17 @@ static String getTraceId(Map linkingMetadata) { static String getRandomGuid() { return UUID.randomUUID().toString(); } + + /** + * Determine if the LLM is initiated by the user or assistant. + *

+ * Assuming that one user request is always followed by one assistant + * response, an even sequence value is the user, while odd is the assistant. + * + * @param sequence index starting at 0 associated with each message + * @return true if is user, false if not + */ + default boolean isUser(int sequence) { + return sequence % 2 == 0; + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java index 77ce6d96c9..4f900c3dc9 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -111,7 +111,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = sequence % 2 == 0; + boolean isUser = isUser(sequence); LlmEvent.Builder builder = new LlmEvent.Builder(this); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java index 4a3e223458..e14b6a6acb 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -89,12 +89,13 @@ public int getMaxTokensToSample() { String maxTokensToSampleString = jsonNode.asNumber(); maxTokensToSample = Integer.parseInt(maxTokensToSampleString); } - } else { - logParsingFailure(null, MAX_TOKENS); } } catch (Exception e) { logParsingFailure(e, MAX_TOKENS); } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS); + } return maxTokensToSample; } @@ -141,12 +142,13 @@ private String parseStringValue(String fieldToParse) { if (jsonNode.isString()) { parsedStringValue = jsonNode.asString(); } - } else { - logParsingFailure(null, fieldToParse); } } catch (Exception e) { logParsingFailure(e, fieldToParse); } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } return parsedStringValue; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java index f3f5b3be07..d4f03a10f1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -183,12 +183,13 @@ public String getResponseMessage() { } } } - } else { - logParsingFailure(null, TEXT); } } catch (Exception e) { logParsingFailure(e, TEXT); } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, TEXT); + } return parsedResponseMessage; } @@ -219,12 +220,13 @@ public String getStopReason() { } } } - } else { - logParsingFailure(null, FINISH_REASON); } } catch (Exception e) { logParsingFailure(e, FINISH_REASON); } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, FINISH_REASON); + } return parsedStopReason; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java index 87883a5119..86bc96ee8b 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -111,7 +111,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = sequence % 2 == 0; + boolean isUser = isUser(sequence); LlmEvent.Builder builder = new LlmEvent.Builder(this); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java index ab2e44f243..ccc5f04939 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -95,13 +95,13 @@ public int getMaxTokensToSample() { } } } - - } else { - logParsingFailure(null, MAX_TOKEN_COUNT); } } catch (Exception e) { logParsingFailure(e, MAX_TOKEN_COUNT); } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKEN_COUNT); + } return maxTokensToSample; } @@ -154,12 +154,13 @@ private String parseStringValue(String fieldToParse) { if (jsonNode.isString()) { parsedStringValue = jsonNode.asString(); } - } else { - logParsingFailure(null, fieldToParse); } } catch (Exception e) { logParsingFailure(e, fieldToParse); } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } return parsedStringValue; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java index 8a321ad2a2..56f8d4a8cc 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -177,12 +177,13 @@ public String getResponseMessage() { } } } - } else { - logParsingFailure(null, OUTPUT_TEXT); } } catch (Exception e) { logParsingFailure(e, OUTPUT_TEXT); } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, OUTPUT_TEXT); + } return parsedResponseMessage; } @@ -207,12 +208,13 @@ public String getStopReason() { } } } - } else { - logParsingFailure(null, COMPLETION_REASON); } } catch (Exception e) { logParsingFailure(e, COMPLETION_REASON); } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, COMPLETION_REASON); + } return parsedStopReason; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java index 7118548ff2..aa6e25a035 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -111,7 +111,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess @Override public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = sequence % 2 == 0; + boolean isUser = isUser(sequence); LlmEvent.Builder builder = new LlmEvent.Builder(this); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java index f6cc4d6058..dc6a0eceea 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -93,12 +93,13 @@ public int getMaxTokensToSample() { String maxTokensToSampleString = jsonNode.asNumber(); maxTokensToSample = Integer.parseInt(maxTokensToSampleString); } - } else { - logParsingFailure(null, MAX_TOKENS_TO_SAMPLE); } } catch (Exception e) { logParsingFailure(e, MAX_TOKENS_TO_SAMPLE); } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS_TO_SAMPLE); + } return maxTokensToSample; } @@ -160,12 +161,13 @@ private String parseStringValue(String fieldToParse) { if (jsonNode.isString()) { parsedStringValue = jsonNode.asString(); } - } else { - logParsingFailure(null, fieldToParse); } } catch (Exception e) { logParsingFailure(e, fieldToParse); } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } return parsedStringValue; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java index 85cb37186b..0ff9992061 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -172,12 +172,13 @@ private String parseStringValue(String fieldToParse) { if (jsonNode.isString()) { parsedStringValue = jsonNode.asString(); } - } else { - logParsingFailure(null, fieldToParse); } } catch (Exception e) { logParsingFailure(e, fieldToParse); } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } return parsedStringValue; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java new file mode 100644 index 0000000000..961aa41041 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java @@ -0,0 +1,203 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class CommandModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public CommandModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new CommandModelRequest(invokeModelRequest); + this.modelResponse = new CommandModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input() + .requestModel() + .responseModel() + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .responseUsageCompletionTokens() + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message) { + boolean isUser = isUser(sequence); + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvent(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, 2); + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java new file mode 100644 index 0000000000..1815581054 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java @@ -0,0 +1,180 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class CommandModelRequest implements ModelRequest { + private static final String MAX_TOKENS = "max_tokens"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + private static final String TEXTS = "texts"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public CommandModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKENS); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public String getRequestMessage() { + return parseStringValue(PROMPT); + } + + @Override + public String getRole() { + // This is effectively a NoOp for Jurassic as the request doesn't contain any signifier of the role + return ""; + } + + @Override + public String getInputText() { + String parsedInputText = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textsJsonNode = getRequestBodyJsonMap().get(TEXTS); + if (textsJsonNode.isArray()) { + List textsJsonNodeArray = textsJsonNode.asArray(); + if (!textsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = textsJsonNodeArray.get(0); + if (jsonNode.isString()) { + parsedInputText = jsonNode.asString(); + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXTS); + } + if (parsedInputText.isEmpty()) { + logParsingFailure(null, TEXTS); + } + return parsedInputText; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java new file mode 100644 index 0000000000..32951e7ea4 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java @@ -0,0 +1,269 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class CommandModelResponse implements ModelResponse { + private static final String FINISH_REASON = "finish_reason"; + private static final String GENERATIONS = "generations"; + private static final String EMBEDDINGS = "embeddings"; + private static final String TEXT = "text"; + + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public CommandModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + setHeaderFields(invokeModelResponse); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(GENERATIONS)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.contains(EMBEDDINGS)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + /** + * Parses header values from the response object and assigns them to fields. + * + * @param invokeModelResponse response object + */ + private void setHeaderFields(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + try { + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + String result = inputTokenCountHeaders.get(0); + inputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + String result = outputTokenCountHeaders.get(0); + outputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); + } + } else { + logParsingFailure(null, "response headers"); + } + } catch (Exception e) { + logParsingFailure(e, "response headers"); + } + } + + @Override + public String getResponseMessage() { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = generationsJsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode textJsonNode = jsonNodeObject.get(TEXT); + if (textJsonNode.isString()) { + parsedResponseMessage = textJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXT); + } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, TEXT); + } + return parsedResponseMessage; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = generationsJsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode finishReasonJsonNode = jsonNodeObject.get(FINISH_REASON); + if (finishReasonJsonNode.isString()) { + parsedStopReason = finishReasonJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, FINISH_REASON); + } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, FINISH_REASON); + } + return parsedStopReason; + } + + @Override + public int getInputTokenCount() { + return inputTokenCount; + } + + @Override + public int getOutputTokenCount() { + return outputTokenCount; + } + + @Override + public int getTotalTokenCount() { + return inputTokenCount + outputTokenCount; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 09a8956a53..32f2890566 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -74,7 +74,7 @@ public CompletableFuture invokeModel(InvokeModelRequest inv Token token = txn.getToken(); - // TODO instrumentation fails if the BiConsumer is replaced with a lambda + // instrumentation fails if the BiConsumer is replaced with a lambda invokeModelResponseFuture.whenComplete(new BiConsumer() { @Override public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { @@ -98,7 +98,7 @@ public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) // Set traced method name based on LLM operation from response llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); llama2ModelInvocation.recordLlmEventsAsync(startTime, token); - } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { // TODO can be combined with COHERE_EMBED? OR should these be separate? + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation from response From 842bcc2097fd3c2c02bd118b0399a1c0c301b57f Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 6 Mar 2024 15:20:42 -0800 Subject: [PATCH 17/68] Add support for Meta Llama 2 models --- .../aws-bedrock-runtime-2.20/README.md | 4 +- .../jurassic/JurassicModelRequest.java | 2 +- .../amazon/titan/TitanModelRequest.java | 2 +- .../cohere/command/CommandModelRequest.java | 2 +- .../meta/llama2/Llama2ModelInvocation.java | 203 +++++++++++++++ .../meta/llama2/Llama2ModelRequest.java | 159 ++++++++++++ .../meta/llama2/Llama2ModelResponse.java | 232 ++++++++++++++++++ 7 files changed, 599 insertions(+), 5 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index dfc66670f7..fb778568f5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -203,8 +203,8 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com * Embed English (~~cohere.embed-english-v3~~) * Embed Multilingual (~~cohere.embed-multilingual-v3~~) * Meta - * Llama 2 Chat 13B (meta.llama2-13b-chat-v1) - * Llama 2 Chat 70B (meta.llama2-70b-chat-v1) + * Llama 2 Chat 13B (~~meta.llama2-13b-chat-v1~~) + * Llama 2 Chat 70B (~~meta.llama2-70b-chat-v1~~) * Test env var and sys prop config * Update default yaml * Write instrumentation tests diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java index e14b6a6acb..fe63b6f4dd 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -125,7 +125,7 @@ public String getRequestMessage() { @Override public String getRole() { - // This is effectively a NoOp for Jurassic as the request doesn't contain any signifier of the role + // This is a NoOp for Jurassic as the request doesn't contain any signifier of the role return ""; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java index ccc5f04939..a2e8b69771 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -137,7 +137,7 @@ public String getRequestMessage() { @Override public String getRole() { - // This is effectively a NoOp for Titan as the request doesn't contain any signifier of the role + // This is a NoOp for Titan as the request doesn't contain any signifier of the role return ""; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java index 1815581054..7e819017a5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java @@ -126,7 +126,7 @@ public String getRequestMessage() { @Override public String getRole() { - // This is effectively a NoOp for Jurassic as the request doesn't contain any signifier of the role + // This is a NoOp for Jurassic as the request doesn't contain any signifier of the role return ""; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java new file mode 100644 index 0000000000..15fc08bd35 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java @@ -0,0 +1,203 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class Llama2ModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public Llama2ModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new Llama2ModelRequest(invokeModelRequest); + this.modelResponse = new Llama2ModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input() + .requestModel() + .responseModel() + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseUsageTotalTokens() + .responseUsagePromptTokens() + .responseUsageCompletionTokens() + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message) { + boolean isUser = isUser(sequence); + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvent(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + // First LlmChatCompletionMessage represents the user input prompt + recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); + // Second LlmChatCompletionMessage represents the completion message from the LLM response + recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); + // A summary of all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, 2); + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java new file mode 100644 index 0000000000..a3b793fade --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java @@ -0,0 +1,159 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class Llama2ModelRequest implements ModelRequest { + private static final String MAX_GEN_LEN = "max_gen_len"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public Llama2ModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_GEN_LEN); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_GEN_LEN); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_GEN_LEN); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public String getRequestMessage() { + return parseStringValue(PROMPT); + } + + @Override + public String getRole() { + // This is a NoOp for Llama as the request doesn't contain any signifier of the role + return ""; + } + + @Override + public String getInputText() { + // This is a NoOp for Llama as it doesn't support embeddings + return ""; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java new file mode 100644 index 0000000000..c206e093ba --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java @@ -0,0 +1,232 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class Llama2ModelResponse implements ModelResponse { + private static final String STOP_REASON = "stop_reason"; + private static final String GENERATION = "generation"; + + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public Llama2ModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + setHeaderFields(invokeModelResponse); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + // Meta Llama 2 for Bedrock doesn't support embedding operations + if (invokeModelResponseBody.contains(GENERATION)) { + operationType = COMPLETION; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + /** + * Parses header values from the response object and assigns them to fields. + * + * @param invokeModelResponse response object + */ + private void setHeaderFields(InvokeModelResponse invokeModelResponse) { + Map> headers = invokeModelResponse.sdkHttpResponse().headers(); + try { + if (!headers.isEmpty()) { + List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); + if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { + String result = inputTokenCountHeaders.get(0); + inputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); + if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { + String result = outputTokenCountHeaders.get(0); + outputTokenCount = result != null ? Integer.parseInt(result) : 0; + } + List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); + if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { + amznRequestId = amznRequestIdHeaders.get(0); + } + } else { + logParsingFailure(null, "response headers"); + } + } catch (Exception e) { + logParsingFailure(e, "response headers"); + } + } + + @Override + public String getResponseMessage() { + return parseStringValue(GENERATION); + } + + @Override + public String getStopReason() { + return parseStringValue(STOP_REASON); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public int getInputTokenCount() { + return inputTokenCount; + } + + @Override + public int getOutputTokenCount() { + return outputTokenCount; + } + + @Override + public int getTotalTokenCount() { + return inputTokenCount + outputTokenCount; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} From bafd7cf4c886809495d022e521faf43c58e0e101 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 6 Mar 2024 15:22:04 -0800 Subject: [PATCH 18/68] Cleanup --- .../BedrockRuntimeClient_Instrumentation.java | 2 +- ...ockRuntimeAsyncClient_Instrumentation.java | 86 ------------------- ...tBedrockRuntimeClient_Instrumentation.java | 85 ------------------ 3 files changed, 1 insertion(+), 172 deletions(-) delete mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java delete mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java index 63264c9acd..c1411bbec1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -73,7 +73,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { // Set traced method name based on LLM operation from response llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); llama2ModelInvocation.recordLlmEvents(startTime); - } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { // TODO can be combined with COHERE_EMBED? OR should these be separate? + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); // Set traced method name based on LLM operation from response diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java deleted file mode 100644 index 167352fdac..0000000000 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeAsyncClient_Instrumentation.java +++ /dev/null @@ -1,86 +0,0 @@ -///* -// * -// * * Copyright 2024 New Relic Corporation. All rights reserved. -// * * SPDX-License-Identifier: Apache-2.0 -// * -// */ -// -//package software.amazon.awssdk.services.bedrockruntime; -// -//import com.newrelic.agent.bridge.AgentBridge; -//import com.newrelic.api.agent.NewRelic; -//import com.newrelic.api.agent.Segment; -//import com.newrelic.api.agent.Trace; -//import com.newrelic.api.agent.weaver.MatchType; -//import com.newrelic.api.agent.weaver.Weave; -//import com.newrelic.api.agent.weaver.Weaver; -//import software.amazon.awssdk.core.client.config.SdkClientConfiguration; -//import software.amazon.awssdk.core.client.handler.AsyncClientHandler; -//import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; -//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; -//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; -// -//import java.util.concurrent.CompletableFuture; -//import java.util.concurrent.Executor; -// -//import static llm.models.ModelInvocation.incrementInstrumentedSupportabilityMetric; -//import static llm.vendor.Vendor.VENDOR_VERSION; -// -///** -// * Service client for accessing Amazon Bedrock Runtime asynchronously. -// */ -//// TODO switch back to instrumenting the BedrockRuntimeAsyncClient interface instead of this implementation class -//@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient") -//final class DefaultBedrockRuntimeAsyncClient_Instrumentation { -//// private static final Logger log = LoggerFactory.getLogger(DefaultBedrockRuntimeAsyncClient.class); -//// -//// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() -//// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); -// -// private final AsyncClientHandler clientHandler; -// -// private final AwsJsonProtocolFactory protocolFactory; -// -// private final SdkClientConfiguration clientConfiguration; -// -// private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; -// -// private final Executor executor; -// -// protected DefaultBedrockRuntimeAsyncClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, -// SdkClientConfiguration clientConfiguration) { -// this.clientHandler = Weaver.callOriginal(); -// this.clientConfiguration = Weaver.callOriginal(); -// this.serviceClientConfiguration = Weaver.callOriginal(); -// this.protocolFactory = Weaver.callOriginal(); -// this.executor = Weaver.callOriginal(); -// } -// -// @Trace -// public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { -// long startTime = System.currentTimeMillis(); -// // TODO name "Llm/" + operationType + "/Bedrock/InvokeModelAsync" ???? -// Segment segment = NewRelic.getAgent().getTransaction().startSegment("LLM", "InvokeModelAsync"); -// CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); -// -// incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); -// -// // this should never happen, but protecting against bad implementations -// if (invokeModelResponseFuture == null) { -// segment.end(); -// } else { -// invokeModelResponseFuture.whenComplete((invokeModelResponse, throwable) -> { -// try { -// // TODO do all the stuff -// segment.end(); -// } catch (Throwable t) { -// AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); -// } -// }); -// } -// -// return invokeModelResponseFuture; -// -// } -// -//} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java deleted file mode 100644 index d810db8ecf..0000000000 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/DefaultBedrockRuntimeClient_Instrumentation.java +++ /dev/null @@ -1,85 +0,0 @@ -///* -// * -// * * Copyright 2024 New Relic Corporation. All rights reserved. -// * * SPDX-License-Identifier: Apache-2.0 -// * -// */ -// -//package software.amazon.awssdk.services.bedrockruntime; -// -//import com.newrelic.agent.bridge.AgentBridge; -//import com.newrelic.agent.bridge.NoOpTransaction; -//import com.newrelic.agent.bridge.Transaction; -//import com.newrelic.api.agent.NewRelic; -//import com.newrelic.api.agent.Trace; -//import com.newrelic.api.agent.weaver.MatchType; -//import com.newrelic.api.agent.weaver.Weave; -//import com.newrelic.api.agent.weaver.Weaver; -//import llm.models.ModelInvocation; -//import llm.models.anthropic.claude.AnthropicClaudeModelInvocation; -//import software.amazon.awssdk.core.client.config.SdkClientConfiguration; -//import software.amazon.awssdk.core.client.handler.SyncClientHandler; -//import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; -//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; -//import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; -// -//import java.util.Map; -// -//import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; -//import static llm.vendor.Vendor.VENDOR_VERSION; -// -///** -// * Service client for accessing Amazon Bedrock Runtime. -// */ -//// TODO switch back to instrumenting the BedrockRuntimeClient interface instead of this implementation class -//@Weave(type = MatchType.ExactClass, originalName = "software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient") -//final class DefaultBedrockRuntimeClient_Instrumentation { -//// private static final Logger log = Logger.loggerFor(DefaultBedrockRuntimeClient.class); -//// -//// private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() -//// .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); -// -// private final SyncClientHandler clientHandler; -// -// private final AwsJsonProtocolFactory protocolFactory; -// -// private final SdkClientConfiguration clientConfiguration; -// -// private final BedrockRuntimeServiceClientConfiguration serviceClientConfiguration; -// -// protected DefaultBedrockRuntimeClient_Instrumentation(BedrockRuntimeServiceClientConfiguration serviceClientConfiguration, -// SdkClientConfiguration clientConfiguration) { -// this.clientHandler = Weaver.callOriginal(); -// this.clientConfiguration = Weaver.callOriginal(); -// this.serviceClientConfiguration = Weaver.callOriginal(); -// this.protocolFactory = Weaver.callOriginal(); -// } -// -// @Trace -// public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { -// long startTime = System.currentTimeMillis(); -// InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); -// -// ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); -// -// Transaction txn = AgentBridge.getAgent().getTransaction(); -// // TODO check AIM config -// if (txn != null && !(txn instanceof NoOpTransaction)) { -// Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); -// -// String modelId = invokeModelRequest.modelId(); -// if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { -// ModelInvocation anthropicClaudeModelInvocation = new AnthropicClaudeModelInvocation(txn, invokeModelRequest, -// invokeModelResponse); -// // Set traced method name based on LLM operation -// anthropicClaudeModelInvocation.setLlmOperationMetricName("invokeModel"); -// // Set llm = true agent attribute -// ModelInvocation.setLlmTrueAgentAttribute(txn); -// anthropicClaudeModelInvocation.recordLlmEvents(startTime, linkingMetadata); -// } -// } -// -// return invokeModelResponse; -// } -// -//} From 9160fff759c253fed62c065cb5d84e162eb1186c Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 6 Mar 2024 16:41:54 -0800 Subject: [PATCH 19/68] Add readme for each model --- .../main/java/llm/models/SupportedModels.java | 26 +- .../jurassic/JurassicModelRequest.java | 4 +- .../jurassic/JurassicModelResponse.java | 3 +- .../llm/models/ai21labs/jurassic/README.md | 713 ++++++++++++++++++ .../java/llm/models/amazon/titan/README.md | 74 ++ .../amazon/titan/TitanModelResponse.java | 4 +- .../anthropic/claude/ClaudeModelRequest.java | 25 +- .../anthropic/claude/ClaudeModelResponse.java | 7 +- .../llm/models/anthropic/claude/README.md | 41 + .../java/llm/models/cohere/command/README.md | 82 ++ .../java/llm/models/meta/llama2/README.md | 40 + 11 files changed, 965 insertions(+), 54 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java index 6c5cdb7972..70300901f1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java @@ -7,6 +7,12 @@ package llm.models; +/** + * Prefixes for supported models. As long as the model ID for an invoked LLM model contains + * one of these prefixes the instrumentation should attempt to process the request/response. + *

+ * See the README for each model in llm.models.* for more details on supported models. + */ public class SupportedModels { public static final String ANTHROPIC_CLAUDE = "anthropic.claude"; public static final String AMAZON_TITAN = "amazon.titan"; @@ -15,23 +21,3 @@ public class SupportedModels { public static final String COHERE_EMBED = "cohere.embed"; public static final String AI_21_LABS_JURASSIC = "ai21.j2"; } - -//*AI21 Labs -// *Jurassic-2Ultra(ai21.j2-ultra-v1) -// *Jurassic-2Mid(ai21.j2-mid-v1) -//*Amazon -// *Titan Embeddings G1-Text(amazon.titan-embed-text-v1) -// *Titan Text G1-Lite(amazon.titan-text-lite-v1) -// *Titan Text G1-Express(amazon.titan-text-express-v1) -// *Titan Multimodal Embeddings G1(amazon.titan-embed-image-v1) -//*Anthropic -// *Claude(anthropic.claude-v2,anthropic.claude-v2:1) -// *Claude Instant(anthropic.claude-instant-v1) -//*Cohere -// *Command(cohere.command-text-v14) -// *Command Light(cohere.command-light-text-v14) -// *Embed English(cohere.embed-english-v3) -// *Embed Multilingual(cohere.embed-multilingual-v3) -//*Meta -// *Llama 2Chat 13B(meta.llama2-13b-chat-v1) -// *Llama 2Chat 70B(meta.llama2-70b-chat-v1) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java index fe63b6f4dd..7462e5d34e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -27,7 +27,6 @@ public class JurassicModelRequest implements ModelRequest { private static final String MAX_TOKENS = "maxTokens"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; - private static final String INPUT_TEXT = "inputText"; private String invokeModelRequestBody = ""; private String modelId = ""; @@ -131,7 +130,8 @@ public String getRole() { @Override public String getInputText() { - return parseStringValue(INPUT_TEXT); + // This is a NoOp for Jurassic as it doesn't support embeddings + return ""; } private String parseStringValue(String fieldToParse) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java index d4f03a10f1..54ea510729 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -112,10 +112,9 @@ private Map parseInvokeModelResponseBodyMap() { private void setOperationType(String invokeModelResponseBody) { try { if (!invokeModelResponseBody.isEmpty()) { + // Jurassic for Bedrock doesn't support embedding operations if (invokeModelResponseBody.contains(COMPLETION)) { operationType = COMPLETION; - } else if (invokeModelResponseBody.contains(EMBEDDING)) { - operationType = EMBEDDING; } else { logParsingFailure(null, "operation type"); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md new file mode 100644 index 0000000000..6d2d1d71aa --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md @@ -0,0 +1,713 @@ +# AI21 Labs + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Jurassic Models + +#### Text Completion Models + +The following models have been tested: + +* Jurassic-2 Mid (`ai21.j2-mid-v1`) +* Jurassic-2 Ultra (`ai21.j2-ultra-v1`) + +#### Sample Request + +```json +{ + "temperature": 0.5, + "maxTokens": 1000, + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "id": 1234, + "prompt": { + "text": "What is the color of the sky?", + "tokens": [ + { + "generatedToken": { + "token": "▁What▁is▁the", + "logprob": -8.316551208496094, + "raw_logprob": -8.316551208496094 + }, + "topTokens": null, + "textRange": { + "start": 0, + "end": 11 + } + }, + { + "generatedToken": { + "token": "▁color", + "logprob": -7.189708709716797, + "raw_logprob": -7.189708709716797 + }, + "topTokens": null, + "textRange": { + "start": 11, + "end": 17 + } + }, + { + "generatedToken": { + "token": "▁of▁the▁sky", + "logprob": -5.750617027282715, + "raw_logprob": -5.750617027282715 + }, + "topTokens": null, + "textRange": { + "start": 17, + "end": 28 + } + }, + { + "generatedToken": { + "token": "?", + "logprob": -5.858178615570068, + "raw_logprob": -5.858178615570068 + }, + "topTokens": null, + "textRange": { + "start": 28, + "end": 29 + } + } + ] + }, + "completions": [ + { + "data": { + "text": "\nThe color of the sky on Earth is blue. This is because Earth's atmosphere scatters short-wavelength light more efficiently than long-wavelength light. When sunlight enters Earth's atmosphere, most of the blue light is scattered, leaving mostly red light to illuminate the sky. The scattering of blue light is more efficient because it travels as shorter, smaller waves.", + "tokens": [ + { + "generatedToken": { + "token": "<|newline|>", + "logprob": 0.0, + "raw_logprob": -6.305972783593461E-5 + }, + "topTokens": null, + "textRange": { + "start": 0, + "end": 1 + } + }, + { + "generatedToken": { + "token": "▁The▁color", + "logprob": -0.007753042038530111, + "raw_logprob": -0.18397575616836548 + }, + "topTokens": null, + "textRange": { + "start": 1, + "end": 10 + } + }, + { + "generatedToken": { + "token": "▁of▁the▁sky", + "logprob": -6.770858453819528E-5, + "raw_logprob": -0.0130088459700346 + }, + "topTokens": null, + "textRange": { + "start": 10, + "end": 21 + } + }, + { + "generatedToken": { + "token": "▁on▁Earth", + "logprob": -6.189814303070307E-4, + "raw_logprob": -0.06852064281702042 + }, + "topTokens": null, + "textRange": { + "start": 21, + "end": 30 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -0.5599813461303711, + "raw_logprob": -1.532042145729065 + }, + "topTokens": null, + "textRange": { + "start": 30, + "end": 33 + } + }, + { + "generatedToken": { + "token": "▁blue", + "logprob": -0.0358763113617897, + "raw_logprob": -0.2531339228153229 + }, + "topTokens": null, + "textRange": { + "start": 33, + "end": 38 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.0022088908590376377, + "raw_logprob": -0.11807831376791 + }, + "topTokens": null, + "textRange": { + "start": 38, + "end": 39 + } + }, + { + "generatedToken": { + "token": "▁This▁is▁because", + "logprob": -0.7582850456237793, + "raw_logprob": -1.6503678560256958 + }, + "topTokens": null, + "textRange": { + "start": 39, + "end": 55 + } + }, + { + "generatedToken": { + "token": "▁Earth's▁atmosphere", + "logprob": -0.37150290608406067, + "raw_logprob": -1.086639404296875 + }, + "topTokens": null, + "textRange": { + "start": 55, + "end": 74 + } + }, + { + "generatedToken": { + "token": "▁scatter", + "logprob": -1.4662635294371285E-5, + "raw_logprob": -0.011443688534200191 + }, + "topTokens": null, + "textRange": { + "start": 74, + "end": 82 + } + }, + { + "generatedToken": { + "token": "s", + "logprob": -9.929640509653836E-5, + "raw_logprob": -0.01099079567939043 + }, + "topTokens": null, + "textRange": { + "start": 82, + "end": 83 + } + }, + { + "generatedToken": { + "token": "▁short", + "logprob": -2.97943115234375, + "raw_logprob": -1.8346563577651978 + }, + "topTokens": null, + "textRange": { + "start": 83, + "end": 89 + } + }, + { + "generatedToken": { + "token": "-wavelength", + "logprob": -1.5722469834145159E-4, + "raw_logprob": -0.020076051354408264 + }, + "topTokens": null, + "textRange": { + "start": 89, + "end": 100 + } + }, + { + "generatedToken": { + "token": "▁light", + "logprob": -1.8000440832111053E-5, + "raw_logprob": -0.008328350260853767 + }, + "topTokens": null, + "textRange": { + "start": 100, + "end": 106 + } + }, + { + "generatedToken": { + "token": "▁more▁efficiently", + "logprob": -0.11763446033000946, + "raw_logprob": -0.6382070779800415 + }, + "topTokens": null, + "textRange": { + "start": 106, + "end": 123 + } + }, + { + "generatedToken": { + "token": "▁than", + "logprob": -0.0850396677851677, + "raw_logprob": -0.4660969078540802 + }, + "topTokens": null, + "textRange": { + "start": 123, + "end": 128 + } + }, + { + "generatedToken": { + "token": "▁long", + "logprob": -0.21488533914089203, + "raw_logprob": -0.43275904655456543 + }, + "topTokens": null, + "textRange": { + "start": 128, + "end": 133 + } + }, + { + "generatedToken": { + "token": "-wavelength", + "logprob": -3.576272320060525E-6, + "raw_logprob": -0.0032024311367422342 + }, + "topTokens": null, + "textRange": { + "start": 133, + "end": 144 + } + }, + { + "generatedToken": { + "token": "▁light", + "logprob": -6.603976362384856E-5, + "raw_logprob": -0.021542951464653015 + }, + "topTokens": null, + "textRange": { + "start": 144, + "end": 150 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.03969373181462288, + "raw_logprob": -0.24834078550338745 + }, + "topTokens": null, + "textRange": { + "start": 150, + "end": 151 + } + }, + { + "generatedToken": { + "token": "▁When", + "logprob": -0.8459960222244263, + "raw_logprob": -1.758193016052246 + }, + "topTokens": null, + "textRange": { + "start": 151, + "end": 156 + } + }, + { + "generatedToken": { + "token": "▁sunlight", + "logprob": -0.043000709265470505, + "raw_logprob": -0.413555383682251 + }, + "topTokens": null, + "textRange": { + "start": 156, + "end": 165 + } + }, + { + "generatedToken": { + "token": "▁enters", + "logprob": -2.2813825607299805, + "raw_logprob": -1.975184440612793 + }, + "topTokens": null, + "textRange": { + "start": 165, + "end": 172 + } + }, + { + "generatedToken": { + "token": "▁Earth's▁atmosphere", + "logprob": -0.04206264019012451, + "raw_logprob": -0.22090668976306915 + }, + "topTokens": null, + "textRange": { + "start": 172, + "end": 191 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -2.1300431399140507E-4, + "raw_logprob": -0.04065611585974693 + }, + "topTokens": null, + "textRange": { + "start": 191, + "end": 192 + } + }, + { + "generatedToken": { + "token": "▁most▁of▁the", + "logprob": -1.0895559787750244, + "raw_logprob": -1.4258980751037598 + }, + "topTokens": null, + "textRange": { + "start": 192, + "end": 204 + } + }, + { + "generatedToken": { + "token": "▁blue▁light", + "logprob": -2.7195115089416504, + "raw_logprob": -2.069707155227661 + }, + "topTokens": null, + "textRange": { + "start": 204, + "end": 215 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -3.036991402041167E-4, + "raw_logprob": -0.036258988082408905 + }, + "topTokens": null, + "textRange": { + "start": 215, + "end": 218 + } + }, + { + "generatedToken": { + "token": "▁scattered", + "logprob": -1.1086402082582936E-5, + "raw_logprob": -0.007142604328691959 + }, + "topTokens": null, + "textRange": { + "start": 218, + "end": 228 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -0.8132423162460327, + "raw_logprob": -1.204469919204712 + }, + "topTokens": null, + "textRange": { + "start": 228, + "end": 229 + } + }, + { + "generatedToken": { + "token": "▁leaving", + "logprob": -0.028648898005485535, + "raw_logprob": -0.24427929520606995 + }, + "topTokens": null, + "textRange": { + "start": 229, + "end": 237 + } + }, + { + "generatedToken": { + "token": "▁mostly", + "logprob": -0.012762418016791344, + "raw_logprob": -0.18833962082862854 + }, + "topTokens": null, + "textRange": { + "start": 237, + "end": 244 + } + }, + { + "generatedToken": { + "token": "▁red▁light", + "logprob": -0.3875422477722168, + "raw_logprob": -0.9608176350593567 + }, + "topTokens": null, + "textRange": { + "start": 244, + "end": 254 + } + }, + { + "generatedToken": { + "token": "▁to▁illuminate", + "logprob": -1.2177848815917969, + "raw_logprob": -1.6379175186157227 + }, + "topTokens": null, + "textRange": { + "start": 254, + "end": 268 + } + }, + { + "generatedToken": { + "token": "▁the▁sky", + "logprob": -0.004821578972041607, + "raw_logprob": -0.1349806934595108 + }, + "topTokens": null, + "textRange": { + "start": 268, + "end": 276 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -2.7894584491150454E-5, + "raw_logprob": -0.01649152860045433 + }, + "topTokens": null, + "textRange": { + "start": 276, + "end": 277 + } + }, + { + "generatedToken": { + "token": "▁The", + "logprob": -4.816740989685059, + "raw_logprob": -3.04256534576416 + }, + "topTokens": null, + "textRange": { + "start": 277, + "end": 281 + } + }, + { + "generatedToken": { + "token": "▁scattering", + "logprob": -0.07598043233156204, + "raw_logprob": -0.4935254752635956 + }, + "topTokens": null, + "textRange": { + "start": 281, + "end": 292 + } + }, + { + "generatedToken": { + "token": "▁of", + "logprob": -2.1653952598571777, + "raw_logprob": -2.153515338897705 + }, + "topTokens": null, + "textRange": { + "start": 292, + "end": 295 + } + }, + { + "generatedToken": { + "token": "▁blue▁light", + "logprob": -0.0025517542380839586, + "raw_logprob": -0.0987434908747673 + }, + "topTokens": null, + "textRange": { + "start": 295, + "end": 306 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -0.04848421365022659, + "raw_logprob": -0.5477231740951538 + }, + "topTokens": null, + "textRange": { + "start": 306, + "end": 309 + } + }, + { + "generatedToken": { + "token": "▁more▁efficient", + "logprob": -1.145136833190918, + "raw_logprob": -1.6279737949371338 + }, + "topTokens": null, + "textRange": { + "start": 309, + "end": 324 + } + }, + { + "generatedToken": { + "token": "▁because▁it", + "logprob": -0.7712448835372925, + "raw_logprob": -1.402230143547058 + }, + "topTokens": null, + "textRange": { + "start": 324, + "end": 335 + } + }, + { + "generatedToken": { + "token": "▁travels", + "logprob": -1.0001159535022452E-4, + "raw_logprob": -0.03441037982702255 + }, + "topTokens": null, + "textRange": { + "start": 335, + "end": 343 + } + }, + { + "generatedToken": { + "token": "▁as", + "logprob": -2.169585604860913E-5, + "raw_logprob": -0.008925186470150948 + }, + "topTokens": null, + "textRange": { + "start": 343, + "end": 346 + } + }, + { + "generatedToken": { + "token": "▁shorter", + "logprob": -0.0026372435968369246, + "raw_logprob": -0.054399896413087845 + }, + "topTokens": null, + "textRange": { + "start": 346, + "end": 354 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -3.576214658096433E-5, + "raw_logprob": -0.011654269881546497 + }, + "topTokens": null, + "textRange": { + "start": 354, + "end": 355 + } + }, + { + "generatedToken": { + "token": "▁smaller", + "logprob": -1.0609570381348021E-5, + "raw_logprob": -0.007282733917236328 + }, + "topTokens": null, + "textRange": { + "start": 355, + "end": 363 + } + }, + { + "generatedToken": { + "token": "▁waves", + "logprob": -2.7418097943154862E-6, + "raw_logprob": -0.0030873988289386034 + }, + "topTokens": null, + "textRange": { + "start": 363, + "end": 369 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.19333261251449585, + "raw_logprob": -0.535153865814209 + }, + "topTokens": null, + "textRange": { + "start": 369, + "end": 370 + } + }, + { + "generatedToken": { + "token": "<|endoftext|>", + "logprob": -0.03163028880953789, + "raw_logprob": -0.6691970229148865 + }, + "topTokens": null, + "textRange": { + "start": 370, + "end": 370 + } + } + ] + }, + "finishReason": { + "reason": "endoftext" + } + } + ] +} +``` + +### Embedding Models + +Not supported by Jurassic. \ No newline at end of file diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md new file mode 100644 index 0000000000..8947646586 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md @@ -0,0 +1,74 @@ +# Amazon + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Titan Models + +### Text Completion Models + +The following models have been tested: + +* Titan Text G1-Lite (`amazon.titan-text-lite-v1`) +* Titan Text G1-Express (`amazon.titan-text-express-v1`) + +#### Sample Request + +```json +{ + "inputText": "What is the color of the sky?", + "textGenerationConfig": { + "maxTokenCount": 1000, + "stopSequences": [ + "User:" + ], + "temperature": 0.5, + "topP": 0.9 + } +} +``` + +#### Sample Response + +```json +{ + "inputTextTokenCount": 8, + "results": [ + { + "tokenCount": 39, + "outputText": "\nThe color of the sky depends on the time of day, weather conditions, and location. It can range from blue to gray, depending on the presence of clouds and pollutants in the air.", + "completionReason": "FINISH" + } + ] +} +``` + +### Embedding Models + +The following models have been tested: + +* Titan Embeddings G1-Text (`amazon.titan-embed-text-v1`) +* Titan Multimodal Embeddings G1 (`amazon.titan-embed-image-v1`) + +#### Sample Request + +```json +{ + "inputText": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "embedding": [ + 0.328125, + ..., + 0.44335938 + ], + "inputTextTokenCount": 8 +} +``` + diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java index 56f8d4a8cc..ea800c5196 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -49,8 +49,6 @@ public class TitanModelResponse implements ModelResponse { private String invokeModelResponseBody = ""; private Map responseBodyJsonMap = null; - private static final String JSON_START = "{\""; - public TitanModelResponse(InvokeModelResponse invokeModelResponse) { if (invokeModelResponse != null) { invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); @@ -114,7 +112,7 @@ private void setOperationType(String invokeModelResponseBody) { if (!invokeModelResponseBody.isEmpty()) { if (invokeModelResponseBody.contains(COMPLETION_REASON)) { operationType = COMPLETION; - } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { + } else if (invokeModelResponseBody.startsWith(EMBEDDING)) { operationType = EMBEDDING; } else { logParsingFailure(null, "operation type"); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java index dc6a0eceea..da461b4849 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -27,11 +27,6 @@ public class ClaudeModelRequest implements ModelRequest { private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; private static final String TEMPERATURE = "temperature"; private static final String PROMPT = "prompt"; - private static final String INPUT_TEXT = "inputText"; - private static final String ESCAPED_NEWLINES = "\\n\\n"; - private static final String SYSTEM = "system"; - private static final String ASSISTANT = "assistant"; - private static final String USER = "user"; private String invokeModelRequestBody = ""; private String modelId = ""; @@ -129,28 +124,14 @@ public String getRequestMessage() { @Override public String getRole() { - try { - if (!invokeModelRequestBody.isEmpty()) { - String invokeModelRequestBodyLowerCase = invokeModelRequestBody.toLowerCase(); - if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + SYSTEM)) { - return SYSTEM; - } else if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + USER)) { - return USER; - } else if (invokeModelRequestBodyLowerCase.contains(ESCAPED_NEWLINES + ASSISTANT)) { - return ASSISTANT; - } - } else { - logParsingFailure(null, "role"); - } - } catch (Exception e) { - logParsingFailure(e, "role"); - } + // This is a NoOp for Claude as the request doesn't contain any signifier of the role return ""; } @Override public String getInputText() { - return parseStringValue(INPUT_TEXT); + // This is a NoOp for Claude as it doesn't support embeddings + return ""; } private String parseStringValue(String fieldToParse) { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java index 0ff9992061..851ba9cebb 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -47,8 +47,6 @@ public class ClaudeModelResponse implements ModelResponse { private String invokeModelResponseBody = ""; private Map responseBodyJsonMap = null; - private static final String JSON_START = "{\""; - public ClaudeModelResponse(InvokeModelResponse invokeModelResponse) { if (invokeModelResponse != null) { invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); @@ -110,10 +108,9 @@ private Map parseInvokeModelResponseBodyMap() { private void setOperationType(String invokeModelResponseBody) { try { if (!invokeModelResponseBody.isEmpty()) { - if (invokeModelResponseBody.startsWith(JSON_START + COMPLETION)) { + // Claude for Bedrock doesn't support embedding operations + if (invokeModelResponseBody.contains(COMPLETION)) { operationType = COMPLETION; - } else if (invokeModelResponseBody.startsWith(JSON_START + EMBEDDING)) { - operationType = EMBEDDING; } else { logParsingFailure(null, "operation type"); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md new file mode 100644 index 0000000000..345695d03c --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md @@ -0,0 +1,41 @@ +# Anthropic + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Claude Models + +### Text Completion Models + +The following models have been tested: + +* Claude(`anthropic.claude-v2`, `anthropic.claude-v2:1`) +* Claude Instant(`anthropic.claude-instant-v1`) + +#### Sample Request + +```json +{ + "stop_sequences": [ + "\n\nHuman:" + ], + "max_tokens_to_sample": 1000, + "temperature": 0.5, + "prompt": "Human: What is the color of the sky?\n\nAssistant:" +} +``` + +#### Sample Response + +```json +{ + "completion": " The sky appears blue during the day because molecules in the air scatter blue light from the sun more than they scatter red light. The actual color of the sky varies some based on atmospheric conditions, but the primary color we perceive is blue.", + "stop_reason": "stop_sequence", + "stop": "\n\nHuman:" +} +``` + +### Embedding Models + +Not supported by Claude. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md new file mode 100644 index 0000000000..c50a56a1cf --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md @@ -0,0 +1,82 @@ +# Cohere + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed here. + +## Command Models + +### Text Completion Models + +The following models have been tested: +* Command(`cohere.command-text-v14`) +* Command Light(`cohere.command-light-text-v14`) + +#### Sample Request + +```json +{ + "p": 0.9, + "stop_sequences": [ + "User:" + ], + "truncate": "END", + "max_tokens": 1000, + "stream": false, + "temperature": 0.5, + "k": 0, + "return_likelihoods": "NONE", + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "generations": [ + { + "finish_reason": "COMPLETE", + "id": "f5700a48-0730-49f1-9756-227a993963aa", + "text": " The color of the sky can vary depending on the time of day, weather conditions, and location. In general, the color of the sky is a pale blue. During the day, the sky can appear to be a lighter shade of blue, while at night, it may appear to be a darker shade of blue or even black. The color of the sky can also be affected by the presence of clouds, which can appear as white, grey, or even pink or red in the morning or evening light. \n\nIt is important to note that the color of the sky is not a static or fixed color, but rather a dynamic and ever-changing one, which can be influenced by a variety of factors." + } + ], + "id": "c548295f-9064-49c5-a05f-c754e4c5c9f8", + "prompt": "What is the color of the sky?" +} +``` + +### Embedding Models + +The following models have been tested: +* Embed English(`cohere.embed-english-v3`) +* Embed Multilingual(`cohere.embed-multilingual-v3`) + +#### Sample Request + +```json +{ + "texts": [ + "What is the color of the sky?" + ], + "truncate": "NONE", + "input_type": "search_document" +} +``` + +#### Sample Response + +```json +{ + "embeddings": [ + [ + -0.002828598, + ..., + 0.00541687 + ] + ], + "id": "e1e969ba-d526-4c76-aa92-a8a705288f6d", + "response_type": "embeddings_floats", + "texts": [ + "what is the color of the sky?" + ] +} +``` diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md new file mode 100644 index 0000000000..6bfba79a5a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md @@ -0,0 +1,40 @@ +# Meta + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Llama 2 Models + +### Text Completion Models + +The following models have been tested: + +* Llama 2Chat 13B (`meta.llama2-13b-chat-v1`) +* Llama 2Chat 70B (`meta.llama2-70b-chat-v1`) + +#### Sample Request + +```json +{ + "top_p": 0.9, + "max_gen_len": 1000, + "temperature": 0.5, + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "generation": "\n\nThe color of the sky can vary depending on the time of day and atmospheric conditions. During the daytime, the sky typically appears blue, which is caused by a phenomenon called Rayleigh scattering, in which shorter (blue) wavelengths of light are scattered more than longer (red) wavelengths by the tiny molecules of gases in the atmosphere.\n\nIn the evening, as the sun sets, the sky can take on a range of colors, including orange, pink, and purple, due to the scattering of light by atmospheric particles. During sunrise and sunset, the sky can also appear more red or orange due to the longer wavelengths of light being scattered.\n\nAt night, the sky can appear dark, but it can also be illuminated by the moon, stars, and artificial light sources such as city lights. In areas with minimal light pollution, the night sky can be a deep indigo or black, with the stars and constellations visible as points of light.\n\nOverall, the color of the sky can vary greatly depending on the time of day, atmospheric conditions, and the observer's location.", + "prompt_token_count": 9, + "generation_token_count": 256, + "stop_reason": "stop" +} +``` + +### Embedding Models + +Not supported by Llama 2. From 1c353e65213a232bad5da99b0e0b7811e52c8074 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 10:49:55 -0800 Subject: [PATCH 20/68] Cleanup --- .../aimonitoring/AiMonitoringUtils.java | 14 ++++- .../aws-bedrock-runtime-2.20/README.md | 51 ++++++------------- .../src/main/java/llm/events/LlmEvent.java | 6 ++- .../amazon/titan/TitanModelResponse.java | 2 +- .../LlmEventAttributeValidator.java | 10 ++-- 5 files changed, 39 insertions(+), 44 deletions(-) diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java index 8168925a04..161cfd1aa9 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java @@ -13,12 +13,13 @@ public class AiMonitoringUtils { // Enabled defaults private static final boolean AI_MONITORING_ENABLED_DEFAULT = false; private static final boolean AI_MONITORING_STREAMING_ENABLED_DEFAULT = true; + private static final boolean AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT = true; /** * Check if ai_monitoring features are enabled. * Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. * - * @return true if enabled, else false + * @return true if AI monitoring is enabled, else false */ public static boolean isAiMonitoringEnabled() { return NewRelic.getAgent().getConfig().getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); @@ -27,9 +28,18 @@ public static boolean isAiMonitoringEnabled() { /** * Check if ai_monitoring.streaming features are enabled. * - * @return true if enabled, else false + * @return true if streaming is enabled, else false */ public static boolean isAiMonitoringStreamingEnabled() { return NewRelic.getAgent().getConfig().getValue("ai_monitoring.streaming.enabled", AI_MONITORING_STREAMING_ENABLED_DEFAULT); } + + /** + * Check if the input and output content should be added to LLM events. + * + * @return true if adding content is enabled, else false + */ + public static boolean isAiMonitoringRecordContentEnabled() { + return NewRelic.getAgent().getConfig().getValue("ai_monitoring.record_content.enabled", AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT); + } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index fb778568f5..b8e4606692 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -19,27 +19,27 @@ Note: Currently, `invokeModelWithResponseStream` is not supported. ### Supported Models -At the time of the instrumentation being published, only the following text-based foundation models have been tested and confirmed as supported: +At the time of the instrumentation being published, only the following text-based foundation models have been tested and confirmed as supported. As long as the model ID for an invoked LLM model contains one of these prefixes defined in `SupportedModels`, the instrumentation should attempt to process the request/response. However, if the request/response structure significantly changes the processing may fail. See the `README` for each model in `llm.models.*` for more details on each. * AI21 Labs - * Jurassic-2 Ultra (ai21.j2-ultra-v1) - * Jurassic-2 Mid (ai21.j2-mid-v1) + * Jurassic-2 Ultra (`ai21.j2-ultra-v1`) + * Jurassic-2 Mid (`ai21.j2-mid-v1`) * Amazon - * Titan Embeddings G1 - Text (amazon.titan-embed-text-v1) - * Titan Text G1 - Lite (amazon.titan-text-lite-v1) - * Titan Text G1 - Express (amazon.titan-text-express-v1) - * Titan Multimodal Embeddings G1 (amazon.titan-embed-image-v1) + * Titan Embeddings G1 - Text (`amazon.titan-embed-text-v1`) + * Titan Text G1 - Lite (`amazon.titan-text-lite-v1`) + * Titan Text G1 - Express (`amazon.titan-text-express-v1`) + * Titan Multimodal Embeddings G1 (`amazon.titan-embed-image-v1`) * Anthropic - * Claude (anthropic.claude-v2, anthropic.claude-v2:1) - * Claude Instant (anthropic.claude-instant-v1) + * Claude (`anthropic.claude-v2`, `anthropic.claude-v2:1`) + * Claude Instant (`anthropic.claude-instant-v1`) * Cohere - * Command (cohere.command-text-v14) - * Command Light (cohere.command-light-text-v14) - * Embed English (cohere.embed-english-v3) - * Embed Multilingual (cohere.embed-multilingual-v3) + * Command (`cohere.command-text-v14`) + * Command Light (`cohere.command-light-text-v14`) + * Embed English (`cohere.embed-english-v3`) + * Embed Multilingual (`cohere.embed-multilingual-v3`) * Meta - * Llama 2 Chat 13B (meta.llama2-13b-chat-v1) - * Llama 2 Chat 70B (meta.llama2-70b-chat-v1) + * Llama 2 Chat 13B (`meta.llama2-13b-chat-v1`) + * Llama 2 Chat 70B (`meta.llama2-70b-chat-v1`) ## Involved Pieces @@ -183,28 +183,7 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com ## TODO * Make all LLM event attribute values un-truncated???? https://source.datanerd.us/agents/agent-specs/pull/664 -* Add new `ai_monitoring.record_content.enabled` config https://source.datanerd.us/agents/agent-specs/pull/663 * Refactoring related to token count, new callback API https://source.datanerd.us/agents/agent-specs/pull/662 -* Set up and test new models - * AI21 Labs - * Jurassic-2 Ultra (~~ai21.j2-ultra-v1~~) - * Jurassic-2 Mid (~~ai21.j2-mid-v1~~) - * Amazon - * Titan Embeddings G1 - Text (~~amazon.titan-embed-text-v1~~) - * Titan Text G1 - Lite (~~amazon.titan-text-lite-v1~~) - * Titan Text G1 - Express (~~amazon.titan-text-express-v1~~) - * Titan Multimodal Embeddings G1 (~~amazon.titan-embed-image-v1~~) - * Anthropic - * Claude (~~anthropic.claude-v2~~, ~~anthropic.claude-v2:1~~) - * Claude Instant (~~anthropic.claude-instant-v1~~) - * Cohere - * Command (~~cohere.command-text-v14~~) - * Command Light (~~cohere.command-light-text-v14~~) - * Embed English (~~cohere.embed-english-v3~~) - * Embed Multilingual (~~cohere.embed-multilingual-v3~~) - * Meta - * Llama 2 Chat 13B (~~meta.llama2-13b-chat-v1~~) - * Llama 2 Chat 70B (~~meta.llama2-70b-chat-v1~~) * Test env var and sys prop config * Update default yaml * Write instrumentation tests diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 2b49efce58..cc9f1775c3 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -16,6 +16,8 @@ import java.util.HashMap; import java.util.Map; +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringRecordContentEnabled; + /** * Class for building an LlmEvent */ @@ -261,7 +263,7 @@ private LlmEvent(Builder builder) { } content = builder.content; - if (content != null && !content.isEmpty()) { + if (isAiMonitoringRecordContentEnabled() && content != null && !content.isEmpty()) { eventAttributes.put("content", content); } @@ -311,7 +313,7 @@ private LlmEvent(Builder builder) { } input = builder.input; - if (input != null && !input.isEmpty()) { + if (isAiMonitoringRecordContentEnabled() && input != null && !input.isEmpty()) { eventAttributes.put("input", input); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java index ea800c5196..40ce928b15 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -112,7 +112,7 @@ private void setOperationType(String invokeModelResponseBody) { if (!invokeModelResponseBody.isEmpty()) { if (invokeModelResponseBody.contains(COMPLETION_REASON)) { operationType = COMPLETION; - } else if (invokeModelResponseBody.startsWith(EMBEDDING)) { + } else if (invokeModelResponseBody.contains(EMBEDDING)) { operationType = EMBEDDING; } else { logParsingFailure(null, "operation type"); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java index 6350df9c54..fc1a2dcf0b 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java @@ -13,7 +13,6 @@ * Attribute validator with truncation rules specific to LLM events. */ public class LlmEventAttributeValidator extends AttributeValidator { - // FIXME different size attribute limits for LLM events InsightsConfigImpl.MAX_MAX_ATTRIBUTE_VALUE ? private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService() .getDefaultAgentConfig() .getInsightsConfig() @@ -25,8 +24,13 @@ public LlmEventAttributeValidator(String attributeType) { @Override protected String truncateValue(String key, String value, String methodCalled) { - // TODO make sure that this behavior is accepted into the agent spec - if (key.equals("content")) { + /* + * The 'input' and output 'content' attributes should be added to LLM events + * without being truncated as per the LLMs agent spec. This is because the + * backend will use these attributes to calculate LLM token usage in cases + * where token counts aren't available on LLM events. + */ + if (key.equals("content") || key.equals("input")) { return value; } String truncatedVal = truncateString(value, MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE); From 100f04711ab5697cc945ca95a6222f9dc2cd527f Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 11:23:46 -0800 Subject: [PATCH 21/68] Add ai_monitoring stanza to default yaml config file --- newrelic-agent/src/main/resources/newrelic.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/newrelic-agent/src/main/resources/newrelic.yml b/newrelic-agent/src/main/resources/newrelic.yml index f8eda0e5c6..5dfb1c16ad 100644 --- a/newrelic-agent/src/main/resources/newrelic.yml +++ b/newrelic-agent/src/main/resources/newrelic.yml @@ -83,6 +83,20 @@ common: &default_settings # Default is the logs directory in the newrelic.jar parent directory. #log_file_path: + # AI Monitoring captures insights on the performance, quality, and cost of interactions with LLM models made with instrumented SDKs. + ai_monitoring: + + # Provides control over all AI Monitoring functionality. Set as true to enable all AI Monitoring features. + # Default is false. + enabled: false + + # Provides control over whether attributes for the input and output content should be added to LLM events. + record_content: + + # Set as false to disable attributes for the input and output content. + # Default is true. + enabled: true + # Provides the ability to forward application logs to New Relic, generate log usage metrics, # and decorate local application log files with agent metadata for use with third party log forwarders. # The application_logging.forwarding and application_logging.local_decorating should not be used together. From 6b0b88c3b26bf887119b68376b51aad392da7c2a Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 14:10:37 -0800 Subject: [PATCH 22/68] Cleanup readme --- .../aws-bedrock-runtime-2.20/README.md | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index b8e4606692..0818748a66 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -2,7 +2,7 @@ ## About -Instruments invocations of LLMs via the AWS Bedrock Runtime SDK. +Instruments invocations of LLMs made by the AWS Bedrock Runtime SDK. ## Support @@ -19,7 +19,7 @@ Note: Currently, `invokeModelWithResponseStream` is not supported. ### Supported Models -At the time of the instrumentation being published, only the following text-based foundation models have been tested and confirmed as supported. As long as the model ID for an invoked LLM model contains one of these prefixes defined in `SupportedModels`, the instrumentation should attempt to process the request/response. However, if the request/response structure significantly changes the processing may fail. See the `README` for each model in `llm.models.*` for more details on each. +At the time of the instrumentation being published, the following text-based foundation models have been tested and confirmed as supported. As long as the model ID for an invoked LLM model contains one of the prefixes defined in `SupportedModels`, the instrumentation should attempt to process the request/response. However, if the request/response structure significantly changes the processing may fail. See the `README` for each model in `llm.models.*` for more details on each. * AI21 Labs * Jurassic-2 Ultra (`ai21.j2-ultra-v1`) @@ -51,36 +51,28 @@ The main goal of this instrumentation is to generate the following LLM events to * `LlmChatCompletionSummary`: An event that captures high-level data about the creation of a chat completion including request, response, and call information. * `LlmChatCompletionMessage`: An event that corresponds to each message (sent and received) from a chat completion call including those created by the user, assistant, and the system. -These events are custom events sent via the public `recordCustomEvent` API. Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). Because of this, it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. LLM events are sent to the `custom_event_data` collector endpoint but the backend will assign them a unique namespace to distinguish them from other custom events. +These events are custom events sent via the public `recordCustomEvent` API. Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). ```yaml custom_insights_events: max_samples_stored: 100000 ``` -LLM events also have some unique limits for the content attribute... - -``` -Regardless of which implementation(s) are built, there are consistent changes within the agents and the UX to support AI Monitoring. - -Agents should send the entire content; do not truncate it to 256 or 4096 characters - -Agents should move known token counts to the LlmChatCompletionMessage - -Agents should remove token counts from the LlmChatCompletionSummary -``` +Because of this, it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. LLM events are sent to the `custom_event_data` collector endpoint but the backend will assign them a unique namespace to distinguish them from other custom events. +### Attributes +#### Agent Attributes -Can be built via `LlmEvent` builder +An `llm: true` agent attribute will be set on all Transaction events where one of the supported Bedrock methods is invoked within an active transaction. -### Model Invocation/Request/Response +#### LLM Event Attributes -* `ModelInvocation` -* `ModelRequest` -* `ModelResponse` +Attributes on LLM events use the same configuration and size limits as `custom_insights_events` with two notable exceptions being that the following two LLM event attributes will not be truncated at all: +* `content` +* `input` -### Attributes +This is done so that token usage can be calculated on the backend based on the full input and output content. #### Custom LLM Attributes @@ -88,11 +80,6 @@ Any custom attributes added by customers using the `addCustomParameters` API tha One potential custom attribute with special meaning that customers are encouraged to add is `llm.conversation_id`, which has implications in the UI and can be used to group LLM messages into specific conversations. -#### Agent Attributes - - // Set llm = true agent attribute required on TransactionEvents - - ### Metrics When in an active transaction a named span/segment for each LLM embedding and chat completion call is created using the following format: @@ -123,16 +110,16 @@ Note: Streaming is not currently supported. ## Config -`ai_monitoring.enabled`: Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. +`ai_monitoring.enabled`: Provides control over all AI Monitoring functionality. Set as true to enable all AI Monitoring features. +`ai_monitoring.record_content.enabled`: Provides control over whether attributes for the input and output content should be added to LLM events. Set as false to disable attributes for the input and output content. `ai_monitoring.streaming.enabled`: NOT SUPPORTED ## Related Agent APIs -feedback -callback -addCustomParameter - -## Testing +AI monitoring can be enhanced by using the following agent APIs: +* `recordLlmFeedbackEvent` - Can be used to record an LlmFeedback event to associate user feedback with a specific distributed trace. +* `setLlmTokenCountCallback` +* `addCustomParameter` - Used to add custom attributed to LLM events. See [Custom LLM Attributes](#custom-llm-attributes) ## Known Issues @@ -182,10 +169,8 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com ## TODO -* Make all LLM event attribute values un-truncated???? https://source.datanerd.us/agents/agent-specs/pull/664 * Refactoring related to token count, new callback API https://source.datanerd.us/agents/agent-specs/pull/662 * Test env var and sys prop config -* Update default yaml * Write instrumentation tests * Finish readme * Refactor test app to have multiple invokeMethods for a single transaction... From 55b18abeba3f4b5c699075dcf6d43b10f69b7096 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 15:15:32 -0800 Subject: [PATCH 23/68] Remove logic for parsing input/output token counts --- .../aws-bedrock-runtime-2.20/README.md | 4 +- .../src/main/java/llm/events/LlmEvent.java | 40 +++------------ .../main/java/llm/models/ModelResponse.java | 26 ---------- .../jurassic/JurassicModelInvocation.java | 7 +-- .../jurassic/JurassicModelResponse.java | 50 +----------------- .../amazon/titan/TitanModelInvocation.java | 7 +-- .../amazon/titan/TitanModelResponse.java | 50 +----------------- .../claude/ClaudeModelInvocation.java | 7 +-- .../anthropic/claude/ClaudeModelResponse.java | 51 +------------------ .../command/CommandModelInvocation.java | 7 +-- .../cohere/command/CommandModelResponse.java | 50 +----------------- .../meta/llama2/Llama2ModelInvocation.java | 7 +-- .../meta/llama2/Llama2ModelResponse.java | 51 +------------------ 13 files changed, 24 insertions(+), 333 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index 0818748a66..b17963b135 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -118,7 +118,7 @@ Note: Streaming is not currently supported. AI monitoring can be enhanced by using the following agent APIs: * `recordLlmFeedbackEvent` - Can be used to record an LlmFeedback event to associate user feedback with a specific distributed trace. -* `setLlmTokenCountCallback` +* `setLlmTokenCountCallback` - Can be used to register a Callback that provides a token count. * `addCustomParameter` - Used to add custom attributed to LLM events. See [Custom LLM Attributes](#custom-llm-attributes) ## Known Issues @@ -169,9 +169,7 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com ## TODO -* Refactoring related to token count, new callback API https://source.datanerd.us/agents/agent-specs/pull/662 * Test env var and sys prop config * Write instrumentation tests -* Finish readme * Refactor test app to have multiple invokeMethods for a single transaction... * Figure out how to get external call linked with async client diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index cc9f1775c3..e1fe8d8d71 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -50,9 +50,7 @@ public class LlmEvent { private final Float requestTemperature; private final Integer requestMaxTokens; private final String requestModel; - private final Integer responseUsageTotalTokens; - private final Integer responseUsagePromptTokens; - private final Integer responseUsageCompletionTokens; + private final Integer tokenCount; private final String responseChoicesFinishReason; public static class Builder { @@ -90,9 +88,7 @@ public static class Builder { private Float requestTemperature = null; private Integer requestMaxTokens = null; private String requestModel = null; - private Integer responseUsageTotalTokens = null; - private Integer responseUsagePromptTokens = null; - private Integer responseUsageCompletionTokens = null; + private Integer tokenCount = null; private String responseChoicesFinishReason = null; public Builder(ModelInvocation modelInvocation) { @@ -204,18 +200,8 @@ public Builder requestModel() { return this; } - public Builder responseUsageTotalTokens() { - responseUsageTotalTokens = modelResponse.getTotalTokenCount(); - return this; - } - - public Builder responseUsagePromptTokens() { - responseUsagePromptTokens = modelResponse.getInputTokenCount(); - return this; - } - - public Builder responseUsageCompletionTokens() { - responseUsageCompletionTokens = modelResponse.getOutputTokenCount(); + public Builder tokenCount(int count) { + tokenCount = count; return this; } @@ -323,7 +309,7 @@ private LlmEvent(Builder builder) { } requestMaxTokens = builder.requestMaxTokens; - if (requestMaxTokens != null && requestMaxTokens >= 0) { + if (requestMaxTokens != null && requestMaxTokens > 0) { eventAttributes.put("request.max_tokens", requestMaxTokens); } @@ -332,19 +318,9 @@ private LlmEvent(Builder builder) { eventAttributes.put("request.model", requestModel); } - responseUsageTotalTokens = builder.responseUsageTotalTokens; - if (responseUsageTotalTokens != null && responseUsageTotalTokens >= 0) { - eventAttributes.put("response.usage.total_tokens", responseUsageTotalTokens); - } - - responseUsagePromptTokens = builder.responseUsagePromptTokens; - if (responseUsagePromptTokens != null && responseUsagePromptTokens >= 0) { - eventAttributes.put("response.usage.prompt_tokens", responseUsagePromptTokens); - } - - responseUsageCompletionTokens = builder.responseUsageCompletionTokens; - if (responseUsageCompletionTokens != null && responseUsageCompletionTokens >= 0) { - eventAttributes.put("response.usage.completion_tokens", responseUsageCompletionTokens); + tokenCount = builder.tokenCount; + if (tokenCount != null && tokenCount > 0) { + eventAttributes.put("token_count", tokenCount); } responseChoicesFinishReason = builder.responseChoicesFinishReason; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index 150c47552d..47134ec73f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -12,11 +12,6 @@ import java.util.logging.Level; public interface ModelResponse { - // Response headers - String X_AMZN_BEDROCK_INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; - String X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; - String X_AMZN_REQUEST_ID = "x-amzn-RequestId"; - // Operation types String COMPLETION = "completion"; String EMBEDDING = "embedding"; @@ -35,27 +30,6 @@ public interface ModelResponse { */ String getStopReason(); - /** - * Get the count of input tokens used. - * - * @return int representing the count of input tokens used - */ - int getInputTokenCount(); - - /** - * Get the count of output tokens used. - * - * @return int representing the count of output tokens used - */ - int getOutputTokenCount(); - - /** - * Get the count of total tokens used. - * - * @return int representing the count of total tokens used - */ - int getTotalTokenCount(); - /** * Get the Amazon Request ID. * diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java index 4f900c3dc9..69ccf86b1a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -69,8 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .input() .requestModel() .responseModel() - .responseUsageTotalTokens() - .responseUsagePromptTokens() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -98,9 +97,6 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .requestModel() .responseModel() .responseNumberOfMessages(numberOfMessages) - .responseUsageTotalTokens() - .responseUsagePromptTokens() - .responseUsageCompletionTokens() .responseChoicesFinishReason() .error() .duration(System.currentTimeMillis() - startTime) @@ -128,6 +124,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { .responseModel() .sequence(sequence) .completionId() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java index 54ea510729..3354118737 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -33,8 +33,6 @@ public class JurassicModelResponse implements ModelResponse { private static final String DATA = "data"; private static final String TEXT = "text"; - private int inputTokenCount = 0; - private int outputTokenCount = 0; private String amznRequestId = ""; // LLM operation type @@ -59,7 +57,7 @@ public JurassicModelResponse(InvokeModelResponse invokeModelResponse) { Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); - setHeaderFields(invokeModelResponse); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -124,37 +122,6 @@ private void setOperationType(String invokeModelResponseBody) { } } - /** - * Parses header values from the response object and assigns them to fields. - * - * @param invokeModelResponse response object - */ - private void setHeaderFields(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - try { - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - String result = inputTokenCountHeaders.get(0); - inputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - String result = outputTokenCountHeaders.get(0); - outputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); - } - } else { - logParsingFailure(null, "response headers"); - } - } catch (Exception e) { - logParsingFailure(e, "response headers"); - } - } - @Override public String getResponseMessage() { String parsedResponseMessage = ""; @@ -229,21 +196,6 @@ public String getStopReason() { return parsedStopReason; } - @Override - public int getInputTokenCount() { - return inputTokenCount; - } - - @Override - public int getOutputTokenCount() { - return outputTokenCount; - } - - @Override - public int getTotalTokenCount() { - return inputTokenCount + outputTokenCount; - } - @Override public String getAmznRequestId() { return amznRequestId; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java index 86bc96ee8b..fd0b629950 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -69,8 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .input() .requestModel() .responseModel() - .responseUsageTotalTokens() - .responseUsagePromptTokens() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -98,9 +97,6 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .requestModel() .responseModel() .responseNumberOfMessages(numberOfMessages) - .responseUsageTotalTokens() - .responseUsagePromptTokens() - .responseUsageCompletionTokens() .responseChoicesFinishReason() .error() .duration(System.currentTimeMillis() - startTime) @@ -128,6 +124,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { .responseModel() .sequence(sequence) .completionId() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java index 40ce928b15..26b424d1fd 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -31,8 +31,6 @@ public class TitanModelResponse implements ModelResponse { private static final String RESULTS = "results"; private static final String OUTPUT_TEXT = "outputText"; - private int inputTokenCount = 0; - private int outputTokenCount = 0; private String amznRequestId = ""; // LLM operation type @@ -57,7 +55,7 @@ public TitanModelResponse(InvokeModelResponse invokeModelResponse) { Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); - setHeaderFields(invokeModelResponse); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -123,37 +121,6 @@ private void setOperationType(String invokeModelResponseBody) { } } - /** - * Parses header values from the response object and assigns them to fields. - * - * @param invokeModelResponse response object - */ - private void setHeaderFields(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - try { - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - String result = inputTokenCountHeaders.get(0); - inputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - String result = outputTokenCountHeaders.get(0); - outputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); - } - } else { - logParsingFailure(null, "response headers"); - } - } catch (Exception e) { - logParsingFailure(e, "response headers"); - } - } - @Override public String getResponseMessage() { String parsedResponseMessage = ""; @@ -216,21 +183,6 @@ public String getStopReason() { return parsedStopReason; } - @Override - public int getInputTokenCount() { - return inputTokenCount; - } - - @Override - public int getOutputTokenCount() { - return outputTokenCount; - } - - @Override - public int getTotalTokenCount() { - return inputTokenCount + outputTokenCount; - } - @Override public String getAmznRequestId() { return amznRequestId; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java index aa6e25a035..e31666a380 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -69,8 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .input() .requestModel() .responseModel() - .responseUsageTotalTokens() - .responseUsagePromptTokens() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -98,9 +97,6 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .requestModel() .responseModel() .responseNumberOfMessages(numberOfMessages) - .responseUsageTotalTokens() - .responseUsagePromptTokens() - .responseUsageCompletionTokens() .responseChoicesFinishReason() .error() .duration(System.currentTimeMillis() - startTime) @@ -128,6 +124,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { .responseModel() .sequence(sequence) .completionId() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java index 851ba9cebb..42d7526b59 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -14,7 +14,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.logging.Level; @@ -29,8 +28,6 @@ public class ClaudeModelResponse implements ModelResponse { private static final String STOP_REASON = "stop_reason"; - private int inputTokenCount = 0; - private int outputTokenCount = 0; private String amznRequestId = ""; // LLM operation type @@ -55,7 +52,7 @@ public ClaudeModelResponse(InvokeModelResponse invokeModelResponse) { Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); - setHeaderFields(invokeModelResponse); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -120,37 +117,6 @@ private void setOperationType(String invokeModelResponseBody) { } } - /** - * Parses header values from the response object and assigns them to fields. - * - * @param invokeModelResponse response object - */ - private void setHeaderFields(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - try { - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - String result = inputTokenCountHeaders.get(0); - inputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - String result = outputTokenCountHeaders.get(0); - outputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); - } - } else { - logParsingFailure(null, "response headers"); - } - } catch (Exception e) { - logParsingFailure(e, "response headers"); - } - } - @Override public String getResponseMessage() { return parseStringValue(COMPLETION); @@ -179,21 +145,6 @@ private String parseStringValue(String fieldToParse) { return parsedStringValue; } - @Override - public int getInputTokenCount() { - return inputTokenCount; - } - - @Override - public int getOutputTokenCount() { - return outputTokenCount; - } - - @Override - public int getTotalTokenCount() { - return inputTokenCount + outputTokenCount; - } - @Override public String getAmznRequestId() { return amznRequestId; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java index 961aa41041..4ebfaf2231 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java @@ -69,8 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .input() .requestModel() .responseModel() - .responseUsageTotalTokens() - .responseUsagePromptTokens() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -98,9 +97,6 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .requestModel() .responseModel() .responseNumberOfMessages(numberOfMessages) - .responseUsageTotalTokens() - .responseUsagePromptTokens() - .responseUsageCompletionTokens() .responseChoicesFinishReason() .error() .duration(System.currentTimeMillis() - startTime) @@ -128,6 +124,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { .responseModel() .sequence(sequence) .completionId() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java index 32951e7ea4..6a53cfdeea 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java @@ -32,8 +32,6 @@ public class CommandModelResponse implements ModelResponse { private static final String EMBEDDINGS = "embeddings"; private static final String TEXT = "text"; - private int inputTokenCount = 0; - private int outputTokenCount = 0; private String amznRequestId = ""; // LLM operation type @@ -58,7 +56,7 @@ public CommandModelResponse(InvokeModelResponse invokeModelResponse) { Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); - setHeaderFields(invokeModelResponse); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -124,37 +122,6 @@ private void setOperationType(String invokeModelResponseBody) { } } - /** - * Parses header values from the response object and assigns them to fields. - * - * @param invokeModelResponse response object - */ - private void setHeaderFields(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - try { - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - String result = inputTokenCountHeaders.get(0); - inputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - String result = outputTokenCountHeaders.get(0); - outputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); - } - } else { - logParsingFailure(null, "response headers"); - } - } catch (Exception e) { - logParsingFailure(e, "response headers"); - } - } - @Override public String getResponseMessage() { String parsedResponseMessage = ""; @@ -217,21 +184,6 @@ public String getStopReason() { return parsedStopReason; } - @Override - public int getInputTokenCount() { - return inputTokenCount; - } - - @Override - public int getOutputTokenCount() { - return outputTokenCount; - } - - @Override - public int getTotalTokenCount() { - return inputTokenCount + outputTokenCount; - } - @Override public String getAmznRequestId() { return amznRequestId; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java index 15fc08bd35..5de3a55233 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java @@ -69,8 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .input() .requestModel() .responseModel() - .responseUsageTotalTokens() - .responseUsagePromptTokens() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -98,9 +97,6 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess .requestModel() .responseModel() .responseNumberOfMessages(numberOfMessages) - .responseUsageTotalTokens() - .responseUsagePromptTokens() - .responseUsageCompletionTokens() .responseChoicesFinishReason() .error() .duration(System.currentTimeMillis() - startTime) @@ -128,6 +124,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message) { .responseModel() .sequence(sequence) .completionId() + .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java index c206e093ba..b65247397f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java @@ -14,7 +14,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.logging.Level; @@ -30,8 +29,6 @@ public class Llama2ModelResponse implements ModelResponse { private static final String STOP_REASON = "stop_reason"; private static final String GENERATION = "generation"; - private int inputTokenCount = 0; - private int outputTokenCount = 0; private String amznRequestId = ""; // LLM operation type @@ -56,7 +53,7 @@ public Llama2ModelResponse(InvokeModelResponse invokeModelResponse) { Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); statusTextOptional.ifPresent(s -> statusText = s); setOperationType(invokeModelResponseBody); - setHeaderFields(invokeModelResponse); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); llmChatCompletionSummaryId = getRandomGuid(); llmEmbeddingId = getRandomGuid(); } else { @@ -121,37 +118,6 @@ private void setOperationType(String invokeModelResponseBody) { } } - /** - * Parses header values from the response object and assigns them to fields. - * - * @param invokeModelResponse response object - */ - private void setHeaderFields(InvokeModelResponse invokeModelResponse) { - Map> headers = invokeModelResponse.sdkHttpResponse().headers(); - try { - if (!headers.isEmpty()) { - List inputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_INPUT_TOKEN_COUNT); - if (inputTokenCountHeaders != null && !inputTokenCountHeaders.isEmpty()) { - String result = inputTokenCountHeaders.get(0); - inputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List outputTokenCountHeaders = headers.get(X_AMZN_BEDROCK_OUTPUT_TOKEN_COUNT); - if (outputTokenCountHeaders != null && !outputTokenCountHeaders.isEmpty()) { - String result = outputTokenCountHeaders.get(0); - outputTokenCount = result != null ? Integer.parseInt(result) : 0; - } - List amznRequestIdHeaders = headers.get(X_AMZN_REQUEST_ID); - if (amznRequestIdHeaders != null && !amznRequestIdHeaders.isEmpty()) { - amznRequestId = amznRequestIdHeaders.get(0); - } - } else { - logParsingFailure(null, "response headers"); - } - } catch (Exception e) { - logParsingFailure(e, "response headers"); - } - } - @Override public String getResponseMessage() { return parseStringValue(GENERATION); @@ -180,21 +146,6 @@ private String parseStringValue(String fieldToParse) { return parsedStringValue; } - @Override - public int getInputTokenCount() { - return inputTokenCount; - } - - @Override - public int getOutputTokenCount() { - return outputTokenCount; - } - - @Override - public int getTotalTokenCount() { - return inputTokenCount + outputTokenCount; - } - @Override public String getAmznRequestId() { return amznRequestId; From 60b520a068bfefe997b4afc2b6405f2f407d15b3 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 15:27:19 -0800 Subject: [PATCH 24/68] Cleanup readme --- .../aws-bedrock-runtime-2.20/README.md | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index b17963b135..c2bd43eb0e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -110,10 +110,28 @@ Note: Streaming is not currently supported. ## Config +### Yaml + `ai_monitoring.enabled`: Provides control over all AI Monitoring functionality. Set as true to enable all AI Monitoring features. `ai_monitoring.record_content.enabled`: Provides control over whether attributes for the input and output content should be added to LLM events. Set as false to disable attributes for the input and output content. `ai_monitoring.streaming.enabled`: NOT SUPPORTED +### Environment Variable + +``` +NEW_RELIC_AI_MONITORING_ENABLED +NEW_RELIC_AI_MONITORING_RECORD_CONTENT_ENABLED +NEW_RELIC_AI_MONITORING_STREAMING_ENABLED +``` + +### System Property + +``` +-Dnewrelic.config.ai_monitoring.enabled +-Dnewrelic.config.ai_monitoring.record_content.enabled +-Dnewrelic.config.ai_monitoring.streaming.enabled +``` + ## Related Agent APIs AI monitoring can be enhanced by using the following agent APIs: @@ -166,10 +184,3 @@ When using the `BedrockRuntimeAsyncClient`, which returns the response as a `Com at software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler.execute(AwsAsyncClientHandler.java:52) at software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient.invokeModel(DefaultBedrockRuntimeAsyncClient.java:161) ``` - - -## TODO -* Test env var and sys prop config -* Write instrumentation tests -* Refactor test app to have multiple invokeMethods for a single transaction... -* Figure out how to get external call linked with async client From 7ac856830fafe886fd594e2ba03e512a19e9bb41 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 18:38:51 -0800 Subject: [PATCH 25/68] Refactor to handle cases that should generate multiple events --- .../src/main/java/llm/events/LlmEvent.java | 4 +- .../main/java/llm/models/ModelInvocation.java | 19 ++------ .../main/java/llm/models/ModelRequest.java | 25 ++++++++-- .../main/java/llm/models/ModelResponse.java | 13 ++++- .../jurassic/JurassicModelInvocation.java | 48 ++++++++++++++----- .../jurassic/JurassicModelRequest.java | 16 ++++++- .../jurassic/JurassicModelResponse.java | 30 ++++++++++-- .../amazon/titan/TitanModelInvocation.java | 48 ++++++++++++++----- .../amazon/titan/TitanModelRequest.java | 16 ++++++- .../amazon/titan/TitanModelResponse.java | 26 +++++++++- .../claude/ClaudeModelInvocation.java | 48 ++++++++++++++----- .../anthropic/claude/ClaudeModelRequest.java | 16 ++++++- .../anthropic/claude/ClaudeModelResponse.java | 8 +++- .../command/CommandModelInvocation.java | 48 ++++++++++++++----- .../cohere/command/CommandModelRequest.java | 34 +++++++++++-- .../cohere/command/CommandModelResponse.java | 26 +++++++++- .../meta/llama2/Llama2ModelInvocation.java | 48 ++++++++++++++----- .../meta/llama2/Llama2ModelRequest.java | 16 ++++++- .../meta/llama2/Llama2ModelResponse.java | 8 +++- 19 files changed, 393 insertions(+), 104 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index e1fe8d8d71..02f9934b12 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -180,8 +180,8 @@ public Builder error() { return this; } - public Builder input() { - input = modelRequest.getInputText(); + public Builder input(int index) { + input = modelRequest.getInputText(index); return this; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index ceb230774d..a59d57eeb8 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -44,8 +44,9 @@ public interface ModelInvocation { * Record an LlmEmbedding event that captures data specific to the creation of an embedding. * * @param startTime start time of SDK invoke method + * @param index of the input message in an array */ - void recordLlmEmbeddingEvent(long startTime); + void recordLlmEmbeddingEvent(long startTime, int index); /** * Record an LlmChatCompletionSummary event that captures high-level data about @@ -62,8 +63,9 @@ public interface ModelInvocation { * * @param sequence index starting at 0 associated with each message * @param message String representing the input/output message + * @param isUser boolean representing if the current message event is from a user input prompt or an assistant response message */ - void recordLlmChatCompletionMessageEvent(int sequence, String message); + void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser); /** * Record all LLM events when using the sync client. @@ -180,17 +182,4 @@ static String getTraceId(Map linkingMetadata) { static String getRandomGuid() { return UUID.randomUUID().toString(); } - - /** - * Determine if the LLM is initiated by the user or assistant. - *

- * Assuming that one user request is always followed by one assistant - * response, an even sequence value is the user, while odd is the assistant. - * - * @param sequence index starting at 0 associated with each message - * @return true if is user, false if not - */ - default boolean isUser(int sequence) { - return sequence % 2 == 0; - } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index d30c589bea..4b603a9ee2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -27,25 +27,42 @@ public interface ModelRequest { float getTemperature(); /** - * Get the content of the request message. + * Get the content of the request message, potentially from a specific array index + * if multiple messages are returned. * + * @param index int indicating the index of a message in an array. May be ignored for request structures that always return a single message. * @return String representing the content of the request message */ - String getRequestMessage(); + String getRequestMessage(int index); + + /** + * Get the number of request messages returned + * + * @return int representing the number of request messages returned + */ + int getNumberOfRequestMessages(); /** * Get the role of the requester. * * @return String representing the role of the requester */ - String getRole(); + String getRole(); // TODO can this just be deleted? /** * Get the input to the embedding creation call. * + * @param index int indicating the index of a message in an array. May be ignored for request structures that always return a single message. * @return String representing the input to the embedding creation call */ - String getInputText(); + String getInputText(int index); + + /** + * Get the number of input text messages from the embedding request. + * + * @return int representing the number of request messages returned + */ + int getNumberOfInputTextMessages(); /** * Get the LLM model ID. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java index 47134ec73f..e51b172a87 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -17,11 +17,20 @@ public interface ModelResponse { String EMBEDDING = "embedding"; /** - * Get the response message. + * Get the response message, potentially from a specific array index + * if multiple messages are returned. * + * @param index int indicating the index of a message in an array. May be ignored for response structures that always return a single message. * @return String representing the response message */ - String getResponseMessage(); + String getResponseMessage(int index); + + /** + * Get the number of response messages returned + * + * @return int representing the number of response messages returned + */ + int getNumberOfResponseMessages(); /** * Get the stop reason. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java index 69ccf86b1a..db364205a8 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -52,7 +52,7 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime) { + public void recordLlmEmbeddingEvent(long startTime, int index) { if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -66,7 +66,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .ingestSource() .id(modelResponse.getLlmEmbeddingId()) .requestId() - .input() + .input(index) .requestModel() .responseModel() .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API @@ -106,9 +106,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = isUser(sequence); - + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder @@ -136,7 +134,7 @@ public void recordLlmEvents(long startTime) { if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime); + recordLlmEmbeddingEvents(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -170,12 +168,38 @@ public void reportLlmError() { * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ private void recordLlmChatCompletionEvents(long startTime) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2); + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java index 7462e5d34e..c35ddb9889 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -118,7 +118,13 @@ public float getTemperature() { } @Override - public String getRequestMessage() { + public int getNumberOfRequestMessages() { + // The Jurassic request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { return parseStringValue(PROMPT); } @@ -129,11 +135,17 @@ public String getRole() { } @Override - public String getInputText() { + public String getInputText(int index) { // This is a NoOp for Jurassic as it doesn't support embeddings return ""; } + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Jurassic as it doesn't support embeddings + return 0; + } + private String parseStringValue(String fieldToParse) { String parsedStringValue = ""; try { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java index 3354118737..8d78cacf4e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -123,15 +123,15 @@ private void setOperationType(String invokeModelResponseBody) { } @Override - public String getResponseMessage() { + public String getResponseMessage(int index) { String parsedResponseMessage = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); if (completionsJsonNode.isArray()) { - List jsonNodeArray = completionsJsonNode.asArray(); - if (!jsonNodeArray.isEmpty()) { - JsonNode jsonNode = jsonNodeArray.get(0); + List completionsJsonNodeArray = completionsJsonNode.asArray(); + if (!completionsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = completionsJsonNodeArray.get(index); if (jsonNode.isObject()) { Map jsonNodeObject = jsonNode.asObject(); if (!jsonNodeObject.isEmpty()) { @@ -159,6 +159,28 @@ public String getResponseMessage() { return parsedResponseMessage; } + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List completionsJsonNodeArray = completionsJsonNode.asArray(); + if (!completionsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = completionsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, COMPLETIONS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, COMPLETIONS); + } + return numberOfResponseMessages; + } + @Override public String getStopReason() { String parsedStopReason = ""; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java index fd0b629950..ddc9308bdf 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -52,7 +52,7 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime) { + public void recordLlmEmbeddingEvent(long startTime, int index) { if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -66,7 +66,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .ingestSource() .id(modelResponse.getLlmEmbeddingId()) .requestId() - .input() + .input(index) .requestModel() .responseModel() .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API @@ -106,9 +106,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = isUser(sequence); - + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder @@ -136,7 +134,7 @@ public void recordLlmEvents(long startTime) { if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime); + recordLlmEmbeddingEvents(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -170,12 +168,38 @@ public void reportLlmError() { * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ private void recordLlmChatCompletionEvents(long startTime) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2); + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java index a2e8b69771..2b8bf64d5b 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -131,7 +131,13 @@ public float getTemperature() { } @Override - public String getRequestMessage() { + public int getNumberOfRequestMessages() { + // The Titan request only ever contains a single inputText message + return 1; + } + + @Override + public String getRequestMessage(int index) { return parseStringValue(INPUT_TEXT); } @@ -142,10 +148,16 @@ public String getRole() { } @Override - public String getInputText() { + public String getInputText(int index) { return parseStringValue(INPUT_TEXT); } + @Override + public int getNumberOfInputTextMessages() { + // There is only ever a single inputText message for Titan embeddings + return 1; + } + private String parseStringValue(String fieldToParse) { String parsedStringValue = ""; try { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java index 26b424d1fd..ea1e741bdf 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -122,7 +122,7 @@ private void setOperationType(String invokeModelResponseBody) { } @Override - public String getResponseMessage() { + public String getResponseMessage(int index) { String parsedResponseMessage = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -130,7 +130,7 @@ public String getResponseMessage() { if (jsonNode.isArray()) { List resultsJsonNodeArray = jsonNode.asArray(); if (!resultsJsonNodeArray.isEmpty()) { - JsonNode resultsJsonNode = resultsJsonNodeArray.get(0); + JsonNode resultsJsonNode = resultsJsonNodeArray.get(index); if (resultsJsonNode.isObject()) { Map resultsJsonNodeObject = resultsJsonNode.asObject(); if (!resultsJsonNodeObject.isEmpty()) { @@ -152,6 +152,28 @@ public String getResponseMessage() { return parsedResponseMessage; } + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = resultsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, RESULTS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, RESULTS); + } + return numberOfResponseMessages; + } + @Override public String getStopReason() { String parsedStopReason = ""; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java index e31666a380..9185047102 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -52,7 +52,7 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime) { + public void recordLlmEmbeddingEvent(long startTime, int index) { if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -66,7 +66,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .ingestSource() .id(modelResponse.getLlmEmbeddingId()) .requestId() - .input() + .input(index) .requestModel() .responseModel() .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API @@ -106,9 +106,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = isUser(sequence); - + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder @@ -136,7 +134,7 @@ public void recordLlmEvents(long startTime) { if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime); + recordLlmEmbeddingEvents(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -170,12 +168,38 @@ public void reportLlmError() { * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ private void recordLlmChatCompletionEvents(long startTime) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2); + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java index da461b4849..3de8e1a2f1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -118,7 +118,13 @@ public float getTemperature() { } @Override - public String getRequestMessage() { + public int getNumberOfRequestMessages() { + // The Claude request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { return parseStringValue(PROMPT); } @@ -129,11 +135,17 @@ public String getRole() { } @Override - public String getInputText() { + public String getInputText(int index) { // This is a NoOp for Claude as it doesn't support embeddings return ""; } + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Llama as it doesn't support embeddings + return 0; + } + private String parseStringValue(String fieldToParse) { String parsedStringValue = ""; try { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java index 42d7526b59..d34c3dd3ac 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -118,10 +118,16 @@ private void setOperationType(String invokeModelResponseBody) { } @Override - public String getResponseMessage() { + public String getResponseMessage(int index) { return parseStringValue(COMPLETION); } + @Override + public int getNumberOfResponseMessages() { + // There is only ever a single response message + return 1; + } + @Override public String getStopReason() { return parseStringValue(STOP_REASON); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java index 4ebfaf2231..ef4eb69acc 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java @@ -52,7 +52,7 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime) { + public void recordLlmEmbeddingEvent(long startTime, int index) { if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -66,7 +66,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .ingestSource() .id(modelResponse.getLlmEmbeddingId()) .requestId() - .input() + .input(index) .requestModel() .responseModel() .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API @@ -106,9 +106,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = isUser(sequence); - + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder @@ -136,7 +134,7 @@ public void recordLlmEvents(long startTime) { if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime); + recordLlmEmbeddingEvents(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -170,12 +168,38 @@ public void reportLlmError() { * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ private void recordLlmChatCompletionEvents(long startTime) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2); + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java index 7e819017a5..782425881d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java @@ -120,7 +120,13 @@ public float getTemperature() { } @Override - public String getRequestMessage() { + public int getNumberOfRequestMessages() { + // The Command request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { return parseStringValue(PROMPT); } @@ -131,7 +137,7 @@ public String getRole() { } @Override - public String getInputText() { + public String getInputText(int index) { String parsedInputText = ""; try { if (!getRequestBodyJsonMap().isEmpty()) { @@ -139,7 +145,7 @@ public String getInputText() { if (textsJsonNode.isArray()) { List textsJsonNodeArray = textsJsonNode.asArray(); if (!textsJsonNodeArray.isEmpty()) { - JsonNode jsonNode = textsJsonNodeArray.get(0); + JsonNode jsonNode = textsJsonNodeArray.get(index); if (jsonNode.isString()) { parsedInputText = jsonNode.asString(); } @@ -155,6 +161,28 @@ public String getInputText() { return parsedInputText; } + @Override + public int getNumberOfInputTextMessages() { + int numberOfInputTextMessages = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textsJsonNode = getRequestBodyJsonMap().get(TEXTS); + if (textsJsonNode.isArray()) { + List textsJsonNodeArray = textsJsonNode.asArray(); + if (!textsJsonNodeArray.isEmpty()) { + numberOfInputTextMessages = textsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXTS); + } + if (numberOfInputTextMessages == 0) { + logParsingFailure(null, TEXTS); + } + return numberOfInputTextMessages; + } + private String parseStringValue(String fieldToParse) { String parsedStringValue = ""; try { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java index 6a53cfdeea..1308759a70 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java @@ -123,7 +123,7 @@ private void setOperationType(String invokeModelResponseBody) { } @Override - public String getResponseMessage() { + public String getResponseMessage(int index) { String parsedResponseMessage = ""; try { if (!getResponseBodyJsonMap().isEmpty()) { @@ -131,7 +131,7 @@ public String getResponseMessage() { if (generationsJsonNode.isArray()) { List generationsJsonNodeArray = generationsJsonNode.asArray(); if (!generationsJsonNodeArray.isEmpty()) { - JsonNode jsonNode = generationsJsonNodeArray.get(0); + JsonNode jsonNode = generationsJsonNodeArray.get(index); if (jsonNode.isObject()) { Map jsonNodeObject = jsonNode.asObject(); if (!jsonNodeObject.isEmpty()) { @@ -153,6 +153,28 @@ public String getResponseMessage() { return parsedResponseMessage; } + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = generationsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, GENERATIONS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, GENERATIONS); + } + return numberOfResponseMessages; + } + @Override public String getStopReason() { String parsedStopReason = ""; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java index 5de3a55233..70ede61f9a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java @@ -52,7 +52,7 @@ public void setSegmentName(Segment segment, String functionName) { } @Override - public void recordLlmEmbeddingEvent(long startTime) { + public void recordLlmEmbeddingEvent(long startTime, int index) { if (modelResponse.isErrorResponse()) { reportLlmError(); } @@ -66,7 +66,7 @@ public void recordLlmEmbeddingEvent(long startTime) { .ingestSource() .id(modelResponse.getLlmEmbeddingId()) .requestId() - .input() + .input(index) .requestModel() .responseModel() .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API @@ -106,9 +106,7 @@ public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMess } @Override - public void recordLlmChatCompletionMessageEvent(int sequence, String message) { - boolean isUser = isUser(sequence); - + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { LlmEvent.Builder builder = new LlmEvent.Builder(this); LlmEvent llmChatCompletionMessageEvent = builder @@ -136,7 +134,7 @@ public void recordLlmEvents(long startTime) { if (operationType.equals(COMPLETION)) { recordLlmChatCompletionEvents(startTime); } else if (operationType.equals(EMBEDDING)) { - recordLlmEmbeddingEvent(startTime); + recordLlmEmbeddingEvents(startTime); } else { NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); } @@ -170,12 +168,38 @@ public void reportLlmError() { * The number of LlmChatCompletionMessage events produced can differ based on vendor. */ private void recordLlmChatCompletionEvents(long startTime) { - // First LlmChatCompletionMessage represents the user input prompt - recordLlmChatCompletionMessageEvent(0, modelRequest.getRequestMessage()); - // Second LlmChatCompletionMessage represents the completion message from the LLM response - recordLlmChatCompletionMessageEvent(1, modelResponse.getResponseMessage()); - // A summary of all LlmChatCompletionMessage events - recordLlmChatCompletionSummaryEvent(startTime, 2); + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } } @Override diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java index a3b793fade..1a19567deb 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java @@ -118,7 +118,13 @@ public float getTemperature() { } @Override - public String getRequestMessage() { + public int getNumberOfRequestMessages() { + // The Llama request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { return parseStringValue(PROMPT); } @@ -129,11 +135,17 @@ public String getRole() { } @Override - public String getInputText() { + public String getInputText(int index) { // This is a NoOp for Llama as it doesn't support embeddings return ""; } + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Llama as it doesn't support embeddings + return 0; + } + private String parseStringValue(String fieldToParse) { String parsedStringValue = ""; try { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java index b65247397f..d15eef7136 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java @@ -119,10 +119,16 @@ private void setOperationType(String invokeModelResponseBody) { } @Override - public String getResponseMessage() { + public String getResponseMessage(int index) { return parseStringValue(GENERATION); } + @Override + public int getNumberOfResponseMessages() { + // There is only ever a single response message + return 1; + } + @Override public String getStopReason() { return parseStringValue(STOP_REASON); From c601590ab6386e255772841d41167eee30339e3c Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Thu, 7 Mar 2024 18:44:57 -0800 Subject: [PATCH 26/68] Remove role from interface --- .../src/main/java/llm/events/LlmEvent.java | 5 +---- .../src/main/java/llm/models/ModelRequest.java | 7 ------- .../llm/models/ai21labs/jurassic/JurassicModelRequest.java | 6 ------ .../java/llm/models/amazon/titan/TitanModelRequest.java | 6 ------ .../llm/models/anthropic/claude/ClaudeModelRequest.java | 6 ------ .../llm/models/cohere/command/CommandModelRequest.java | 6 ------ .../java/llm/models/meta/llama2/Llama2ModelRequest.java | 6 ------ 7 files changed, 1 insertion(+), 41 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 02f9934b12..81d8211ff9 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -132,10 +132,7 @@ public Builder role(boolean isUser) { if (isUser) { role = "user"; } else { - role = modelRequest.getRole(); - if (role.isEmpty()) { - role = "assistant"; - } + role = "assistant"; } return this; } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java index 4b603a9ee2..9a2b491828 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -42,13 +42,6 @@ public interface ModelRequest { */ int getNumberOfRequestMessages(); - /** - * Get the role of the requester. - * - * @return String representing the role of the requester - */ - String getRole(); // TODO can this just be deleted? - /** * Get the input to the embedding creation call. * diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java index c35ddb9889..610ae4ea40 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -128,12 +128,6 @@ public String getRequestMessage(int index) { return parseStringValue(PROMPT); } - @Override - public String getRole() { - // This is a NoOp for Jurassic as the request doesn't contain any signifier of the role - return ""; - } - @Override public String getInputText(int index) { // This is a NoOp for Jurassic as it doesn't support embeddings diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java index 2b8bf64d5b..0221b3fee8 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -141,12 +141,6 @@ public String getRequestMessage(int index) { return parseStringValue(INPUT_TEXT); } - @Override - public String getRole() { - // This is a NoOp for Titan as the request doesn't contain any signifier of the role - return ""; - } - @Override public String getInputText(int index) { return parseStringValue(INPUT_TEXT); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java index 3de8e1a2f1..0355983642 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -128,12 +128,6 @@ public String getRequestMessage(int index) { return parseStringValue(PROMPT); } - @Override - public String getRole() { - // This is a NoOp for Claude as the request doesn't contain any signifier of the role - return ""; - } - @Override public String getInputText(int index) { // This is a NoOp for Claude as it doesn't support embeddings diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java index 782425881d..12c4218ce9 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java @@ -130,12 +130,6 @@ public String getRequestMessage(int index) { return parseStringValue(PROMPT); } - @Override - public String getRole() { - // This is a NoOp for Jurassic as the request doesn't contain any signifier of the role - return ""; - } - @Override public String getInputText(int index) { String parsedInputText = ""; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java index 1a19567deb..adf248b3eb 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java @@ -128,12 +128,6 @@ public String getRequestMessage(int index) { return parseStringValue(PROMPT); } - @Override - public String getRole() { - // This is a NoOp for Llama as the request doesn't contain any signifier of the role - return ""; - } - @Override public String getInputText(int index) { // This is a NoOp for Llama as it doesn't support embeddings From 3a8fb76189b1df719fb508c816a3afee5804e7e1 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Fri, 8 Mar 2024 08:53:43 -0800 Subject: [PATCH 27/68] End segment in case of exception --- .../BedrockRuntimeAsyncClient_Instrumentation.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 32f2890566..8efe4c9eda 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -74,7 +74,7 @@ public CompletableFuture invokeModel(InvokeModelRequest inv Token token = txn.getToken(); - // instrumentation fails if the BiConsumer is replaced with a lambda + // Instrumentation fails if the BiConsumer is replaced with a lambda invokeModelResponseFuture.whenComplete(new BiConsumer() { @Override public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { @@ -111,8 +111,13 @@ public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) jurassicModelInvocation.setTracedMethodName(txn, "invokeModel"); jurassicModelInvocation.recordLlmEventsAsync(startTime, token); } - segment.end(); + if (segment != null) { + segment.endAsync(); + } } catch (Throwable t) { + if (segment != null) { + segment.endAsync(); + } AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); } } From 400b198dc9aa2b3dbcbc80e13d41a733e92f5513 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 12 Mar 2024 10:40:53 -0700 Subject: [PATCH 28/68] Add config supportability metrics --- .../aimonitoring/AiMonitoringUtils.java | 30 +- .../aws-bedrock-runtime-2.20/build.gradle | 4 + .../main/java/llm/models/ModelInvocation.java | 7 - ...ockRuntimeAsyncClient_Instrumentation.java | 2 - .../test/java/llm/events/LlmEventTest.java | 340 ++++++++++++++++++ .../src/test/resources/llm_enabled.yml | 0 6 files changed, 371 insertions(+), 12 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java index 161cfd1aa9..686b12e8ad 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java @@ -22,7 +22,15 @@ public class AiMonitoringUtils { * @return true if AI monitoring is enabled, else false */ public static boolean isAiMonitoringEnabled() { - return NewRelic.getAgent().getConfig().getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + + if (enabled) { + NewRelic.incrementCounter("Supportability/Java/ML/Enabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Disabled"); + } + + return enabled; } /** @@ -31,7 +39,15 @@ public static boolean isAiMonitoringEnabled() { * @return true if streaming is enabled, else false */ public static boolean isAiMonitoringStreamingEnabled() { - return NewRelic.getAgent().getConfig().getValue("ai_monitoring.streaming.enabled", AI_MONITORING_STREAMING_ENABLED_DEFAULT); + Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.streaming.enabled", AI_MONITORING_STREAMING_ENABLED_DEFAULT); + + if (enabled) { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Enabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); + } + + return enabled; } /** @@ -40,6 +56,14 @@ public static boolean isAiMonitoringStreamingEnabled() { * @return true if adding content is enabled, else false */ public static boolean isAiMonitoringRecordContentEnabled() { - return NewRelic.getAgent().getConfig().getValue("ai_monitoring.record_content.enabled", AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT); + Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.record_content.enabled", AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT); + + if (enabled) { + NewRelic.incrementCounter("Supportability/Java/ML/RecordContent/Enabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/RecordContent/Disabled"); + } + + return enabled; } } diff --git a/instrumentation/aws-bedrock-runtime-2.20/build.gradle b/instrumentation/aws-bedrock-runtime-2.20/build.gradle index b42f3d37cb..01da205a22 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/build.gradle +++ b/instrumentation/aws-bedrock-runtime-2.20/build.gradle @@ -5,6 +5,10 @@ jar { dependencies { implementation(project(":agent-bridge")) implementation 'software.amazon.awssdk:bedrockruntime:2.20.157' + + testImplementation 'software.amazon.awssdk:bedrockruntime:2.20.157' + testImplementation 'org.mockito:mockito-inline:4.11.0' + testImplementation 'org.json:json:20240303' } verifyInstrumentation { diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index a59d57eeb8..9f85c73f0f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -131,13 +131,6 @@ static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { NewRelic.incrementCounter("Supportability/Java/ML/" + BEDROCK + "/" + vendorVersion); } - /** - * Increment a Supportability metric indicating that streaming support is disabled. - */ - static void incrementStreamingDisabledSupportabilityMetric() { - NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); - } - /** * Set the llm:true attribute on the active transaction. * diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java index 8efe4c9eda..f99ab6e635 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -137,8 +137,6 @@ public CompletableFuture invokeModelWithResponseStream( .getLogger() .log(Level.FINER, "aws-bedrock-runtime-2.20 instrumentation does not currently support response streaming. Enabling ai_monitoring.streaming will have no effect."); - } else { - ModelInvocation.incrementStreamingDisabledSupportabilityMetric(); } } return Weaver.callOriginal(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java new file mode 100644 index 0000000000..d22edff53a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -0,0 +1,340 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.events; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import llm.models.ModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.Builder; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class LlmEventTest { + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void reset() { + introspector.clear(); + } + + @Test + public void testRecordLlmEmbeddingEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + Mockito.when(mockRequestSdkBytes.asUtf8String()).thenReturn("{\"inputText\":\"What is the color of the sky?\"}"); + Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("amazon.titan-embed-text-v1"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + Mockito.when(mockResponseSdkBytes.asUtf8String()).thenReturn("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"); + + SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); + Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); + Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); + Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + TitanModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + // When + // Build LlmEmbedding event + Builder builder = new Builder(titanModelInvocation); + LlmEvent llmEmbeddingEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(titanModelInvocation.getModelResponse().getLlmEmbeddingId()) // attribute 5 + .requestId() // attribute 6 + .input(0) // attribute 7 + .requestModel() // attribute 8 + .responseModel() // attribute 9 + .tokenCount(123) // attribute 10 + .error() // not added + .duration(9000f) // attribute 11 + .build(); + + // attributes 12 & 13 should be the two llm.* prefixed userAttributes + + // Record LlmEmbedding event + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_EMBEDDING); + Assert.assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + Assert.assertEquals(LLM_EMBEDDING, event.getType()); + + Map attributes = event.getAttributes(); + Assert.assertEquals(13, attributes.size()); + Assert.assertEquals("span-id-123", attributes.get("span_id")); + Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); + Assert.assertEquals("bedrock", attributes.get("vendor")); + Assert.assertEquals("Java", attributes.get("ingest_source")); + Assert.assertFalse(((String) attributes.get("id")).isEmpty()); + Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + Assert.assertEquals("What is the color of the sky?", attributes.get("input")); + Assert.assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); + Assert.assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); + Assert.assertEquals(123, attributes.get("token_count")); + Assert.assertEquals(9000f, attributes.get("duration")); + Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } + + @Test + public void testRecordLlmChatCompletionMessageEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + String expectedUserPrompt = "Human: What is the color of the sky?\n\nAssistant:"; + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + Mockito.when(mockRequestSdkBytes.asUtf8String()) + .thenReturn( + "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); + Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + Mockito.when(mockResponseSdkBytes.asUtf8String()) + .thenReturn( + "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); + + SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); + Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); + Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); + Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + LlmEvent.Builder builder = new LlmEvent.Builder(claudeModelInvocation); + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(ModelInvocation.getRandomGuid()) // attribute 5 + .content(expectedUserPrompt) // attribute 6 + .role(true) // attribute 7 + .isResponse(true) // attribute 8 + .requestId() // attribute 9 + .responseModel() // attribute 10 + .sequence(0) // attribute 11 + .completionId() // attribute 12 + .tokenCount(123) // attribute 13 + .build(); + + // attributes 14 & 15 should be the two llm.* prefixed userAttributes + + // Record LlmChatCompletionMessage event + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + Assert.assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + Assert.assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + + Map attributes = event.getAttributes(); + Assert.assertEquals(15, attributes.size()); + Assert.assertEquals("span-id-123", attributes.get("span_id")); + Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); + Assert.assertEquals("bedrock", attributes.get("vendor")); + Assert.assertEquals("Java", attributes.get("ingest_source")); + Assert.assertFalse(((String) attributes.get("id")).isEmpty()); + Assert.assertEquals(expectedUserPrompt, attributes.get("content")); + Assert.assertEquals("user", attributes.get("role")); + Assert.assertEquals(false, attributes.get("is_response")); + Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + Assert.assertEquals("anthropic.claude-v2", attributes.get("response.model")); + Assert.assertEquals(0, attributes.get("sequence")); + Assert.assertFalse(((String) attributes.get("completion_id")).isEmpty()); + Assert.assertEquals(123, attributes.get("token_count")); + Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } + + @Test + public void testRecordLlmChatCompletionSummaryEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + Mockito.when(mockRequestSdkBytes.asUtf8String()) + .thenReturn( + "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); + Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); + Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + Mockito.when(mockResponseSdkBytes.asUtf8String()) + .thenReturn( + "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); + + SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); + Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); + Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); + Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + LlmEvent.Builder builder = new LlmEvent.Builder(claudeModelInvocation); + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(claudeModelInvocation.getModelResponse().getLlmChatCompletionSummaryId()) // attribute 5 + .requestId() // attribute 6 + .requestTemperature() // attribute 7 + .requestMaxTokens() // attribute 8 + .requestModel() // attribute 9 + .responseModel() // attribute 10 + .responseNumberOfMessages(2) // attribute 11 + .responseChoicesFinishReason() // attribute 12 + .error() // not added + .duration(9000f) // attribute 13 + .build(); + + // attributes 14 & 15 should be the two llm.* prefixed userAttributes + + // Record LlmChatCompletionSummary event + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + Assert.assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + Assert.assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + + Map attributes = event.getAttributes(); + Assert.assertEquals(15, attributes.size()); + Assert.assertEquals("span-id-123", attributes.get("span_id")); + Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); + Assert.assertEquals("bedrock", attributes.get("vendor")); + Assert.assertEquals("Java", attributes.get("ingest_source")); + Assert.assertFalse(((String) attributes.get("id")).isEmpty()); + Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + Assert.assertEquals(0.5f, attributes.get("request.temperature")); + Assert.assertEquals(1000, attributes.get("request.max_tokens")); + Assert.assertEquals("anthropic.claude-v2", attributes.get("request.model")); + Assert.assertEquals("anthropic.claude-v2", attributes.get("response.model")); + Assert.assertEquals(2, attributes.get("response.number_of_messages")); + Assert.assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); + Assert.assertEquals(9000f, attributes.get("duration")); + Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } + +// @Test +// public void testUntruncatedAttributes() { +// String expectedAttribute = "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__"; +// +// Map atts = new HashMap<>(); +// atts.put("content", expectedAttribute); +// atts.put("input", expectedAttribute); +// atts.put("vendor", expectedAttribute); +// +// NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, atts); +// +// Collection customEvents = introspector.getCustomEvents(LLM_EMBEDDING); +// Assert.assertEquals(1, customEvents.size()); +// +// Event event = customEvents.iterator().next(); +// Assert.assertEquals(LLM_EMBEDDING, event.getType()); +// +// Map attributes = event.getAttributes(); +// Assert.assertEquals(expectedAttribute, attributes.get("content")); +// Assert.assertEquals(expectedAttribute, attributes.get("input")); +// Assert.assertEquals(expectedAttribute, attributes.get("vendor")); +// } + +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml b/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml new file mode 100644 index 0000000000..e69de29bb2 From eccec66687f3f66c856fc47e3030092ec5ea7843 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 13 Mar 2024 09:43:03 -0700 Subject: [PATCH 29/68] Add ModelInvocation tests --- .../aws-bedrock-runtime-2.20/README.md | 13 +- .../test/java/llm/events/LlmEventTest.java | 214 ++++++------- .../src/test/java/llm/models/TestUtil.java | 96 ++++++ .../jurassic/JurassicModelInvocationTest.java | 162 ++++++++++ .../titan/TitanModelInvocationTest.java | 200 ++++++++++++ .../claude/ClaudeModelInvocationTest.java | 160 ++++++++++ .../command/CommandModelInvocationTest.java | 200 ++++++++++++ .../llama2/Llama2ModelInvocationTest.java | 160 ++++++++++ .../BedrockRuntimeClientMock.java | 101 +++++++ ...rockRuntimeClient_InstrumentationTest.java | 286 ++++++++++++++++++ .../BedrockRuntimeResponseMetadataMock.java | 18 ++ .../src/test/resources/llm_enabled.yml | 11 + 12 files changed, 1511 insertions(+), 110 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md index c2bd43eb0e..37d019c1c9 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/README.md +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -100,13 +100,18 @@ A supportability metric is reported each time an instrumented framework method i Note: The vendor version isn't obtainable from the AWS Bedrock SDK for Java so the instrumentation version is used instead. -Additionally, a supportability metric is recorded to indicate if streaming is disabled. Streaming is considered disabled if the value of the `ai_monitoring.streaming.enabled` configuration setting is `false`. If streaming is enabled, no supportability metric will be sent. The metric uses the following format: +Additionally, the following supportability metrics are recorded to indicate the agent config state. -`Supportability/{language}/ML/Streaming/Disabled` +``` +Supportability/Java/ML/Enabled +Supportability/Java/ML/Disabled -* `language`: Name of language agent (ex: `Java`) +Supportability/Java/ML/Streaming/Enabled +Supportability/Java/ML/Streaming/Disabled -Note: Streaming is not currently supported. +Supportability/Java/ML/RecordContent/Enabled +Supportability/Java/ML/RecordContent/Disabled +``` ## Config diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java index d22edff53a..ca3d417df0 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -14,11 +14,9 @@ import llm.models.ModelInvocation; import llm.models.amazon.titan.TitanModelInvocation; import llm.models.anthropic.claude.ClaudeModelInvocation; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mockito; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; @@ -34,6 +32,10 @@ import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; import static llm.events.LlmEvent.LLM_EMBEDDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; @RunWith(InstrumentationTestRunner.class) @InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") @@ -41,7 +43,7 @@ public class LlmEventTest { private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); @Before - public void reset() { + public void before() { introspector.clear(); } @@ -58,27 +60,27 @@ public void testRecordLlmEmbeddingEvent() { userAttributes.put("test", "test"); // Mock out ModelRequest - InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); - SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); - Mockito.when(mockRequestSdkBytes.asUtf8String()).thenReturn("{\"inputText\":\"What is the color of the sky?\"}"); - Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("amazon.titan-embed-text-v1"); + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn("{\"inputText\":\"What is the color of the sky?\"}"); + when(mockInvokeModelRequest.modelId()).thenReturn("amazon.titan-embed-text-v1"); // Mock out ModelResponse - InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); - SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); - Mockito.when(mockResponseSdkBytes.asUtf8String()).thenReturn("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"); + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"); - SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); - Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); - Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); - Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); - Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); - BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); - Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); - Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); // Instantiate ModelInvocation TitanModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, @@ -109,26 +111,26 @@ public void testRecordLlmEmbeddingEvent() { // Then Collection customEvents = introspector.getCustomEvents(LLM_EMBEDDING); - Assert.assertEquals(1, customEvents.size()); + assertEquals(1, customEvents.size()); Event event = customEvents.iterator().next(); - Assert.assertEquals(LLM_EMBEDDING, event.getType()); + assertEquals(LLM_EMBEDDING, event.getType()); Map attributes = event.getAttributes(); - Assert.assertEquals(13, attributes.size()); - Assert.assertEquals("span-id-123", attributes.get("span_id")); - Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); - Assert.assertEquals("bedrock", attributes.get("vendor")); - Assert.assertEquals("Java", attributes.get("ingest_source")); - Assert.assertFalse(((String) attributes.get("id")).isEmpty()); - Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); - Assert.assertEquals("What is the color of the sky?", attributes.get("input")); - Assert.assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); - Assert.assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); - Assert.assertEquals(123, attributes.get("token_count")); - Assert.assertEquals(9000f, attributes.get("duration")); - Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); - Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals(13, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals("What is the color of the sky?", attributes.get("input")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); + assertEquals(123, attributes.get("token_count")); + assertEquals(9000f, attributes.get("duration")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); } @Test @@ -146,31 +148,31 @@ public void testRecordLlmChatCompletionMessageEvent() { String expectedUserPrompt = "Human: What is the color of the sky?\n\nAssistant:"; // Mock out ModelRequest - InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); - SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); - Mockito.when(mockRequestSdkBytes.asUtf8String()) + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()) .thenReturn( "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); - Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); // Mock out ModelResponse - InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); - SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); - Mockito.when(mockResponseSdkBytes.asUtf8String()) + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()) .thenReturn( "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); - SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); - Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); - Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); - Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); - Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); - BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); - Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); - Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, mockInvokeModelResponse); @@ -199,28 +201,28 @@ public void testRecordLlmChatCompletionMessageEvent() { // Then Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); - Assert.assertEquals(1, customEvents.size()); + assertEquals(1, customEvents.size()); Event event = customEvents.iterator().next(); - Assert.assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); Map attributes = event.getAttributes(); - Assert.assertEquals(15, attributes.size()); - Assert.assertEquals("span-id-123", attributes.get("span_id")); - Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); - Assert.assertEquals("bedrock", attributes.get("vendor")); - Assert.assertEquals("Java", attributes.get("ingest_source")); - Assert.assertFalse(((String) attributes.get("id")).isEmpty()); - Assert.assertEquals(expectedUserPrompt, attributes.get("content")); - Assert.assertEquals("user", attributes.get("role")); - Assert.assertEquals(false, attributes.get("is_response")); - Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); - Assert.assertEquals("anthropic.claude-v2", attributes.get("response.model")); - Assert.assertEquals(0, attributes.get("sequence")); - Assert.assertFalse(((String) attributes.get("completion_id")).isEmpty()); - Assert.assertEquals(123, attributes.get("token_count")); - Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); - Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals(15, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals(expectedUserPrompt, attributes.get("content")); + assertEquals("user", attributes.get("role")); + assertEquals(false, attributes.get("is_response")); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertEquals(0, attributes.get("sequence")); + assertFalse(((String) attributes.get("completion_id")).isEmpty()); + assertEquals(123, attributes.get("token_count")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); } @Test @@ -236,31 +238,31 @@ public void testRecordLlmChatCompletionSummaryEvent() { userAttributes.put("test", "test"); // Mock out ModelRequest - InvokeModelRequest mockInvokeModelRequest = Mockito.mock(InvokeModelRequest.class); - SdkBytes mockRequestSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); - Mockito.when(mockRequestSdkBytes.asUtf8String()) + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()) .thenReturn( "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); - Mockito.when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); // Mock out ModelResponse - InvokeModelResponse mockInvokeModelResponse = Mockito.mock(InvokeModelResponse.class); - SdkBytes mockResponseSdkBytes = Mockito.mock(SdkBytes.class); - Mockito.when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); - Mockito.when(mockResponseSdkBytes.asUtf8String()) + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()) .thenReturn( "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); - SdkHttpResponse mockSdkHttpResponse = Mockito.mock(SdkHttpResponse.class); - Mockito.when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); - Mockito.when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); - Mockito.when(mockSdkHttpResponse.statusCode()).thenReturn(200); - Mockito.when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); - BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = Mockito.mock(BedrockRuntimeResponseMetadata.class); - Mockito.when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); - Mockito.when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, mockInvokeModelResponse); @@ -290,28 +292,28 @@ public void testRecordLlmChatCompletionSummaryEvent() { // Then Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); - Assert.assertEquals(1, customEvents.size()); + assertEquals(1, customEvents.size()); Event event = customEvents.iterator().next(); - Assert.assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); Map attributes = event.getAttributes(); - Assert.assertEquals(15, attributes.size()); - Assert.assertEquals("span-id-123", attributes.get("span_id")); - Assert.assertEquals("trace-id-xyz", attributes.get("trace_id")); - Assert.assertEquals("bedrock", attributes.get("vendor")); - Assert.assertEquals("Java", attributes.get("ingest_source")); - Assert.assertFalse(((String) attributes.get("id")).isEmpty()); - Assert.assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); - Assert.assertEquals(0.5f, attributes.get("request.temperature")); - Assert.assertEquals(1000, attributes.get("request.max_tokens")); - Assert.assertEquals("anthropic.claude-v2", attributes.get("request.model")); - Assert.assertEquals("anthropic.claude-v2", attributes.get("response.model")); - Assert.assertEquals(2, attributes.get("response.number_of_messages")); - Assert.assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); - Assert.assertEquals(9000f, attributes.get("duration")); - Assert.assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); - Assert.assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals(15, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals(0.5f, attributes.get("request.temperature")); + assertEquals(1000, attributes.get("request.max_tokens")); + assertEquals("anthropic.claude-v2", attributes.get("request.model")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertEquals(2, attributes.get("response.number_of_messages")); + assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); + assertEquals(9000f, attributes.get("duration")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); } // @Test diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java new file mode 100644 index 0000000000..bfe1187b9d --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java @@ -0,0 +1,96 @@ +package llm.models; + +import com.newrelic.agent.introspec.ErrorEvent; +import com.newrelic.agent.introspec.Event; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TestUtil { + public static void assertLlmChatCompletionMessageAttributes(Event event, String modelId, String requestInput, String responseContent, boolean isResponse) { + assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("completion_id")).isEmpty()); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals(modelId, attributes.get("response.model")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + + if (isResponse) { + assertEquals("assistant", attributes.get("role")); + assertEquals(responseContent, attributes.get("content")); + assertEquals(true, attributes.get("is_response")); + assertEquals(1, attributes.get("sequence")); + } else { + assertEquals("user", attributes.get("role")); + assertEquals(requestInput, attributes.get("content")); + assertEquals(false, attributes.get("is_response")); + assertEquals(0, attributes.get("sequence")); + } + } + + public static void assertLlmChatCompletionSummaryAttributes(Event event, String modelId, String finishReason) { + assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertEquals(0.5f, attributes.get("request.temperature")); + assertTrue(((Float) attributes.get("duration")) >= 0); + assertEquals(finishReason, attributes.get("response.choices.finish_reason")); + assertEquals(modelId, attributes.get("request.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals(modelId, attributes.get("response.model")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals(2, attributes.get("response.number_of_messages")); + assertEquals(1000, attributes.get("request.max_tokens")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + } + + public static void assertLlmEmbeddingAttributes(Event event, String modelId, String requestInput) { + assertEquals(LLM_EMBEDDING, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertTrue(((Float) attributes.get("duration")) >= 0); + assertEquals(requestInput, attributes.get("input")); + assertEquals(modelId, attributes.get("request.model")); + assertEquals(modelId, attributes.get("response.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + } + + public static void assertErrorEvent(boolean isError, Collection errorEvents) { + if (isError) { + assertEquals(1, errorEvents.size()); + Iterator errorEventIterator = errorEvents.iterator(); + ErrorEvent errorEvent = errorEventIterator.next(); + + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorClass()); + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorMessage()); + + Map errorEventAttributes = errorEvent.getAttributes(); + assertFalse(errorEventAttributes.isEmpty()); + assertEquals(400, errorEventAttributes.get("error.code")); + assertEquals(400, errorEventAttributes.get("http.statusCode")); + } else { + assertTrue(errorEvents.isEmpty()); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java new file mode 100644 index 0000000000..99f5f49a79 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java @@ -0,0 +1,162 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class JurassicModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "ai21.j2-mid-v1"; + private final String completionRequestBody = "{\"temperature\":0.5,\"maxTokens\":1000,\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = + "{\"id\":1234,\"prompt\":{\"text\":\"What is the color of the sky?\",\"tokens\":[{\"generatedToken\":{\"token\":\"▁What▁is▁the\",\"logprob\":-9.992481231689453,\"raw_logprob\":-9.992481231689453}\n" + + ",\"topTokens\":null,\"textRange\":{\"start\":0,\"end\":11}}]},\"completions\":[{\"data\":{\"text\":\"\\nThe color of the sky is blue.\",\"tokens\":[{\"generatedToken\":{\"token\":\"<|newline|>\",\"logprob\":0.0,\"raw_logprob\":-1.389883691444993E-4},\"topTokens\":null,\"textRange\":{\"start\":0,\"end\":1}}]},\"finishReason\":{\"reason\":\"endoftext\"}}]}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\nThe color of the sky is blue."; + private final String finishReason = "endoftext"; + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testCompletion() { + boolean isError = false; + + JurassicModelInvocation jurassicModelInvocation = mockJurassicModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + jurassicModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + JurassicModelInvocation jurassicModelInvocation = mockJurassicModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + jurassicModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private JurassicModelInvocation mockJurassicModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new JurassicModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java new file mode 100644 index 0000000000..113ece5d47 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java @@ -0,0 +1,200 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class TitanModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Embedding + private final String embeddingModelId = "amazon.titan-embed-text-v1"; + private final String embeddingRequestBody = "{\"inputText\":\"What is the color of the sky?\"}"; + private final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + private final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + private final String completionModelId = "amazon.titan-text-lite-v1"; + private final String completionRequestBody = "{\"inputText\":\"What is the color of the sky?\",\"textGenerationConfig\":{\"maxTokenCount\":1000,\"stopSequences\":[\"User:\"],\"temperature\":0.5,\"topP\":0.9}}"; + private final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\nThe color of the sky is blue."; + private final String finishReason = "FINISH"; + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testEmbedding() { + boolean isError = false; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletion() { + boolean isError = false; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testEmbeddingError() { + boolean isError = true; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private TitanModelInvocation mockTitanModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new TitanModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java new file mode 100644 index 0000000000..af4ed9906f --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java @@ -0,0 +1,160 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.anthropic.claude; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class ClaudeModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "anthropic.claude-v2"; + private final String completionRequestBody = "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"; + private final String completionResponseBody = "{\"completion\":\" The color of the sky is blue.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"; + private final String completionRequestInput = "Human: What is the color of the sky?\n\nAssistant:"; + private final String completionResponseContent = " The color of the sky is blue."; + private final String finishReason = "stop_sequence"; + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testCompletion() { + boolean isError = false; + + ClaudeModelInvocation claudeModelInvocation = mockClaudeModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + claudeModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + ClaudeModelInvocation claudeModelInvocation = mockClaudeModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + claudeModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private ClaudeModelInvocation mockClaudeModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java new file mode 100644 index 0000000000..c12daa47f0 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java @@ -0,0 +1,200 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class CommandModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Embedding + private final String embeddingModelId = "cohere.embed-english-v3"; + private final String embeddingRequestBody = "{\"texts\":[\"What is the color of the sky?\"],\"truncate\":\"NONE\",\"input_type\":\"search_document\"}"; + private final String embeddingResponseBody = "{\"embeddings\":[[-0.002828598,0.012145996]],\"id\":\"c2c5c119-1268-4155-8c98-50ae199ffa16\",\"response_type\":\"embeddings_floats\",\"texts\":[\"what is the color of the sky?\"]}"; + private final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + private final String completionModelId = "cohere.command-light-text-v14"; + private final String completionRequestBody = "{\"p\":0.9,\"stop_sequences\":[\"User:\"],\"truncate\":\"END\",\"max_tokens\":1000,\"stream\":false,\"temperature\":0.5,\"k\":0,\"return_likelihoods\":\"NONE\",\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = "{\"generations\":[{\"finish_reason\":\"COMPLETE\",\"id\":\"314ba8cf-778d-49ed-a2cb-cf260008a2cc\",\"text\":\" The color of the sky is blue.\"}],\"id\":\"3070a2a7-b5a3-44cf-9908-554fa25473a6\",\"prompt\":\"What is the color of the sky?\"}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = " The color of the sky is blue."; + private final String finishReason = "COMPLETE"; + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testEmbedding() { + boolean isError = false; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletion() { + boolean isError = false; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testEmbeddingError() { + boolean isError = true; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private CommandModelInvocation mockCommandModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new CommandModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java new file mode 100644 index 0000000000..98a85f325d --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java @@ -0,0 +1,160 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class Llama2ModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "meta.llama2-13b-chat-v1"; + private final String completionRequestBody = "{\"top_p\":0.9,\"max_gen_len\":1000,\"temperature\":0.5,\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = "{\"generation\":\"\\n\\nThe color of the sky is blue.\",\"prompt_token_count\":9,\"generation_token_count\":306,\"stop_reason\":\"stop\"}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\n\nThe color of the sky is blue."; + private final String finishReason = "stop"; + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testCompletion() { + boolean isError = false; + + Llama2ModelInvocation llama2ModelInvocation = mockLlama2ModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + llama2ModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + Llama2ModelInvocation llama2ModelInvocation = mockLlama2ModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + llama2ModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private Llama2ModelInvocation mockLlama2ModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new Llama2ModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java new file mode 100644 index 0000000000..6b3a0d46ff --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java @@ -0,0 +1,101 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.AccessDeniedException; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; +import software.amazon.awssdk.services.bedrockruntime.model.InternalServerException; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ModelErrorException; +import software.amazon.awssdk.services.bedrockruntime.model.ModelNotReadyException; +import software.amazon.awssdk.services.bedrockruntime.model.ModelTimeoutException; +import software.amazon.awssdk.services.bedrockruntime.model.ResourceNotFoundException; +import software.amazon.awssdk.services.bedrockruntime.model.ServiceQuotaExceededException; +import software.amazon.awssdk.services.bedrockruntime.model.ThrottlingException; +import software.amazon.awssdk.services.bedrockruntime.model.ValidationException; + +import java.util.HashMap; +import java.util.function.Consumer; + +public class BedrockRuntimeClientMock implements BedrockRuntimeClient { + @Override + public String serviceName() { + return null; + } + + @Override + public void close() { + + } + + @Override + public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) + throws AccessDeniedException, ResourceNotFoundException, ThrottlingException, ModelTimeoutException, InternalServerException, ValidationException, + ModelNotReadyException, ServiceQuotaExceededException, ModelErrorException, AwsServiceException, SdkClientException, BedrockRuntimeException { + + HashMap metadata = new HashMap<>(); + metadata.put("AWS_REQUEST_ID", "9d32a71a-e285-4b14-a23d-4f7d67b50ac3"); + AwsResponseMetadata awsResponseMetadata = new BedrockRuntimeResponseMetadataMock(metadata); + SdkHttpFullResponse sdkHttpFullResponse; + SdkResponse sdkResponse = null; + + boolean isError = invokeModelRequest.body().asUtf8String().contains("\"errorTest\":true"); + + if (invokeModelRequest.modelId().equals("anthropic.claude-v2")) { + // This case will mock out a chat completion request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String( + "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere. The main gases in our atmosphere are nitrogen and oxygen. These gases are transparent to visible light wavelengths, but they scatter shorter wavelengths more, specifically blue light. This scattering makes the sky look blue from the ground.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}")) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } else if (invokeModelRequest.modelId().equals("amazon.titan-embed-text-v1")) { + // This case will mock out an embedding request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}")) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } + return (InvokeModelResponse) sdkResponse; + } + + @Override + public InvokeModelResponse invokeModel(Consumer invokeModelRequest) + throws AccessDeniedException, ResourceNotFoundException, ThrottlingException, ModelTimeoutException, InternalServerException, ValidationException, + ModelNotReadyException, ServiceQuotaExceededException, ModelErrorException, AwsServiceException, SdkClientException, BedrockRuntimeException { + return null; + } + + @Override + public BedrockRuntimeServiceClientConfiguration serviceClientConfiguration() { + return null; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java new file mode 100644 index 0000000000..f3927993e8 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java @@ -0,0 +1,286 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.introspec.ErrorEvent; +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.agent.introspec.TracedMetricData; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import llm.models.ModelResponse; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") + +public class BedrockRuntimeClient_InstrumentationTest { + private static final BedrockRuntimeClientMock mockBedrockRuntimeClient = new BedrockRuntimeClientMock(); + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testInvokeModelCompletion() { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAnthropicClaudeCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + + verifyTransactionResults(ModelResponse.COMPLETION); + verifySupportabilityMetricResults(); + verifyEventResults(ModelResponse.COMPLETION); + verifyErrorResults(isError); + } + + @Test + public void testInvokeModelEmbedding() { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + + verifyTransactionResults(ModelResponse.EMBEDDING); + verifySupportabilityMetricResults(); + verifyEventResults(ModelResponse.EMBEDDING); + verifyErrorResults(isError); + } + + @Test + public void testInvokeModelCompletionError() { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAnthropicClaudeCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + + verifyTransactionResults(ModelResponse.COMPLETION); + verifySupportabilityMetricResults(); + verifyEventResults(ModelResponse.COMPLETION); + verifyErrorResults(isError); + } + + @Test + public void testInvokeModelEmbeddingError() { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + + verifyTransactionResults(ModelResponse.EMBEDDING); + verifySupportabilityMetricResults(); + verifyEventResults(ModelResponse.EMBEDDING); + verifyErrorResults(isError); + } + + private static InvokeModelRequest buildAnthropicClaudeCompletionRequest(boolean isError) { + String prompt = "Human: What is the color of the sky?\n\nAssistant:"; + String modelId = "anthropic.claude-v2"; + + String payload = new JSONObject() + .put("prompt", prompt) + .put("max_tokens_to_sample", 1000) + .put("temperature", 0.5) + .put("stop_sequences", Collections.singletonList("\n\nHuman:")) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(modelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + private static InvokeModelRequest buildAmazonTitanEmbeddingRequest(boolean isError) { + String prompt = "{\"inputText\":\"What is the color of the sky?\"}"; + String modelId = "amazon.titan-embed-text-v1"; + + String payload = new JSONObject() + .put("inputText", prompt) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(modelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + @Trace(dispatcher = true) + private InvokeModelResponse invokeModelInTransaction(InvokeModelRequest invokeModelRequest) { + NewRelic.addCustomParameter("llm.conversation_id", "conversation-id-value"); // Will be added to LLM events + NewRelic.addCustomParameter("llm.testPrefix", "testPrefix"); // Will be added to LLM events + NewRelic.addCustomParameter("test", "test"); // Will NOT be added to LLM events + return mockBedrockRuntimeClient.invokeModel(invokeModelRequest); + } + + private void verifyTransactionResults(String operationType) { + assertEquals(1, introspector.getFinishedTransactionCount(TimeUnit.SECONDS.toMillis(2))); + Collection transactionNames = introspector.getTransactionNames(); + String transactionName = transactionNames.iterator().next(); + Map metrics = introspector.getMetricsForTransaction(transactionName); + assertTrue(metrics.containsKey("Llm/" + operationType + "/Bedrock/invokeModel")); + assertEquals(1, metrics.get("Llm/" + operationType + "/Bedrock/invokeModel").getCallCount()); + } + + private void verifySupportabilityMetricResults() { + Map unscopedMetrics = introspector.getUnscopedMetrics(); + assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); + } + + private void verifyErrorResults(boolean isError) { + Collection errorEvents = introspector.getErrorEvents(); + if (isError) { + assertEquals(1, errorEvents.size()); + Iterator errorEventIterator = errorEvents.iterator(); + ErrorEvent errorEvent = errorEventIterator.next(); + + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorClass()); + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorMessage()); + + Map errorEventAttributes = errorEvent.getAttributes(); + assertFalse(errorEventAttributes.isEmpty()); + assertEquals(400, errorEventAttributes.get("error.code")); + assertEquals(400, errorEventAttributes.get("http.statusCode")); + } else { + assertTrue(errorEvents.isEmpty()); + } + Map unscopedMetrics = introspector.getUnscopedMetrics(); + assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); + } + + private void verifyEventResults(String operationType) { + if (ModelResponse.COMPLETION.equals(operationType)) { + // LlmChatCompletionMessage events + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + // LlmChatCompletionMessage event for user request message + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, false); + + // LlmChatCompletionMessage event for assistant response message + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, true); + + // LlmCompletionSummary events + Collection llmCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmCompletionSummaryEvents.size()); + + Iterator llmCompletionSummaryEventIterator = llmCompletionSummaryEvents.iterator(); + // Summary event for both LlmChatCompletionMessage events + Event llmCompletionSummaryEvent = llmCompletionSummaryEventIterator.next(); + assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent); + } else if (ModelResponse.EMBEDDING.equals(operationType)) { + // LlmEmbedding events + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + // LlmEmbedding event + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + assertLlmEmbeddingAttributes(llmEmbeddingEvent); + } + } + + private void assertLlmChatCompletionMessageAttributes(Event event, boolean isResponse) { + assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("completion_id")).isEmpty()); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + + if (isResponse) { + assertEquals("assistant", attributes.get("role")); + assertEquals( + " The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere. The main gases in our atmosphere are nitrogen and oxygen. These gases are transparent to visible light wavelengths, but they scatter shorter wavelengths more, specifically blue light. This scattering makes the sky look blue from the ground.", + attributes.get("content")); + assertEquals(true, attributes.get("is_response")); + assertEquals(1, attributes.get("sequence")); + } else { + assertEquals("user", attributes.get("role")); + assertEquals("Human: What is the color of the sky?\n\nAssistant:", attributes.get("content")); + assertEquals(false, attributes.get("is_response")); + assertEquals(0, attributes.get("sequence")); + } + } + + private void assertLlmChatCompletionSummaryAttributes(Event event) { + assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertEquals(0.5f, attributes.get("request.temperature")); + assertTrue(((Float) attributes.get("duration")) > 0); + assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); + assertEquals("anthropic.claude-v2", attributes.get("request.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals(2, attributes.get("response.number_of_messages")); + assertEquals(1000, attributes.get("request.max_tokens")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + } + + private void assertLlmEmbeddingAttributes(Event event) { + assertEquals(LLM_EMBEDDING, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertTrue(((Float) attributes.get("duration")) >= 0); + assertEquals("{\"inputText\":\"What is the color of the sky?\"}", attributes.get("input")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + } +} \ No newline at end of file diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java new file mode 100644 index 0000000000..b7967e4381 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java @@ -0,0 +1,18 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; + +import java.util.Map; + +public class BedrockRuntimeResponseMetadataMock extends AwsResponseMetadata { + protected BedrockRuntimeResponseMetadataMock(Map metadata) { + super(metadata); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml b/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml index e69de29bb2..dbcd129940 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml @@ -0,0 +1,11 @@ +common: &default_settings + ai_monitoring: + enabled: true + record_content: + enabled: true + streaming: + enabled: true + + custom_insights_events: + max_samples_stored: 30000 + max_attribute_value: 255 From 3a86c870193fc7cc5e665cca2d4ea3a344cc6857 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 13 Mar 2024 10:25:32 -0700 Subject: [PATCH 30/68] Cleanup tests --- .../test/java/llm/events/LlmEventTest.java | 24 --- .../BedrockRuntimeClientMock.java | 24 ++- ...rockRuntimeClient_InstrumentationTest.java | 178 +++++------------- 3 files changed, 68 insertions(+), 158 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java index ca3d417df0..39858fc881 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -315,28 +315,4 @@ public void testRecordLlmChatCompletionSummaryEvent() { assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); assertEquals("testPrefix", attributes.get("llm.testPrefix")); } - -// @Test -// public void testUntruncatedAttributes() { -// String expectedAttribute = "aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__aaaaaaaaaabbbbbbbbbbccccccccccddddddddddeeeeeeeeeeffffffffffgggggggggghhhhhhhhhhiiiiiiiiiijjjjjjjjjjkkkkkkkkkkllllllllllmmmmmmmmmmnnnnnnnnnnooooooooooppppppppppqqqqqqqqqqrrrrrrrrrrssssssssssttttttttttuuuuuuuuuuvvvvvvvvvvwwwwwwwwwwxxxxxxxxxxyyyyyyyyyyzzzzzzzzzz__"; -// -// Map atts = new HashMap<>(); -// atts.put("content", expectedAttribute); -// atts.put("input", expectedAttribute); -// atts.put("vendor", expectedAttribute); -// -// NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, atts); -// -// Collection customEvents = introspector.getCustomEvents(LLM_EMBEDDING); -// Assert.assertEquals(1, customEvents.size()); -// -// Event event = customEvents.iterator().next(); -// Assert.assertEquals(LLM_EMBEDDING, event.getType()); -// -// Map attributes = event.getAttributes(); -// Assert.assertEquals(expectedAttribute, attributes.get("content")); -// Assert.assertEquals(expectedAttribute, attributes.get("input")); -// Assert.assertEquals(expectedAttribute, attributes.get("vendor")); -// } - } diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java index 6b3a0d46ff..e936971f3a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java @@ -31,6 +31,21 @@ import java.util.function.Consumer; public class BedrockRuntimeClientMock implements BedrockRuntimeClient { + + // Embedding + public static final String embeddingModelId = "amazon.titan-embed-text-v1"; + public static final String embeddingRequestBody = "{\"inputText\":\"What is the color of the sky?\"}"; + public static final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + public static final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + public static final String completionModelId = "amazon.titan-text-lite-v1"; + public static final String completionRequestBody = "{\"inputText\":\"What is the color of the sky?\",\"textGenerationConfig\":{\"maxTokenCount\":1000,\"stopSequences\":[\"User:\"],\"temperature\":0.5,\"topP\":0.9}}"; + public static final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + public static final String completionRequestInput = "What is the color of the sky?"; + public static final String completionResponseContent = "\nThe color of the sky is blue."; + public static final String finishReason = "FINISH"; + @Override public String serviceName() { return null; @@ -54,7 +69,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) boolean isError = invokeModelRequest.body().asUtf8String().contains("\"errorTest\":true"); - if (invokeModelRequest.modelId().equals("anthropic.claude-v2")) { + if (invokeModelRequest.modelId().equals(completionModelId)) { // This case will mock out a chat completion request/response if (isError) { sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); @@ -63,13 +78,12 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) } sdkResponse = InvokeModelResponse.builder() - .body(SdkBytes.fromUtf8String( - "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere. The main gases in our atmosphere are nitrogen and oxygen. These gases are transparent to visible light wavelengths, but they scatter shorter wavelengths more, specifically blue light. This scattering makes the sky look blue from the ground.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}")) + .body(SdkBytes.fromUtf8String(completionResponseBody)) .contentType("application/json") .responseMetadata(awsResponseMetadata) .sdkHttpResponse(sdkHttpFullResponse) .build(); - } else if (invokeModelRequest.modelId().equals("amazon.titan-embed-text-v1")) { + } else if (invokeModelRequest.modelId().equals(embeddingModelId)) { // This case will mock out an embedding request/response if (isError) { sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); @@ -78,7 +92,7 @@ public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) } sdkResponse = InvokeModelResponse.builder() - .body(SdkBytes.fromUtf8String("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}")) + .body(SdkBytes.fromUtf8String(embeddingResponseBody)) .contentType("application/json") .responseMetadata(awsResponseMetadata) .sdkHttpResponse(sdkHttpFullResponse) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java index f3927993e8..a7dcc1b96f 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java @@ -7,7 +7,6 @@ package software.amazon.awssdk.services.bedrockruntime; -import com.newrelic.agent.introspec.ErrorEvent; import com.newrelic.agent.introspec.Event; import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; @@ -33,10 +32,19 @@ import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionResponseContent; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.embeddingModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.embeddingRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.finishReason; @RunWith(InstrumentationTestRunner.class) @InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") @@ -53,15 +61,14 @@ public void before() { @Test public void testInvokeModelCompletion() { boolean isError = false; - InvokeModelRequest invokeModelRequest = buildAnthropicClaudeCompletionRequest(isError); + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); assertNotNull(invokeModelResponse); - - verifyTransactionResults(ModelResponse.COMPLETION); - verifySupportabilityMetricResults(); - verifyEventResults(ModelResponse.COMPLETION); - verifyErrorResults(isError); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); } @Test @@ -71,25 +78,23 @@ public void testInvokeModelEmbedding() { InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); assertNotNull(invokeModelResponse); - - verifyTransactionResults(ModelResponse.EMBEDDING); - verifySupportabilityMetricResults(); - verifyEventResults(ModelResponse.EMBEDDING); - verifyErrorResults(isError); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); } @Test public void testInvokeModelCompletionError() { boolean isError = true; - InvokeModelRequest invokeModelRequest = buildAnthropicClaudeCompletionRequest(isError); + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); assertNotNull(invokeModelResponse); - - verifyTransactionResults(ModelResponse.COMPLETION); - verifySupportabilityMetricResults(); - verifyEventResults(ModelResponse.COMPLETION); - verifyErrorResults(isError); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); } @Test @@ -99,45 +104,42 @@ public void testInvokeModelEmbeddingError() { InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); assertNotNull(invokeModelResponse); - - verifyTransactionResults(ModelResponse.EMBEDDING); - verifySupportabilityMetricResults(); - verifyEventResults(ModelResponse.EMBEDDING); - verifyErrorResults(isError); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); } - private static InvokeModelRequest buildAnthropicClaudeCompletionRequest(boolean isError) { - String prompt = "Human: What is the color of the sky?\n\nAssistant:"; - String modelId = "anthropic.claude-v2"; + private static InvokeModelRequest buildAmazonTitanCompletionRequest(boolean isError) { + JSONObject textGenerationConfig = new JSONObject() + .put("maxTokenCount", 1000) + .put("stopSequences", Collections.singletonList("User:")) + .put("temperature", 0.5) + .put("topP", 0.9); String payload = new JSONObject() - .put("prompt", prompt) - .put("max_tokens_to_sample", 1000) - .put("temperature", 0.5) - .put("stop_sequences", Collections.singletonList("\n\nHuman:")) + .put("inputText", completionRequestInput) + .put("textGenerationConfig", textGenerationConfig) .put("errorTest", isError) // this is not a real model attribute, just adding for testing .toString(); return InvokeModelRequest.builder() .body(SdkBytes.fromUtf8String(payload)) - .modelId(modelId) + .modelId(completionModelId) .contentType("application/json") .accept("application/json") .build(); } private static InvokeModelRequest buildAmazonTitanEmbeddingRequest(boolean isError) { - String prompt = "{\"inputText\":\"What is the color of the sky?\"}"; - String modelId = "amazon.titan-embed-text-v1"; - String payload = new JSONObject() - .put("inputText", prompt) + .put("inputText", embeddingRequestInput) .put("errorTest", isError) // this is not a real model attribute, just adding for testing .toString(); return InvokeModelRequest.builder() .body(SdkBytes.fromUtf8String(payload)) - .modelId(modelId) + .modelId(embeddingModelId) .contentType("application/json") .accept("application/json") .build(); @@ -151,7 +153,7 @@ private InvokeModelResponse invokeModelInTransaction(InvokeModelRequest invokeMo return mockBedrockRuntimeClient.invokeModel(invokeModelRequest); } - private void verifyTransactionResults(String operationType) { + private void assertTransaction(String operationType) { assertEquals(1, introspector.getFinishedTransactionCount(TimeUnit.SECONDS.toMillis(2))); Collection transactionNames = introspector.getTransactionNames(); String transactionName = transactionNames.iterator().next(); @@ -160,33 +162,12 @@ private void verifyTransactionResults(String operationType) { assertEquals(1, metrics.get("Llm/" + operationType + "/Bedrock/invokeModel").getCallCount()); } - private void verifySupportabilityMetricResults() { + private void assertSupportabilityMetrics() { Map unscopedMetrics = introspector.getUnscopedMetrics(); assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); } - private void verifyErrorResults(boolean isError) { - Collection errorEvents = introspector.getErrorEvents(); - if (isError) { - assertEquals(1, errorEvents.size()); - Iterator errorEventIterator = errorEvents.iterator(); - ErrorEvent errorEvent = errorEventIterator.next(); - - assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorClass()); - assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorMessage()); - - Map errorEventAttributes = errorEvent.getAttributes(); - assertFalse(errorEventAttributes.isEmpty()); - assertEquals(400, errorEventAttributes.get("error.code")); - assertEquals(400, errorEventAttributes.get("http.statusCode")); - } else { - assertTrue(errorEvents.isEmpty()); - } - Map unscopedMetrics = introspector.getUnscopedMetrics(); - assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); - } - - private void verifyEventResults(String operationType) { + private void assertLlmEvents(String operationType) { if (ModelResponse.COMPLETION.equals(operationType)) { // LlmChatCompletionMessage events Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); @@ -195,11 +176,13 @@ private void verifyEventResults(String operationType) { Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); // LlmChatCompletionMessage event for user request message Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); - assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, false); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, + false); // LlmChatCompletionMessage event for assistant response message Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); - assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, true); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, + true); // LlmCompletionSummary events Collection llmCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); @@ -208,7 +191,7 @@ private void verifyEventResults(String operationType) { Iterator llmCompletionSummaryEventIterator = llmCompletionSummaryEvents.iterator(); // Summary event for both LlmChatCompletionMessage events Event llmCompletionSummaryEvent = llmCompletionSummaryEventIterator.next(); - assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent); + assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent, completionModelId, finishReason); } else if (ModelResponse.EMBEDDING.equals(operationType)) { // LlmEmbedding events Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); @@ -217,70 +200,7 @@ private void verifyEventResults(String operationType) { Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); // LlmEmbedding event Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); - assertLlmEmbeddingAttributes(llmEmbeddingEvent); - } - } - - private void assertLlmChatCompletionMessageAttributes(Event event, boolean isResponse) { - assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); - - Map attributes = event.getAttributes(); - assertEquals("Java", attributes.get("ingest_source")); - assertFalse(((String) attributes.get("completion_id")).isEmpty()); - assertFalse(((String) attributes.get("id")).isEmpty()); - assertFalse(((String) attributes.get("request_id")).isEmpty()); - assertEquals("bedrock", attributes.get("vendor")); - assertEquals("anthropic.claude-v2", attributes.get("response.model")); - assertEquals("testPrefix", attributes.get("llm.testPrefix")); - assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); - - if (isResponse) { - assertEquals("assistant", attributes.get("role")); - assertEquals( - " The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere. The main gases in our atmosphere are nitrogen and oxygen. These gases are transparent to visible light wavelengths, but they scatter shorter wavelengths more, specifically blue light. This scattering makes the sky look blue from the ground.", - attributes.get("content")); - assertEquals(true, attributes.get("is_response")); - assertEquals(1, attributes.get("sequence")); - } else { - assertEquals("user", attributes.get("role")); - assertEquals("Human: What is the color of the sky?\n\nAssistant:", attributes.get("content")); - assertEquals(false, attributes.get("is_response")); - assertEquals(0, attributes.get("sequence")); + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); } } - - private void assertLlmChatCompletionSummaryAttributes(Event event) { - assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); - - Map attributes = event.getAttributes(); - assertEquals("Java", attributes.get("ingest_source")); - assertEquals(0.5f, attributes.get("request.temperature")); - assertTrue(((Float) attributes.get("duration")) > 0); - assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); - assertEquals("anthropic.claude-v2", attributes.get("request.model")); - assertEquals("bedrock", attributes.get("vendor")); - assertEquals("anthropic.claude-v2", attributes.get("response.model")); - assertFalse(((String) attributes.get("id")).isEmpty()); - assertFalse(((String) attributes.get("request_id")).isEmpty()); - assertEquals(2, attributes.get("response.number_of_messages")); - assertEquals(1000, attributes.get("request.max_tokens")); - assertEquals("testPrefix", attributes.get("llm.testPrefix")); - assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); - } - - private void assertLlmEmbeddingAttributes(Event event) { - assertEquals(LLM_EMBEDDING, event.getType()); - - Map attributes = event.getAttributes(); - assertEquals("Java", attributes.get("ingest_source")); - assertTrue(((Float) attributes.get("duration")) >= 0); - assertEquals("{\"inputText\":\"What is the color of the sky?\"}", attributes.get("input")); - assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); - assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); - assertEquals("bedrock", attributes.get("vendor")); - assertFalse(((String) attributes.get("id")).isEmpty()); - assertFalse(((String) attributes.get("request_id")).isEmpty()); - assertEquals("testPrefix", attributes.get("llm.testPrefix")); - assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); - } -} \ No newline at end of file +} From 09a64a97fa0d5d40702e1a170c569510edebb90a Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 13 Mar 2024 10:58:35 -0700 Subject: [PATCH 31/68] Add tests for async Bedrock client --- .../BedrockRuntimeAsyncClientMock.java | 108 +++++++++ ...untimeAsyncClient_InstrumentationTest.java | 209 ++++++++++++++++++ .../BedrockRuntimeClientMock.java | 2 - 3 files changed, 317 insertions(+), 2 deletions(-) create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java create mode 100644 instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java new file mode 100644 index 0000000000..6558c5957c --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java @@ -0,0 +1,108 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.HashMap; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +public class BedrockRuntimeAsyncClientMock implements BedrockRuntimeAsyncClient { + + // Embedding + public static final String embeddingModelId = "amazon.titan-embed-text-v1"; + public static final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + public static final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + public static final String completionModelId = "amazon.titan-text-lite-v1"; + public static final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + public static final String completionRequestInput = "What is the color of the sky?"; + public static final String completionResponseContent = "\nThe color of the sky is blue."; + public static final String finishReason = "FINISH"; + + @Override + public String serviceName() { + return null; + } + + @Override + public void close() { + + } + + @Override + public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + HashMap metadata = new HashMap<>(); + metadata.put("AWS_REQUEST_ID", "9d32a71a-e285-4b14-a23d-4f7d67b50ac3"); + AwsResponseMetadata awsResponseMetadata = new BedrockRuntimeResponseMetadataMock(metadata); + SdkHttpFullResponse sdkHttpFullResponse; + SdkResponse sdkResponse = null; + + boolean isError = invokeModelRequest.body().asUtf8String().contains("\"errorTest\":true"); + + if (invokeModelRequest.modelId().equals(completionModelId)) { + // This case will mock out a chat completion request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(completionResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } else if (invokeModelRequest.modelId().equals(embeddingModelId)) { + // This case will mock out an embedding request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(embeddingResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } + return CompletableFuture.completedFuture((InvokeModelResponse) sdkResponse); + } + + @Override + public CompletableFuture invokeModelWithResponseStream(InvokeModelWithResponseStreamRequest invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + return BedrockRuntimeAsyncClient.super.invokeModelWithResponseStream(invokeModelWithResponseStreamRequest, asyncResponseHandler); + // Streaming not currently supported + } + + @Override + public CompletableFuture invokeModelWithResponseStream(Consumer invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + return BedrockRuntimeAsyncClient.super.invokeModelWithResponseStream(invokeModelWithResponseStreamRequest, asyncResponseHandler); + // Streaming not currently supported + } + + @Override + public BedrockRuntimeServiceClientConfiguration serviceClientConfiguration() { + return null; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java new file mode 100644 index 0000000000..928ecc87b3 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java @@ -0,0 +1,209 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.agent.introspec.TracedMetricData; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import llm.models.ModelResponse; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionResponseContent; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.embeddingModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.embeddingRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.finishReason; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") + +public class BedrockRuntimeAsyncClient_InstrumentationTest { + private static final BedrockRuntimeAsyncClientMock mockBedrockRuntimeAsyncClient = new BedrockRuntimeAsyncClientMock(); + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testInvokeModelCompletion() throws ExecutionException, InterruptedException { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbedding() throws ExecutionException, InterruptedException { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelCompletionError() throws ExecutionException, InterruptedException { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbeddingError() throws ExecutionException, InterruptedException { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private static InvokeModelRequest buildAmazonTitanCompletionRequest(boolean isError) { + JSONObject textGenerationConfig = new JSONObject() + .put("maxTokenCount", 1000) + .put("stopSequences", Collections.singletonList("User:")) + .put("temperature", 0.5) + .put("topP", 0.9); + + String payload = new JSONObject() + .put("inputText", completionRequestInput) + .put("textGenerationConfig", textGenerationConfig) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(completionModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + private static InvokeModelRequest buildAmazonTitanEmbeddingRequest(boolean isError) { + String payload = new JSONObject() + .put("inputText", embeddingRequestInput) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(embeddingModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + @Trace(dispatcher = true) + private InvokeModelResponse invokeModelInTransaction(InvokeModelRequest invokeModelRequest) throws ExecutionException, InterruptedException { + NewRelic.addCustomParameter("llm.conversation_id", "conversation-id-value"); // Will be added to LLM events + NewRelic.addCustomParameter("llm.testPrefix", "testPrefix"); // Will be added to LLM events + NewRelic.addCustomParameter("test", "test"); // Will NOT be added to LLM events + CompletableFuture invokeModelResponseCompletableFuture = mockBedrockRuntimeAsyncClient.invokeModel(invokeModelRequest); + return invokeModelResponseCompletableFuture.get(); + } + + private void assertTransaction(String operationType) { + assertEquals(1, introspector.getFinishedTransactionCount(TimeUnit.SECONDS.toMillis(2))); + Collection transactionNames = introspector.getTransactionNames(); + String transactionName = transactionNames.iterator().next(); + Map metrics = introspector.getMetricsForTransaction(transactionName); + assertTrue(metrics.containsKey("Llm/" + operationType + "/Bedrock/invokeModel")); + assertEquals(1, metrics.get("Llm/" + operationType + "/Bedrock/invokeModel").getCallCount()); + } + + private void assertSupportabilityMetrics() { + Map unscopedMetrics = introspector.getUnscopedMetrics(); + assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); + } + + private void assertLlmEvents(String operationType) { + if (ModelResponse.COMPLETION.equals(operationType)) { + // LlmChatCompletionMessage events + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + // LlmChatCompletionMessage event for user request message + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, + false); + + // LlmChatCompletionMessage event for assistant response message + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, + true); + + // LlmCompletionSummary events + Collection llmCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmCompletionSummaryEvents.size()); + + Iterator llmCompletionSummaryEventIterator = llmCompletionSummaryEvents.iterator(); + // Summary event for both LlmChatCompletionMessage events + Event llmCompletionSummaryEvent = llmCompletionSummaryEventIterator.next(); + assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent, completionModelId, finishReason); + } else if (ModelResponse.EMBEDDING.equals(operationType)) { + // LlmEmbedding events + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + // LlmEmbedding event + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java index e936971f3a..81b164861e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java @@ -34,13 +34,11 @@ public class BedrockRuntimeClientMock implements BedrockRuntimeClient { // Embedding public static final String embeddingModelId = "amazon.titan-embed-text-v1"; - public static final String embeddingRequestBody = "{\"inputText\":\"What is the color of the sky?\"}"; public static final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; public static final String embeddingRequestInput = "What is the color of the sky?"; // Completion public static final String completionModelId = "amazon.titan-text-lite-v1"; - public static final String completionRequestBody = "{\"inputText\":\"What is the color of the sky?\",\"textGenerationConfig\":{\"maxTokenCount\":1000,\"stopSequences\":[\"User:\"],\"temperature\":0.5,\"topP\":0.9}}"; public static final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; public static final String completionRequestInput = "What is the color of the sky?"; public static final String completionResponseContent = "\nThe color of the sky is blue."; From cb7753d134e83686baa467544a6286906d7cc7c4 Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Wed, 13 Mar 2024 11:12:59 -0700 Subject: [PATCH 32/68] Add copyright header --- .../src/test/java/llm/models/TestUtil.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java index bfe1187b9d..15d0149b1e 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java @@ -1,3 +1,10 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + package llm.models; import com.newrelic.agent.introspec.ErrorEvent; From db637ec2c893031dc1d391f2253fddf9324126b2 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:13:38 -0700 Subject: [PATCH 33/68] add ai monitoring interface --- .../com/newrelic/api/agent/AiMonitoring.java | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java new file mode 100644 index 0000000000..67f54a039f --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java @@ -0,0 +1,37 @@ +package com.newrelic.api.agent; + +import java.util.Map; + +public interface AiMonitoring { + /** + * Records an LlmFeedbackMessage event. + * + * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct + * the llmFeedbackEventAttributes map, use + * {@link LlmFeedbackEventAttributes.Builder} + *

The map must include:

+ * + *

Optional attributes: + *

+ * + * + */ + void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes); + + /** + * Registers a callback function for providing token counts to LLM events. + * + * @param callback Callback function for calculating token counts + */ +// void setLlmTokenCountCallback(LlmTokenCountCallback callback); + +} From b22730eff54c6bf40a400adb5c97c010752032b6 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:14:42 -0700 Subject: [PATCH 34/68] added ai monitoring implementation --- .../api/agent/LlmFeedbackEventRecorder.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java new file mode 100644 index 0000000000..4190066274 --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java @@ -0,0 +1,38 @@ +package com.newrelic.api.agent; + +import java.util.Map; + +/** + * A utility class for recording LlmFeedbackMessage events using the AI Monitoring API. + *

+ * This class implements the {@link AiMonitoring} interface and provides a method to record LlmFeedbackMessage events + * by delegating to the Insights API for custom event recording. + */ + +public class LlmFeedbackEventRecorder implements AiMonitoring { + /** + * Records an LlmFeedbackMessage event. + * + * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct + * the llmFeedbackEventAttributes map, use + * {@link LlmFeedbackEventAttributes.Builder} + *

The map must include:

+ *
    + *
  • "traceId" (String): Trace ID where the chat completion related to the + * feedback event occurred
  • + *
  • "rating" (Integer/String): Rating provided by an end user
  • + *
+ * Optional attributes: + *
    + *
  • "category" (String): Category of the feedback as provided by the end user
  • + *
  • "message" (String): Freeform text feedback from an end user.
  • + *
  • "metadata" (Map<String, String>): Set of key-value pairs to store + * additional data to submit with the feedback event
  • + *
+ */ + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + // Delegate to Insights API for event recording + NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); + } +} From 440fce97434120065c8549b5a8927fe686b85652 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:23:08 -0700 Subject: [PATCH 35/68] create builder class --- .../api/agent/LlmFeedbackEventAttributes.java | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java new file mode 100644 index 0000000000..c0e0214350 --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java @@ -0,0 +1,106 @@ +package com.newrelic.api.agent; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +public class LlmFeedbackEventAttributes { + private final String traceId; + private final Object rating; + private final String category; + private final String message; + private final Map metadata; + private final UUID id; + private final String ingestSource; + + protected LlmFeedbackEventAttributes(String traceId, Object rating, String category, String message, Map metadata, UUID id, String ingestSource) { + this.traceId = traceId; + this.rating = rating; + this.category = category; + this.message = message; + this.metadata = metadata; + this.id = id; + this.ingestSource = ingestSource; + } + + public String getTraceId() { + return traceId; + } + + public Object getRating() { + return rating; + } + + + public String getCategory() { + return category; + } + + public String getMessage() { + return message; + } + + public Map getMetadata() { + return metadata; + } + + public UUID getId() { + return id; + } + + public String getIngestSource() { + return ingestSource; + } + + public Map toMap() { + Map feedbackParametersMap = new HashMap<>(); + feedbackParametersMap.put("traceId", getTraceId()); + feedbackParametersMap.put("rating", getRating()); + feedbackParametersMap.put("id", getId()); + feedbackParametersMap.put("ingestSource", getIngestSource()); + if (category != null) { + feedbackParametersMap.put("category", getCategory()); + } + if (message != null) { + feedbackParametersMap.put("message", getMessage()); + } + if (metadata != null) { + feedbackParametersMap.put("metadata", getMetadata()); + } + return feedbackParametersMap; + } + + public static class Builder { + private final String traceId; + private final Object rating; + private String category; + private String message; + private Map metadata; + private final UUID id = UUID.randomUUID(); + + public Builder(String traceId, Object rating) { + this.traceId = traceId; + this.rating = rating; + } + + public Builder category(String category) { + this.category = category; + return this; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Map build() { + return new LlmFeedbackEventAttributes(traceId, rating, category, message, metadata, id, "Java").toMap(); + + } + } +} From cc3171796234ca24008096fd96995aa59faddeff Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:23:54 -0700 Subject: [PATCH 36/68] add test suite for builder class --- .../agent/LlmFeedbackEventAttributesTest.java | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java new file mode 100644 index 0000000000..c42340b8ed --- /dev/null +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java @@ -0,0 +1,91 @@ +package com.newrelic.api.agent; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +public class LlmFeedbackEventAttributesTest { + + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + Map llmFeedbackEventAttributes; + + @Before + public void setup() { + String traceId = "123456"; + Object rating = 3; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + } + + @Test + public void testBuilderWithRequiredParamsOnly() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertNotNull(llmFeedbackEventAttributes.get("id")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingestSource")); + assertFalse(llmFeedbackEventAttributes.containsKey("category")); + assertFalse(llmFeedbackEventAttributes.containsKey("message")); + assertFalse(llmFeedbackEventAttributes.containsKey("metadata")); + } + + @Test + public void testBuilderWithRequiredAndOptionalParams() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category("exampleCategory") + .message("exampleMessage") + .metadata(createMetadataMap()) + .build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertEquals("exampleCategory", llmFeedbackEventAttributes.get("category")); + assertEquals("exampleMessage", llmFeedbackEventAttributes.get("message")); + } + + @Test + public void testBuilderWithOptionalParamsSetToNull() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category(null) + .message(null) + .metadata(null) + .build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertNull(llmFeedbackEventAttributes.get("category")); + assertNull(llmFeedbackEventAttributes.get("message")); + assertNull(llmFeedbackEventAttributes.get("metadata")); + assertNotNull(llmFeedbackEventAttributes.get("id")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingestSource")); + } + + @Test + public void testBuilderWithRatingParamAsStringType() { + String traceId2 = "123456"; + Object rating2 = "3"; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId2, rating2); + llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals("3", llmFeedbackEventAttributes.get("rating")); + } + + public Map createMetadataMap() { + Map map = new HashMap<>(); + map.put("key1", "val1"); + map.put("key2", "val2"); + return map; + } +} \ No newline at end of file From 8996831261638c995dc421314d6ff4b372a1637b Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:24:29 -0700 Subject: [PATCH 37/68] start test suite for feedbackEvent recorder class --- .../agent/LlmFeedbackEventRecorderTest.java | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java new file mode 100644 index 0000000000..a51a7731cb --- /dev/null +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java @@ -0,0 +1,47 @@ +package com.newrelic.api.agent; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; + +public class LlmFeedbackEventRecorderTest { + + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + LlmFeedbackEventRecorder recordLlmFeedbackEvent; + Map llmFeedbackEventParameters; + + @Before + public void setup() { + String traceId = "123456"; + Integer rating = 5; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + recordLlmFeedbackEvent = new LlmFeedbackEventRecorder(); + } + + @Test + public void testRecordLlmFeedbackEvent() { + llmFeedbackEventParameters = llmFeedbackEventBuilder + .category("General") + .message("Great experience") + .build(); + + recordLlmFeedbackEvent.recordLlmFeedbackEvent(llmFeedbackEventParameters); + + // TODO: verify recordCustomEvent was called with the correct parameters + } + + @Test + public void testRecordLlmFeedbackEvent_NullMap() { + // TODO: invoke the method with a null map + // TODO: verify recordCustomEvent was not called + } + + @Test + public void testRecordLlmFeedbackEvent_EmptyMap() { + // TODO: invoke the method with an empty map + // TODO: verify recordCustomEvent was not called + } + + +} From 5391ed5164f2e8505e603f26bc3363d9b0c3a8bb Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:25:53 -0700 Subject: [PATCH 38/68] add getAiMonitoring to Agent interface --- .../src/main/java/com/newrelic/api/agent/Agent.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java b/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java index ef289ca4ad..023c34d278 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java @@ -63,6 +63,13 @@ public interface Agent { */ Insights getInsights(); + /** + * Provides access to the AI Monitoring custom events API. + * + * @return Object for recording custom events. + */ + AiMonitoring getAiMonitoring(); + ErrorApi getErrorApi(); /** From edec98ed7207867e30f73b21c117e038b2fffb21 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 14 Mar 2024 17:28:52 -0700 Subject: [PATCH 39/68] update locations where Agent interface is implemented --- .../src/main/java/com/newrelic/agent/bridge/NoOpAgent.java | 6 ++++++ .../com/newrelic/agent/extension/FakeExtensionAgent.java | 6 ++++++ .../src/main/java/com/newrelic/agent/AgentImpl.java | 7 +++++++ .../src/main/java/com/newrelic/api/agent/NoOpAgent.java | 5 +++++ .../com/newrelic/opentelemetry/OpenTelemetryAgent.java | 6 ++++++ 5 files changed, 30 insertions(+) diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java index 3eb81761d1..90521cc071 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java @@ -7,6 +7,7 @@ package com.newrelic.agent.bridge; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; @@ -65,6 +66,11 @@ public Insights getInsights() { return NoOpInsights.INSTANCE; } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public ErrorApi getErrorApi() { return NoOpErrorApi.INSTANCE; diff --git a/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java b/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java index f71962dd0e..bc1721eed3 100644 --- a/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java +++ b/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java @@ -10,6 +10,7 @@ import com.newrelic.agent.bridge.Agent; import com.newrelic.agent.bridge.TracedMethod; import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; @@ -38,6 +39,11 @@ public Logger getLogger() { @Override public Insights getInsights() { throw new RuntimeException(); } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public ErrorApi getErrorApi() { throw new RuntimeException(); } diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java index 9e79c409f9..f509a67a5b 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java @@ -15,8 +15,10 @@ import com.newrelic.agent.bridge.Transaction; import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.tracers.Tracer; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; +import com.newrelic.api.agent.LlmFeedbackEventRecorder; import com.newrelic.api.agent.Logger; import com.newrelic.api.agent.Logs; import com.newrelic.api.agent.MetricAggregator; @@ -136,6 +138,11 @@ public Insights getInsights() { return ServiceFactory.getServiceManager().getInsights(); } + @Override + public AiMonitoring getAiMonitoring() { + return new LlmFeedbackEventRecorder(); + } + @Override public Logs getLogSender() { return ServiceFactory.getServiceManager().getLogSenderService(); diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java index b207ad799f..16270a02b0 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java @@ -458,6 +458,11 @@ public Insights getInsights() { return INSIGHTS; } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public ErrorApi getErrorApi() { return ERROR_API; diff --git a/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java b/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java index 8fd5de5bbf..ce1fdeb23a 100644 --- a/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java +++ b/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java @@ -8,6 +8,7 @@ package com.newrelic.opentelemetry; import com.newrelic.api.agent.Agent; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.Insights; import com.newrelic.api.agent.Logger; @@ -73,6 +74,11 @@ public Insights getInsights() { return openTelemetryInsights; } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public TraceMetadata getTraceMetadata() { OpenTelemetryNewRelic.logUnsupportedMethod("Agent", "getTraceMetadata"); From 40594b5f269c955bd066fb0f7a25095ff74d4575 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 25 Mar 2024 11:29:00 -0700 Subject: [PATCH 40/68] refactored naming of ai api implementation class --- .../aimonitoring/AiMonitoringImplTest.java | 21 +++++++++++++++---- ...entRecorder.java => AiMonitoringImpl.java} | 21 +++++++++++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) rename newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java => newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java (63%) rename newrelic-api/src/main/java/com/newrelic/api/agent/{LlmFeedbackEventRecorder.java => AiMonitoringImpl.java} (68%) diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java similarity index 63% rename from newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java rename to newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java index a51a7731cb..7c5ea1266f 100644 --- a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventRecorderTest.java +++ b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java @@ -5,10 +5,10 @@ import java.util.Map; -public class LlmFeedbackEventRecorderTest { +public class AiMonitoringImplTest { LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; - LlmFeedbackEventRecorder recordLlmFeedbackEvent; + AiMonitoringImpl aiMonitoringImpl; Map llmFeedbackEventParameters; @Before @@ -16,7 +16,7 @@ public void setup() { String traceId = "123456"; Integer rating = 5; llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); - recordLlmFeedbackEvent = new LlmFeedbackEventRecorder(); + aiMonitoringImpl = new AiMonitoringImpl(); } @Test @@ -26,7 +26,7 @@ public void testRecordLlmFeedbackEvent() { .message("Great experience") .build(); - recordLlmFeedbackEvent.recordLlmFeedbackEvent(llmFeedbackEventParameters); + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventParameters); // TODO: verify recordCustomEvent was called with the correct parameters } @@ -43,5 +43,18 @@ public void testRecordLlmFeedbackEvent_EmptyMap() { // TODO: verify recordCustomEvent was not called } + @Test + public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { + class TestCallback implements LlmTokenCountCallback { + + @Override + public Integer calculateLlmTokenCount(String model, String content) { + return -5; + } + } + + TestCallback testCallback = new TestCallback(); + aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + } } diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java similarity index 68% rename from newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java rename to newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java index 4190066274..2251367547 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventRecorder.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java @@ -1,4 +1,8 @@ -package com.newrelic.api.agent; +package com.newrelic.agent.aimonitoring; + +import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.NewRelic; import java.util.Map; @@ -9,7 +13,7 @@ * by delegating to the Insights API for custom event recording. */ -public class LlmFeedbackEventRecorder implements AiMonitoring { +public class AiMonitoringImpl implements AiMonitoring { /** * Records an LlmFeedbackMessage event. * @@ -30,9 +34,22 @@ public class LlmFeedbackEventRecorder implements AiMonitoring { * additional data to submit with the feedback event * */ + @Override public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { // Delegate to Insights API for event recording NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); } + + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + String model = "SampleModel"; + String content = "SampleContent"; +// LlmTokenCountCallbackHolder llmTokenCountCallbackHolder = new LlmTokenCountCallbackHolder(llmTokenCountCallback); + LlmTokenCountCallbackHolder llmTokenCountCallbackHolder = new LlmTokenCountCallbackHolder(); + llmTokenCountCallbackHolder.setLlmTokenCountCallbackHolder(llmTokenCountCallback); + LlmTokenCountCallback tokenCounter = llmTokenCountCallbackHolder.getLlmTokenCountCallback(); + Integer tokenCount = llmTokenCountCallback.calculateLlmTokenCount(model, content); + + } } From e2b38085e7221443dc2f549b00ad71e9fefe10dd Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 25 Mar 2024 11:29:59 -0700 Subject: [PATCH 41/68] updated argument type for tokenCount --- .../src/main/java/llm/events/LlmEvent.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java index 81d8211ff9..221da47de2 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -197,7 +197,7 @@ public Builder requestModel() { return this; } - public Builder tokenCount(int count) { + public Builder tokenCount(Integer count) { tokenCount = count; return this; } From bc9d7ef77c9a828fa61a8584199066058420ffa0 Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 15:29:07 -0700 Subject: [PATCH 42/68] created callback interface --- .../api/agent/LlmTokenCountCallback.java | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java new file mode 100644 index 0000000000..809c628e0b --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java @@ -0,0 +1,41 @@ +package com.newrelic.api.agent; + +/** + * An interface for calculating the number of tokens used for a given LLM (Large Language Model) and content. + *

+ * Implement this interface to define custom logic for token calculation based on your application's requirements. + *

+ *

+ * Example usage: + *

{@code
+ * class MyTokenCountCallback implements LlmTokenCountCallback {
+ *
+ *     @Override
+ *     public Integer calculateLlmTokenCount(String model, String content) {
+ *         // Implement your custom token calculating logic here
+ *         // This example calculates the number of tokens based on the length of the content
+ *         return content.length();
+ *     }
+ * }
+ *
+ * // Usage:
+ * LlmTokenCountCallback myCallback = new MyTokenCountCallback();
+ * // After creating the {@code myCallback} instance, it should be passed as an argument to the {@code setLlmTokenCountCallback}
+ * // method of the AI Monitoring API.
+ * NewRelic.getAgent().getAiMonitoring.setLlmTokenCountCallback(myCallback);
+ * }
+ *

+ */ +public interface LlmTokenCountCallback { + + + /** + * Calculates the number of tokens used for a given LLM model and content. + * + * @param model The name of the LLM model. + * @param content The message content or prompt. + * @return An integer representing the number of tokens used for the given model and content. + * If the count cannot be determined or is less than or equal to 0, null is returned. + */ + public Integer calculateLlmTokenCount(String model, String content); +} From c68c4784c734a86c341c573f5526019e40228d5f Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 15:31:09 -0700 Subject: [PATCH 43/68] created singleton holder class to store the users callback method --- .../agent/LlmTokenCountCallbackHolder.java | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java new file mode 100644 index 0000000000..de2dcdc5a4 --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java @@ -0,0 +1,55 @@ +package com.newrelic.api.agent; + +/** + * A singleton class for holding an instance of {@link LlmTokenCountCallback}. + * This class ensures that only one instance of the callback is stored and accessed throughout the application. + */ +public class LlmTokenCountCallbackHolder { + + private static volatile LlmTokenCountCallbackHolder INSTANCE; + private static volatile LlmTokenCountCallback llmTokenCountCallback; + + /** + * Private constructor to prevent instantiation from outside. + * + * @param llmTokenCountCallback The callback method to be stored. + */ + private LlmTokenCountCallbackHolder(LlmTokenCountCallback llmTokenCountCallback) { + LlmTokenCountCallbackHolder.llmTokenCountCallback = llmTokenCountCallback; + }; + + /** + * Returns the singleton instance of the {@code LlmTokenCountCallbackHolder}. + * + * @return The singleton instance. + */ + public static LlmTokenCountCallbackHolder getInstance() { + if (INSTANCE == null) { + synchronized (LlmTokenCountCallbackHolder.class) { + if (INSTANCE == null) { + INSTANCE = new LlmTokenCountCallbackHolder(llmTokenCountCallback); + } + } + } + return INSTANCE; + } + + /** + * Sets the {@link LlmTokenCountCallback} instance to be stored. + * + * @param llmTokenCountCallback The callback instance to be stored. + */ + public static void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + LlmTokenCountCallbackHolder.llmTokenCountCallback = llmTokenCountCallback; + } + + /** + * Retrieves the stored {@link LlmTokenCountCallback} instance. + * + * @return The stored callback instance. + */ + public LlmTokenCountCallback getLlmTokenCountCallback() { + return llmTokenCountCallback; + } + +} \ No newline at end of file From eb4d421b80e188090e31b97792f37fb6e27b7302 Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 16:33:33 -0700 Subject: [PATCH 44/68] added example for setting a tokenCount callback --- .../com/newrelic/api/agent/AiMonitoring.java | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java index 67f54a039f..42494b1aab 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java @@ -2,6 +2,9 @@ import java.util.Map; +/** + * This interface defines methods for recording LlmFeedbackMessage events and setting a callback for token calculation. + */ public interface AiMonitoring { /** * Records an LlmFeedbackMessage event. @@ -28,10 +31,23 @@ public interface AiMonitoring { void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes); /** - * Registers a callback function for providing token counts to LLM events. + * Sets the callback function for calculating LLM tokens. * - * @param callback Callback function for calculating token counts + * @param llmTokenCountCallback The callback function to be invoked for counting LLM tokens. + * Example usage: + *
{@code
+     *                              LlmTokenCountCallback llmTokenCountCallback = new LlmTokenCountCallback() {
+     *                                  {@literal @}Override
+     *                                  public Integer calculateLlmTokenCount(String model, String content) {
+     *                                      // Token calculation based on model and content goes here
+     *                                      // Return the calculated token count
+     *                                  }
+     *                               };
+     *
+     *                               // Set the created callback instance
+     *                               NewRelic.getAgent().getAiMonitoring().setLlmTokenCountCallback(llmTokenCountCallback);
+     *                               }
*/ -// void setLlmTokenCountCallback(LlmTokenCountCallback callback); + void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback); } From cd0926c92502eccd68d52cfe724ae785f94a5fc2 Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 16:38:10 -0700 Subject: [PATCH 45/68] name refactoring --- .../src/main/java/com/newrelic/agent/AgentImpl.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java index f509a67a5b..cc1082bfdf 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java @@ -16,9 +16,9 @@ import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.tracers.Tracer; import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.AiMonitoringImpl; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; -import com.newrelic.api.agent.LlmFeedbackEventRecorder; import com.newrelic.api.agent.Logger; import com.newrelic.api.agent.Logs; import com.newrelic.api.agent.MetricAggregator; @@ -140,7 +140,7 @@ public Insights getInsights() { @Override public AiMonitoring getAiMonitoring() { - return new LlmFeedbackEventRecorder(); + return new AiMonitoringImpl(); } @Override From c6d39e480c8467f100f4aba34823421e3fc92e41 Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 16:41:13 -0700 Subject: [PATCH 46/68] created callback interface --- .../main/java/com/newrelic/api/agent/LlmTokenCountCallback.java | 1 - 1 file changed, 1 deletion(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java index 809c628e0b..ebaaed1e2c 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java @@ -18,7 +18,6 @@ * } * } * - * // Usage: * LlmTokenCountCallback myCallback = new MyTokenCountCallback(); * // After creating the {@code myCallback} instance, it should be passed as an argument to the {@code setLlmTokenCountCallback} * // method of the AI Monitoring API. From cfa407bf8f893ef694acb11d49bd45febdbd8db3 Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 16:43:51 -0700 Subject: [PATCH 47/68] added setLlmTokenCountCallback method and a supportability metric --- .../newrelic/api/agent/AiMonitoringImpl.java | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java index 2251367547..83a57e54ee 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java @@ -1,32 +1,29 @@ -package com.newrelic.agent.aimonitoring; - -import com.newrelic.api.agent.AiMonitoring; -import com.newrelic.api.agent.LlmFeedbackEventAttributes; -import com.newrelic.api.agent.NewRelic; +package com.newrelic.api.agent; import java.util.Map; /** - * A utility class for recording LlmFeedbackMessage events using the AI Monitoring API. - *

- * This class implements the {@link AiMonitoring} interface and provides a method to record LlmFeedbackMessage events - * by delegating to the Insights API for custom event recording. + * A utility class for interacting with the AI Monitoring API to record LlmFeedbackMessage events. + * This class implements the {@link AiMonitoring} interface and provides methods for feedback event recording + * and setting callbacks for token calculation. */ public class AiMonitoringImpl implements AiMonitoring { + private static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/set"; + /** * Records an LlmFeedbackMessage event. * * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct * the llmFeedbackEventAttributes map, use * {@link LlmFeedbackEventAttributes.Builder} - *

The map must include:

+ *

Required Attributes:

*
    *
  • "traceId" (String): Trace ID where the chat completion related to the * feedback event occurred
  • *
  • "rating" (Integer/String): Rating provided by an end user
  • *
- * Optional attributes: + * Optional Attributes: *
    *
  • "category" (String): Category of the feedback as provided by the end user
  • *
  • "message" (String): Freeform text feedback from an end user.
  • @@ -41,15 +38,16 @@ public void recordLlmFeedbackEvent(Map llmFeedbackEventAttribute NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); } + /** + * Sets the callback for token calculation and reports a supportability metric. + * + * @param llmTokenCountCallback The callback instance implementing {@link LlmTokenCountCallback} interface. + * This callback will be used for token calculation. + * @see LlmTokenCountCallback + */ @Override public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { - String model = "SampleModel"; - String content = "SampleContent"; -// LlmTokenCountCallbackHolder llmTokenCountCallbackHolder = new LlmTokenCountCallbackHolder(llmTokenCountCallback); - LlmTokenCountCallbackHolder llmTokenCountCallbackHolder = new LlmTokenCountCallbackHolder(); - llmTokenCountCallbackHolder.setLlmTokenCountCallbackHolder(llmTokenCountCallback); - LlmTokenCountCallback tokenCounter = llmTokenCountCallbackHolder.getLlmTokenCountCallback(); - Integer tokenCount = llmTokenCountCallback.calculateLlmTokenCount(model, content); - + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + NewRelic.getAgent().getMetricAggregator().incrementCounter(SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); } } From f545f4730a7405f7be4e4c7c5297bab5f3d99a5f Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 17:15:08 -0700 Subject: [PATCH 48/68] added setUp with LlmTokenCountCallback and updated tokenCount --- .../src/test/java/llm/events/LlmEventTest.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java index 39858fc881..0910a6fd70 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import llm.models.ModelInvocation; import llm.models.amazon.titan.TitanModelInvocation; import llm.models.anthropic.claude.ClaudeModelInvocation; @@ -45,8 +47,13 @@ public class LlmEventTest { @Before public void before() { introspector.clear(); + setUp(); } + public void setUp() { + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } @Test public void testRecordLlmEmbeddingEvent() { // Given @@ -191,7 +198,7 @@ public void testRecordLlmChatCompletionMessageEvent() { .responseModel() // attribute 10 .sequence(0) // attribute 11 .completionId() // attribute 12 - .tokenCount(123) // attribute 13 + .tokenCount(LlmTokenCountCallbackHolder.getInstance().getLlmTokenCountCallback().calculateLlmTokenCount("model", "content")) // attribute 13 .build(); // attributes 14 & 15 should be the two llm.* prefixed userAttributes @@ -220,7 +227,7 @@ public void testRecordLlmChatCompletionMessageEvent() { assertEquals("anthropic.claude-v2", attributes.get("response.model")); assertEquals(0, attributes.get("sequence")); assertFalse(((String) attributes.get("completion_id")).isEmpty()); - assertEquals(123, attributes.get("token_count")); + assertEquals(13, attributes.get("token_count")); assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); assertEquals("testPrefix", attributes.get("llm.testPrefix")); } From a1cddff96b03a71b55d3a50e4ae4ee98c95c7d0e Mon Sep 17 00:00:00 2001 From: edeleon Date: Wed, 27 Mar 2024 17:16:38 -0700 Subject: [PATCH 49/68] added assertions for token_count value --- .../src/test/java/llm/models/TestUtil.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java index 15d0149b1e..4b3b86da9b 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java @@ -34,6 +34,7 @@ public static void assertLlmChatCompletionMessageAttributes(Event event, String assertEquals(modelId, attributes.get("response.model")); assertEquals("testPrefix", attributes.get("llm.testPrefix")); assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + assertEquals(13, attributes.get("token_count")); if (isResponse) { assertEquals("assistant", attributes.get("role")); @@ -81,6 +82,7 @@ public static void assertLlmEmbeddingAttributes(Event event, String modelId, Str assertFalse(((String) attributes.get("request_id")).isEmpty()); assertEquals("testPrefix", attributes.get("llm.testPrefix")); assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + assertEquals(13, attributes.get("token_count")); } public static void assertErrorEvent(boolean isError, Collection errorEvents) { From bc019070b0a75570fa24f79b311ecf7d955ce4e2 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 28 Mar 2024 12:16:45 -0700 Subject: [PATCH 50/68] refactored builders to invoke getTokenCount --- .../llm/models/ai21labs/jurassic/JurassicModelInvocation.java | 4 ++-- .../java/llm/models/amazon/titan/TitanModelInvocation.java | 4 ++-- .../llm/models/anthropic/claude/ClaudeModelInvocation.java | 4 ++-- .../llm/models/cohere/command/CommandModelInvocation.java | 4 ++-- .../java/llm/models/meta/llama2/Llama2ModelInvocation.java | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java index db364205a8..ea41563314 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -69,7 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime, int index) { .input(index) .requestModel() .responseModel() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -122,7 +122,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, bo .responseModel() .sequence(sequence) .completionId() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java index ddc9308bdf..b919ae1e94 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -69,7 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime, int index) { .input(index) .requestModel() .responseModel() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -122,7 +122,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, bo .responseModel() .sequence(sequence) .completionId() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java index 9185047102..f1861fba2c 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -69,7 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime, int index) { .input(index) .requestModel() .responseModel() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -122,7 +122,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, bo .responseModel() .sequence(sequence) .completionId() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java index ef4eb69acc..729900e38b 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java @@ -69,7 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime, int index) { .input(index) .requestModel() .responseModel() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -122,7 +122,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, bo .responseModel() .sequence(sequence) .completionId() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java index 70ede61f9a..99e2820dde 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java @@ -69,7 +69,7 @@ public void recordLlmEmbeddingEvent(long startTime, int index) { .input(index) .requestModel() .responseModel() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) .error() .duration(System.currentTimeMillis() - startTime) .build(); @@ -122,7 +122,7 @@ public void recordLlmChatCompletionMessageEvent(int sequence, String message, bo .responseModel() .sequence(sequence) .completionId() - .tokenCount(0) // TODO set to value from the setLlmTokenCountCallback API + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) .build(); llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); From 8ee16467fef43904c402875f76d97d996f77f8e8 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 28 Mar 2024 12:31:53 -0700 Subject: [PATCH 51/68] update tests to include callback setting --- .../models/ai21labs/jurassic/JurassicModelInvocationTest.java | 4 ++++ .../llm/models/amazon/titan/TitanModelInvocationTest.java | 4 ++++ .../models/anthropic/claude/ClaudeModelInvocationTest.java | 4 ++++ .../llm/models/cohere/command/CommandModelInvocationTest.java | 4 ++++ .../llm/models/meta/llama2/Llama2ModelInvocationTest.java | 4 ++++ 5 files changed, 20 insertions(+) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java index 99f5f49a79..46040b9a9d 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -54,6 +56,8 @@ public class JurassicModelInvocationTest { @Before public void before() { introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); } @Test diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java index 113ece5d47..5a2ee428c5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -60,6 +62,8 @@ public class TitanModelInvocationTest { @Before public void before() { introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); } @Test diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java index af4ed9906f..8c5bb9a2d3 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,6 +54,8 @@ public class ClaudeModelInvocationTest { @Before public void before() { introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); } @Test diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java index c12daa47f0..f29f75fa46 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -60,6 +62,8 @@ public class CommandModelInvocationTest { @Before public void before() { introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); } @Test diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java index 98a85f325d..6ddc491fb5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java @@ -11,6 +11,8 @@ import com.newrelic.agent.introspec.InstrumentationTestConfig; import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,6 +54,8 @@ public class Llama2ModelInvocationTest { @Before public void before() { introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); } @Test From e9bdf5bfa6d5ea9005b0e313dec1483f148d032a Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 28 Mar 2024 13:15:57 -0700 Subject: [PATCH 52/68] refactored AiMonitoringImplTest --- .../aimonitoring/AiMonitoringImplTest.java | 60 ------------ .../api/agent/AiMonitoringImplTest.java | 94 +++++++++++++++++++ 2 files changed, 94 insertions(+), 60 deletions(-) delete mode 100644 newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java create mode 100644 newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java diff --git a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java deleted file mode 100644 index 7c5ea1266f..0000000000 --- a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java +++ /dev/null @@ -1,60 +0,0 @@ -package com.newrelic.api.agent; - -import org.junit.Before; -import org.junit.Test; - -import java.util.Map; - -public class AiMonitoringImplTest { - - LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; - AiMonitoringImpl aiMonitoringImpl; - Map llmFeedbackEventParameters; - - @Before - public void setup() { - String traceId = "123456"; - Integer rating = 5; - llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); - aiMonitoringImpl = new AiMonitoringImpl(); - } - - @Test - public void testRecordLlmFeedbackEvent() { - llmFeedbackEventParameters = llmFeedbackEventBuilder - .category("General") - .message("Great experience") - .build(); - - aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventParameters); - - // TODO: verify recordCustomEvent was called with the correct parameters - } - - @Test - public void testRecordLlmFeedbackEvent_NullMap() { - // TODO: invoke the method with a null map - // TODO: verify recordCustomEvent was not called - } - - @Test - public void testRecordLlmFeedbackEvent_EmptyMap() { - // TODO: invoke the method with an empty map - // TODO: verify recordCustomEvent was not called - } - - @Test - public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { - class TestCallback implements LlmTokenCountCallback { - - @Override - public Integer calculateLlmTokenCount(String model, String content) { - return -5; - } - } - - TestCallback testCallback = new TestCallback(); - aiMonitoringImpl.setLlmTokenCountCallback(testCallback); - } - -} diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java new file mode 100644 index 0000000000..e3dc07ed5d --- /dev/null +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java @@ -0,0 +1,94 @@ +package com.newrelic.agent.aimonitoring; + +import com.newrelic.api.agent.AiMonitoringImpl; +import com.newrelic.api.agent.Insights; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import java.util.Map; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.Mockito.*; + +@RunWith(MockitoJUnitRunner.class) +public class AiMonitoringImplTest { + + @Mock + Insights insights; + + AiMonitoringImpl aiMonitoringImpl; + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + Map llmFeedbackEventAttributes; + + @Before + public void setup() { + String traceId = "123456"; + Integer rating = 5; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category("General") + .message("Great experience") + .build(); + } + + @Test + public void testRecordLlmFeedbackEventSuccess() { + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); + verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + } + + @Test + public void testRecordLlmFeedbackEventFailure() { + doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); + try { + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); + } catch (RuntimeException exception) { + verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + assertEquals("Custom event recording failed", exception.getMessage()); + } + } + + @Test + public void testRecordLlmFeedbackEventWithNullAttributes() { + doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) + .when(aiMonitoringImpl).recordLlmFeedbackEvent(null); + + try { + aiMonitoringImpl.recordLlmFeedbackEvent(null); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // Expected exception thrown, test passed + System.out.println("IllegalArgumentException successfully thrown!"); + } + } + + @Test + public void testSetLlmTokenCountCallbackSuccess() { + LlmTokenCountCallback testCallback = mock(LlmTokenCountCallback.class); + aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); + assertNotNull(LlmTokenCountCallbackHolder.getInstance()); + } + + @Test + public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { + class TestCallback implements LlmTokenCountCallback { + + @Override + public Integer calculateLlmTokenCount(String model, String content) { + return 13; + } + } + + TestCallback testCallback = new TestCallback(); + aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + } + +} From 1ba9614ed844578e91b9633d1076c40091fbc542 Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 28 Mar 2024 13:16:30 -0700 Subject: [PATCH 53/68] added getTokenCount method --- .../main/java/llm/models/ModelInvocation.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 9f85c73f0f..10de637ba4 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -9,10 +9,12 @@ import com.newrelic.agent.bridge.Token; import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.LlmTokenCountCallbackHolder; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; import java.util.Map; +import java.util.Objects; import java.util.UUID; import static llm.vendor.Vendor.BEDROCK; @@ -175,4 +177,23 @@ static String getTraceId(Map linkingMetadata) { static String getRandomGuid() { return UUID.randomUUID().toString(); } + + /** + * Calculates the tokenCount based on a user provided callback + * + * @param model String representation of the LLM model + * @param content String representation of the message content or prompt + * @return int representing the tokenCount + */ + static int getTokenCount(String model, String content) { + int tokenCount = 0; + + if (LlmTokenCountCallbackHolder.getInstance() != null && !Objects.equals(content, "")) { + return LlmTokenCountCallbackHolder + .getInstance() + .getLlmTokenCountCallback() + .calculateLlmTokenCount(model, content); + } + return tokenCount; + } } From 3ff383d829196b90b5f26c225052c4939ed13c7b Mon Sep 17 00:00:00 2001 From: edeleon Date: Thu, 28 Mar 2024 13:20:04 -0700 Subject: [PATCH 54/68] updates for class movement --- .../api/agent/AiMonitoringImplTest.java | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java index e3dc07ed5d..14e42d81c2 100644 --- a/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java @@ -1,28 +1,21 @@ -package com.newrelic.agent.aimonitoring; +package com.newrelic.api.agent; -import com.newrelic.api.agent.AiMonitoringImpl; -import com.newrelic.api.agent.Insights; -import com.newrelic.api.agent.LlmFeedbackEventAttributes; -import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import java.util.Map; -import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) public class AiMonitoringImplTest { @Mock - Insights insights; - AiMonitoringImpl aiMonitoringImpl; LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; Map llmFeedbackEventAttributes; @@ -41,28 +34,28 @@ public void setup() { @Test public void testRecordLlmFeedbackEventSuccess() { aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); - verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); } @Test public void testRecordLlmFeedbackEventFailure() { - doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); + Mockito.doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); try { aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); } catch (RuntimeException exception) { - verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); - assertEquals("Custom event recording failed", exception.getMessage()); + Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + Assert.assertEquals("Custom event recording failed", exception.getMessage()); } } @Test public void testRecordLlmFeedbackEventWithNullAttributes() { - doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) + Mockito.doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) .when(aiMonitoringImpl).recordLlmFeedbackEvent(null); try { aiMonitoringImpl.recordLlmFeedbackEvent(null); - fail("Expected IllegalArgumentException to be thrown"); + Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // Expected exception thrown, test passed System.out.println("IllegalArgumentException successfully thrown!"); @@ -71,10 +64,10 @@ public void testRecordLlmFeedbackEventWithNullAttributes() { @Test public void testSetLlmTokenCountCallbackSuccess() { - LlmTokenCountCallback testCallback = mock(LlmTokenCountCallback.class); + LlmTokenCountCallback testCallback = Mockito.mock(LlmTokenCountCallback.class); aiMonitoringImpl.setLlmTokenCountCallback(testCallback); - verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); - assertNotNull(LlmTokenCountCallbackHolder.getInstance()); + Mockito.verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); + Assert.assertNotNull(LlmTokenCountCallbackHolder.getInstance()); } @Test From c2722bfcb6c2809b6b1e4eefced94b5401a011be Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:11:56 -0700 Subject: [PATCH 55/68] moved LlmTokenCountCallbackHolder to agent-bridge and AiMonitoringImpl to newrelic-agent --- .../LlmTokenCountCallbackHolder.java | 19 ++++ .../agent/aimonitoring/AiMonitoringImpl.java | 59 ++++++++++++ .../aimonitoring/AiMonitoringImplTest.java | 90 +++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java create mode 100644 newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java create mode 100644 newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java new file mode 100644 index 0000000000..f237565902 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java @@ -0,0 +1,19 @@ +package com.newrelic.agent.bridge.aimonitoring; + +import com.newrelic.api.agent.LlmTokenCountCallback; + +/** + * storage for an instance of {@link LlmTokenCountCallback}. + */ +public class LlmTokenCountCallbackHolder { + + private static volatile LlmTokenCountCallback llmTokenCountCallback = null; + + public static void setLlmTokenCountCallback(LlmTokenCountCallback newLlmTokenCountCallback) { + llmTokenCountCallback = newLlmTokenCountCallback; + } + + public static LlmTokenCountCallback getLlmTokenCountCallback() { + return llmTokenCountCallback; + } +} \ No newline at end of file diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java new file mode 100644 index 0000000000..0ea8d1aa9d --- /dev/null +++ b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java @@ -0,0 +1,59 @@ +package com.newrelic.agent; + +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.NewRelic; + +import java.util.Map; + +/** + * A utility class for interacting with the AI Monitoring API to record LlmFeedbackMessage events. + * This class implements the {@link AiMonitoring} interface and provides methods for feedback event recording + * and setting callbacks for token calculation. + */ + +public class AiMonitoringImpl implements AiMonitoring { + private static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/set"; + + /** + * Records an LlmFeedbackMessage event. + * + * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct + * the llmFeedbackEventAttributes map, use + * {@link LlmFeedbackEventAttributes.Builder} + *

    Required Attributes:

    + *
      + *
    • "traceId" (String): Trace ID where the chat completion related to the + * feedback event occurred
    • + *
    • "rating" (Integer/String): Rating provided by an end user
    • + *
    + * Optional Attributes: + *
      + *
    • "category" (String): Category of the feedback as provided by the end user
    • + *
    • "message" (String): Freeform text feedback from an end user.
    • + *
    • "metadata" (Map<String, String>): Set of key-value pairs to store + * additional data to submit with the feedback event
    • + *
    + */ + + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + // Delegate to Insights API for event recording + NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); + } + + /** + * Sets the callback for token calculation and reports a supportability metric. + * + * @param llmTokenCountCallback The callback instance implementing {@link LlmTokenCountCallback} interface. + * This callback will be used for token calculation. + * @see LlmTokenCountCallback + */ + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + NewRelic.getAgent().getMetricAggregator().incrementCounter(SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); + } +} diff --git a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java new file mode 100644 index 0000000000..646141f0f7 --- /dev/null +++ b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java @@ -0,0 +1,90 @@ +package com.newrelic.agent.bridge.aimonitoring; + +import com.newrelic.api.agent.AiMonitoringImpl; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.LlmTokenCountCallback; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.anyMap; + +@RunWith(MockitoJUnitRunner.class) +public class AiMonitoringImplTest { + + @Mock + AiMonitoringImpl aiMonitoringImpl; + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + Map llmFeedbackEventAttributes; + + @Before + public void setup() { + String traceId = "123456"; + Integer rating = 5; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category("General") + .message("Great experience") + .build(); + } + + @Test + public void testRecordLlmFeedbackEventSuccess() { + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); + Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + } + + @Test + public void testRecordLlmFeedbackEventFailure() { + Mockito.doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); + try { + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); + } catch (RuntimeException exception) { + Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); + Assert.assertEquals("Custom event recording failed", exception.getMessage()); + } + } + + @Test + public void testRecordLlmFeedbackEventWithNullAttributes() { + Mockito.doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) + .when(aiMonitoringImpl).recordLlmFeedbackEvent(null); + + try { + aiMonitoringImpl.recordLlmFeedbackEvent(null); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // Expected exception thrown, test passed + System.out.println("IllegalArgumentException successfully thrown!"); + } + } + + @Test + public void testSetLlmTokenCountCallbackSuccess() { + LlmTokenCountCallback testCallback = Mockito.mock(LlmTokenCountCallback.class); + aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + Mockito.verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); + Assert.assertNotNull(LlmTokenCountCallbackHolder.getInstance()); + } + + @Test + public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { + class TestCallback implements LlmTokenCountCallback { + + @Override + public Integer calculateLlmTokenCount(String model, String content) { + return 13; + } + } + + TestCallback testCallback = new TestCallback(); + aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + } + +} From 6b9dabe44c4591203fe87607274a8ad4c386ad4f Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:13:49 -0700 Subject: [PATCH 56/68] added supportability metric to track when tokenCountCallbacks are set --- .../src/main/java/com/newrelic/agent/MetricNames.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java b/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java index 948e7ede9c..0a426c81b7 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java @@ -498,6 +498,9 @@ public class MetricNames { public static final String SUPPORTABILITY_SLOW_TXN_DETECTION_ENABLED = "Supportability/SlowTransactionDetection/enabled"; public static final String SUPPORTABILITY_SLOW_TXN_DETECTION_DISABLED = "Supportability/SlowTransactionDetection/disabled"; + // AiMonitoring Callback Set + public static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/Set"; + /** * Utility method for adding supportability metrics to APIs * From 9e692eccdb9266a350f522fc93d5686ab06d4ab5 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:19:25 -0700 Subject: [PATCH 57/68] implemented getTokenCount method to check for nulls and return token calculation --- .../src/main/java/llm/models/ModelInvocation.java | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java index 10de637ba4..0bb2fc4d8a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -9,7 +9,7 @@ import com.newrelic.agent.bridge.Token; import com.newrelic.agent.bridge.Transaction; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; @@ -186,14 +186,11 @@ static String getRandomGuid() { * @return int representing the tokenCount */ static int getTokenCount(String model, String content) { - int tokenCount = 0; - - if (LlmTokenCountCallbackHolder.getInstance() != null && !Objects.equals(content, "")) { - return LlmTokenCountCallbackHolder - .getInstance() - .getLlmTokenCountCallback() - .calculateLlmTokenCount(model, content); + if (LlmTokenCountCallbackHolder.getLlmTokenCountCallback() == null || Objects.equals(content, "")) { + return 0; } - return tokenCount; + return LlmTokenCountCallbackHolder + .getLlmTokenCountCallback() + .calculateLlmTokenCount(model, content); } } From bee19fedda0b417a90b61cce7f772dd3c86e6d49 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:39:23 -0700 Subject: [PATCH 58/68] updated imports --- .../src/test/java/llm/events/LlmEventTest.java | 4 ++-- .../models/ai21labs/jurassic/JurassicModelInvocationTest.java | 2 +- .../llm/models/amazon/titan/TitanModelInvocationTest.java | 2 +- .../models/anthropic/claude/ClaudeModelInvocationTest.java | 2 +- .../llm/models/cohere/command/CommandModelInvocationTest.java | 2 +- .../llm/models/meta/llama2/Llama2ModelInvocationTest.java | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java index 0910a6fd70..f932918bd1 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import llm.models.ModelInvocation; import llm.models.amazon.titan.TitanModelInvocation; import llm.models.anthropic.claude.ClaudeModelInvocation; @@ -198,7 +198,7 @@ public void testRecordLlmChatCompletionMessageEvent() { .responseModel() // attribute 10 .sequence(0) // attribute 11 .completionId() // attribute 12 - .tokenCount(LlmTokenCountCallbackHolder.getInstance().getLlmTokenCountCallback().calculateLlmTokenCount("model", "content")) // attribute 13 + .tokenCount(LlmTokenCountCallbackHolder.getLlmTokenCountCallback().calculateLlmTokenCount("model", "content")) // attribute 13 .build(); // attributes 14 & 15 should be the two llm.* prefixed userAttributes diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java index 46040b9a9d..320b54ac57 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java index 5a2ee428c5..521472afb5 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java index 8c5bb9a2d3..36d53a92c3 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java index f29f75fa46..cc115dfa29 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java index 6ddc491fb5..da1399310a 100644 --- a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java @@ -12,7 +12,7 @@ import com.newrelic.agent.introspec.InstrumentationTestRunner; import com.newrelic.agent.introspec.Introspector; import com.newrelic.api.agent.LlmTokenCountCallback; -import com.newrelic.api.agent.LlmTokenCountCallbackHolder; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; From 1b4df474c983871944a48ac0cd17d3664d93991c Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:40:29 -0700 Subject: [PATCH 59/68] updated return type and comments --- .../java/com/newrelic/api/agent/LlmTokenCountCallback.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java index ebaaed1e2c..28f1d3864f 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java @@ -19,7 +19,7 @@ * } * * LlmTokenCountCallback myCallback = new MyTokenCountCallback(); - * // After creating the {@code myCallback} instance, it should be passed as an argument to the {@code setLlmTokenCountCallback} + * // After creating the myCallback instance, it should be passed as an argument to the setLlmTokenCountCallback * // method of the AI Monitoring API. * NewRelic.getAgent().getAiMonitoring.setLlmTokenCountCallback(myCallback); * } @@ -34,7 +34,7 @@ public interface LlmTokenCountCallback { * @param model The name of the LLM model. * @param content The message content or prompt. * @return An integer representing the number of tokens used for the given model and content. - * If the count cannot be determined or is less than or equal to 0, null is returned. + * If the count cannot be determined or is less than or equal to 0, 0 is returned. */ - public Integer calculateLlmTokenCount(String model, String content); + public int calculateLlmTokenCount(String model, String content); } From 51e1285801362f6de09ac68d28e412fe5baeaaa7 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:41:06 -0700 Subject: [PATCH 60/68] refactored class --- .../aimonitoring/LlmTokenCountCallbackHolder.java | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java index f237565902..6c89efd7e0 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java @@ -3,16 +3,27 @@ import com.newrelic.api.agent.LlmTokenCountCallback; /** - * storage for an instance of {@link LlmTokenCountCallback}. + * A thread-safe holder for an instance of {@link LlmTokenCountCallback}. + * This class provides methods for setting and retrieving the callback instance. */ public class LlmTokenCountCallbackHolder { private static volatile LlmTokenCountCallback llmTokenCountCallback = null; + /** + * Sets the {@link LlmTokenCountCallback} instance to be stored. + * + * @param newLlmTokenCountCallback the callback instance + */ public static void setLlmTokenCountCallback(LlmTokenCountCallback newLlmTokenCountCallback) { llmTokenCountCallback = newLlmTokenCountCallback; } + /** + * Retrieves the stored {@link LlmTokenCountCallback} instance. + * + * @return stored callback instance + */ public static LlmTokenCountCallback getLlmTokenCountCallback() { return llmTokenCountCallback; } From 18827a675ec7e6f0cf641b0813e93f761d53ae73 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:43:01 -0700 Subject: [PATCH 61/68] moved classes to update callback accessibility --- .../agent/LlmTokenCountCallbackHolder.java | 55 ------------ .../api/agent/AiMonitoringImplTest.java | 87 ------------------- 2 files changed, 142 deletions(-) delete mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java delete mode 100644 newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java deleted file mode 100644 index de2dcdc5a4..0000000000 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallbackHolder.java +++ /dev/null @@ -1,55 +0,0 @@ -package com.newrelic.api.agent; - -/** - * A singleton class for holding an instance of {@link LlmTokenCountCallback}. - * This class ensures that only one instance of the callback is stored and accessed throughout the application. - */ -public class LlmTokenCountCallbackHolder { - - private static volatile LlmTokenCountCallbackHolder INSTANCE; - private static volatile LlmTokenCountCallback llmTokenCountCallback; - - /** - * Private constructor to prevent instantiation from outside. - * - * @param llmTokenCountCallback The callback method to be stored. - */ - private LlmTokenCountCallbackHolder(LlmTokenCountCallback llmTokenCountCallback) { - LlmTokenCountCallbackHolder.llmTokenCountCallback = llmTokenCountCallback; - }; - - /** - * Returns the singleton instance of the {@code LlmTokenCountCallbackHolder}. - * - * @return The singleton instance. - */ - public static LlmTokenCountCallbackHolder getInstance() { - if (INSTANCE == null) { - synchronized (LlmTokenCountCallbackHolder.class) { - if (INSTANCE == null) { - INSTANCE = new LlmTokenCountCallbackHolder(llmTokenCountCallback); - } - } - } - return INSTANCE; - } - - /** - * Sets the {@link LlmTokenCountCallback} instance to be stored. - * - * @param llmTokenCountCallback The callback instance to be stored. - */ - public static void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { - LlmTokenCountCallbackHolder.llmTokenCountCallback = llmTokenCountCallback; - } - - /** - * Retrieves the stored {@link LlmTokenCountCallback} instance. - * - * @return The stored callback instance. - */ - public LlmTokenCountCallback getLlmTokenCountCallback() { - return llmTokenCountCallback; - } - -} \ No newline at end of file diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java deleted file mode 100644 index 14e42d81c2..0000000000 --- a/newrelic-api/src/test/java/com/newrelic/api/agent/AiMonitoringImplTest.java +++ /dev/null @@ -1,87 +0,0 @@ -package com.newrelic.api.agent; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; - -import java.util.Map; - -import static org.mockito.ArgumentMatchers.anyMap; - -@RunWith(MockitoJUnitRunner.class) -public class AiMonitoringImplTest { - - @Mock - AiMonitoringImpl aiMonitoringImpl; - LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; - Map llmFeedbackEventAttributes; - - @Before - public void setup() { - String traceId = "123456"; - Integer rating = 5; - llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); - llmFeedbackEventAttributes = llmFeedbackEventBuilder - .category("General") - .message("Great experience") - .build(); - } - - @Test - public void testRecordLlmFeedbackEventSuccess() { - aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); - Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); - } - - @Test - public void testRecordLlmFeedbackEventFailure() { - Mockito.doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); - try { - aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); - } catch (RuntimeException exception) { - Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); - Assert.assertEquals("Custom event recording failed", exception.getMessage()); - } - } - - @Test - public void testRecordLlmFeedbackEventWithNullAttributes() { - Mockito.doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) - .when(aiMonitoringImpl).recordLlmFeedbackEvent(null); - - try { - aiMonitoringImpl.recordLlmFeedbackEvent(null); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // Expected exception thrown, test passed - System.out.println("IllegalArgumentException successfully thrown!"); - } - } - - @Test - public void testSetLlmTokenCountCallbackSuccess() { - LlmTokenCountCallback testCallback = Mockito.mock(LlmTokenCountCallback.class); - aiMonitoringImpl.setLlmTokenCountCallback(testCallback); - Mockito.verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); - Assert.assertNotNull(LlmTokenCountCallbackHolder.getInstance()); - } - - @Test - public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { - class TestCallback implements LlmTokenCountCallback { - - @Override - public Integer calculateLlmTokenCount(String model, String content) { - return 13; - } - } - - TestCallback testCallback = new TestCallback(); - aiMonitoringImpl.setLlmTokenCountCallback(testCallback); - } - -} From 9443c6b3748cf1485a5cd9da102105051a4052db Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 03:45:08 -0700 Subject: [PATCH 62/68] refactored naming and initialized optional attrs to be null --- .../api/agent/LlmFeedbackEventAttributes.java | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java index c0e0214350..2f12a63695 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java @@ -12,6 +12,7 @@ public class LlmFeedbackEventAttributes { private final Map metadata; private final UUID id; private final String ingestSource; + private static final String INGEST_SOURCE = "Java"; protected LlmFeedbackEventAttributes(String traceId, Object rating, String category, String message, Map metadata, UUID id, String ingestSource) { this.traceId = traceId; @@ -53,29 +54,29 @@ public String getIngestSource() { } public Map toMap() { - Map feedbackParametersMap = new HashMap<>(); - feedbackParametersMap.put("traceId", getTraceId()); - feedbackParametersMap.put("rating", getRating()); - feedbackParametersMap.put("id", getId()); - feedbackParametersMap.put("ingestSource", getIngestSource()); + Map feedbackAttributesMap = new HashMap<>(); + feedbackAttributesMap.put("traceId", getTraceId()); + feedbackAttributesMap.put("rating", getRating()); + feedbackAttributesMap.put("id", getId()); + feedbackAttributesMap.put("ingestSource", getIngestSource()); if (category != null) { - feedbackParametersMap.put("category", getCategory()); + feedbackAttributesMap.put("category", getCategory()); } if (message != null) { - feedbackParametersMap.put("message", getMessage()); + feedbackAttributesMap.put("message", getMessage()); } if (metadata != null) { - feedbackParametersMap.put("metadata", getMetadata()); + feedbackAttributesMap.put("metadata", getMetadata()); } - return feedbackParametersMap; + return feedbackAttributesMap; } public static class Builder { private final String traceId; private final Object rating; - private String category; - private String message; - private Map metadata; + private String category = null; + private String message = null; + private Map metadata = null; private final UUID id = UUID.randomUUID(); public Builder(String traceId, Object rating) { @@ -99,7 +100,7 @@ public Builder metadata(Map metadata) { } public Map build() { - return new LlmFeedbackEventAttributes(traceId, rating, category, message, metadata, id, "Java").toMap(); + return new LlmFeedbackEventAttributes(traceId, rating, category, message, metadata, id, INGEST_SOURCE).toMap(); } } From 088123ddfccbb2e0a6cfa933f88b46b3665c7105 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 04:00:57 -0700 Subject: [PATCH 63/68] moved the ai-monitoring implementation, refactored imports and updated tests --- .../agent/aimonitoring/AiMonitoringImpl.java | 11 +++- .../aimonitoring/AiMonitoringImplTest.java | 61 ++++++++++--------- .../newrelic/api/agent/AiMonitoringImpl.java | 53 ---------------- 3 files changed, 40 insertions(+), 85 deletions(-) delete mode 100644 newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java index 0ea8d1aa9d..bb6f7c79c8 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java @@ -1,5 +1,6 @@ -package com.newrelic.agent; +package com.newrelic.agent.aimonitoring; +import com.newrelic.agent.MetricNames; import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.LlmFeedbackEventAttributes; @@ -40,6 +41,9 @@ public class AiMonitoringImpl implements AiMonitoring { @Override public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + if (llmFeedbackEventAttributes == null) { + throw new IllegalArgumentException("llmFeedbackEventAttributes cannot be null"); + } // Delegate to Insights API for event recording NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); } @@ -53,7 +57,10 @@ public void recordLlmFeedbackEvent(Map llmFeedbackEventAttribute */ @Override public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + if (llmTokenCountCallback == null) { + throw new IllegalArgumentException("llmTokenCountCallback cannot be null"); + } LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); - NewRelic.getAgent().getMetricAggregator().incrementCounter(SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); + NewRelic.getAgent().getMetricAggregator().incrementCounter(MetricNames.SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); } } diff --git a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java index 646141f0f7..46254fa6e8 100644 --- a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java +++ b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java @@ -1,6 +1,6 @@ -package com.newrelic.agent.bridge.aimonitoring; +package com.newrelic.agent.aimonitoring; -import com.newrelic.api.agent.AiMonitoringImpl; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; import com.newrelic.api.agent.LlmFeedbackEventAttributes; import com.newrelic.api.agent.LlmTokenCountCallback; import org.junit.Assert; @@ -8,53 +8,46 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import java.util.Map; -import static org.mockito.ArgumentMatchers.anyMap; +import static org.junit.Assert.assertEquals; @RunWith(MockitoJUnitRunner.class) public class AiMonitoringImplTest { @Mock - AiMonitoringImpl aiMonitoringImpl; - LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + private AiMonitoringImpl aiMonitoringImpl; + + private LlmTokenCountCallback callback; Map llmFeedbackEventAttributes; @Before public void setup() { String traceId = "123456"; Integer rating = 5; - llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); llmFeedbackEventAttributes = llmFeedbackEventBuilder .category("General") .message("Great experience") .build(); + callback = getCallback(); + aiMonitoringImpl = new AiMonitoringImpl(); } @Test - public void testRecordLlmFeedbackEventSuccess() { - aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); - Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); - } - - @Test - public void testRecordLlmFeedbackEventFailure() { - Mockito.doThrow(new RuntimeException("Custom event recording failed")).when(aiMonitoringImpl).recordLlmFeedbackEvent(anyMap()); + public void testRecordLlmFeedbackEventSent() { try { aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); - } catch (RuntimeException exception) { - Mockito.verify(aiMonitoringImpl).recordLlmFeedbackEvent(llmFeedbackEventAttributes); - Assert.assertEquals("Custom event recording failed", exception.getMessage()); + } catch (IllegalArgumentException e) { + // test should not catch an exception } + } @Test public void testRecordLlmFeedbackEventWithNullAttributes() { - Mockito.doThrow(new IllegalArgumentException("llmFeedbackEventAttributes cannot be null")) - .when(aiMonitoringImpl).recordLlmFeedbackEvent(null); try { aiMonitoringImpl.recordLlmFeedbackEvent(null); @@ -66,25 +59,33 @@ public void testRecordLlmFeedbackEventWithNullAttributes() { } @Test - public void testSetLlmTokenCountCallbackSuccess() { - LlmTokenCountCallback testCallback = Mockito.mock(LlmTokenCountCallback.class); - aiMonitoringImpl.setLlmTokenCountCallback(testCallback); - Mockito.verify(aiMonitoringImpl).setLlmTokenCountCallback(testCallback); - Assert.assertNotNull(LlmTokenCountCallbackHolder.getInstance()); + public void testCallbackSetSuccessfully() { + aiMonitoringImpl.setLlmTokenCountCallback(callback); + assertEquals(callback, LlmTokenCountCallbackHolder.getLlmTokenCountCallback()); } @Test - public void testSetLlmTokenCountCallbackReturnsIntegerGreaterThanZero() { + public void testSetLlmTokenCountCallbackWithNull() { + + try { + aiMonitoringImpl.setLlmTokenCountCallback(null); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // Expected exception thrown, test passes + System.out.println("IllegalArgumentException successfully thrown!"); + } + + } + + public LlmTokenCountCallback getCallback() { class TestCallback implements LlmTokenCountCallback { @Override - public Integer calculateLlmTokenCount(String model, String content) { + public int calculateLlmTokenCount(String model, String content) { return 13; } } - - TestCallback testCallback = new TestCallback(); - aiMonitoringImpl.setLlmTokenCountCallback(testCallback); + return new TestCallback(); } } diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java deleted file mode 100644 index 83a57e54ee..0000000000 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoringImpl.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.newrelic.api.agent; - -import java.util.Map; - -/** - * A utility class for interacting with the AI Monitoring API to record LlmFeedbackMessage events. - * This class implements the {@link AiMonitoring} interface and provides methods for feedback event recording - * and setting callbacks for token calculation. - */ - -public class AiMonitoringImpl implements AiMonitoring { - private static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/set"; - - /** - * Records an LlmFeedbackMessage event. - * - * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct - * the llmFeedbackEventAttributes map, use - * {@link LlmFeedbackEventAttributes.Builder} - *

    Required Attributes:

    - *
      - *
    • "traceId" (String): Trace ID where the chat completion related to the - * feedback event occurred
    • - *
    • "rating" (Integer/String): Rating provided by an end user
    • - *
    - * Optional Attributes: - *
      - *
    • "category" (String): Category of the feedback as provided by the end user
    • - *
    • "message" (String): Freeform text feedback from an end user.
    • - *
    • "metadata" (Map<String, String>): Set of key-value pairs to store - * additional data to submit with the feedback event
    • - *
    - */ - - @Override - public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { - // Delegate to Insights API for event recording - NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); - } - - /** - * Sets the callback for token calculation and reports a supportability metric. - * - * @param llmTokenCountCallback The callback instance implementing {@link LlmTokenCountCallback} interface. - * This callback will be used for token calculation. - * @see LlmTokenCountCallback - */ - @Override - public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { - LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); - NewRelic.getAgent().getMetricAggregator().incrementCounter(SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); - } -} From 82fecfc06c7ffa819fbecd1cec74c9ce8111b837 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 04:02:48 -0700 Subject: [PATCH 64/68] updated AiMonitoringImpl import --- newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java index cc1082bfdf..754375cb56 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java @@ -7,6 +7,7 @@ package com.newrelic.agent; +import com.newrelic.agent.aimonitoring.AiMonitoringImpl; import com.newrelic.agent.bridge.AgentBridge; import com.newrelic.agent.bridge.NoOpMetricAggregator; import com.newrelic.agent.bridge.NoOpTracedMethod; @@ -16,7 +17,6 @@ import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.tracers.Tracer; import com.newrelic.api.agent.AiMonitoring; -import com.newrelic.api.agent.AiMonitoringImpl; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; import com.newrelic.api.agent.Logger; From db09206f80f904b361f2af10e36c9d0571662152 Mon Sep 17 00:00:00 2001 From: edeleon Date: Mon, 1 Apr 2024 04:28:55 -0700 Subject: [PATCH 65/68] added NoOpAiMonitoring, updated implementations --- .../com/newrelic/agent/bridge/NoOpAgent.java | 4 +++- .../agent/bridge/NoOpAiMonitoring.java | 21 +++++++++++++++++++ .../com/newrelic/api/agent/NoOpAgent.java | 10 ++++++++- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java index 90521cc071..f43dcfff9a 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java @@ -19,6 +19,8 @@ import java.util.Collections; import java.util.Map; +import static com.newrelic.agent.bridge.NoOpAiMonitoring.INSTANCE; + class NoOpAgent implements Agent { static final Agent INSTANCE = new NoOpAgent(); @@ -68,7 +70,7 @@ public Insights getInsights() { @Override public AiMonitoring getAiMonitoring() { - return null; + return NoOpAiMonitoring.INSTANCE; } @Override diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java new file mode 100644 index 0000000000..db26d3ac71 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java @@ -0,0 +1,21 @@ +package com.newrelic.agent.bridge; + +import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.LlmTokenCountCallback; + +import java.util.Map; + +public class NoOpAiMonitoring implements AiMonitoring { + + static final AiMonitoring INSTANCE = new NoOpAiMonitoring(); + + private NoOpAiMonitoring() {} + + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + } + + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + } +} diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java index 16270a02b0..dc61aee4f2 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java @@ -360,6 +360,14 @@ public void recordCustomEvent(String eventType, Map attributes) { } }; + private static final AiMonitoring AI_MONITORING = new AiMonitoring() { + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) {} + + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) {} + }; + private static final Segment SEGMENT = new Segment() { @Override public void setMetricName(String... metricNameParts) { @@ -460,7 +468,7 @@ public Insights getInsights() { @Override public AiMonitoring getAiMonitoring() { - return null; + return AI_MONITORING; } @Override From 387e32f2739c4c3d5101a0964c53e91aa10c0331 Mon Sep 17 00:00:00 2001 From: edeleon Date: Tue, 2 Apr 2024 10:14:25 -0700 Subject: [PATCH 66/68] update naming for attributes --- .../com/newrelic/api/agent/LlmFeedbackEventAttributes.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java index 2f12a63695..9211f5c1c8 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java @@ -55,10 +55,10 @@ public String getIngestSource() { public Map toMap() { Map feedbackAttributesMap = new HashMap<>(); - feedbackAttributesMap.put("traceId", getTraceId()); + feedbackAttributesMap.put("trace_id", getTraceId()); feedbackAttributesMap.put("rating", getRating()); feedbackAttributesMap.put("id", getId()); - feedbackAttributesMap.put("ingestSource", getIngestSource()); + feedbackAttributesMap.put("ingest_source", getIngestSource()); if (category != null) { feedbackAttributesMap.put("category", getCategory()); } From 66fb7203a9c945918dfb72c95ef4fef045960bb3 Mon Sep 17 00:00:00 2001 From: edeleon Date: Tue, 2 Apr 2024 12:06:52 -0700 Subject: [PATCH 67/68] update tests with refactored attribute names --- .../api/agent/LlmFeedbackEventAttributesTest.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java index c42340b8ed..fb3a7c9898 100644 --- a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java @@ -28,10 +28,10 @@ public void testBuilderWithRequiredParamsOnly() { llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); assertNotNull(llmFeedbackEventAttributes); - assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); assertEquals(3, llmFeedbackEventAttributes.get("rating")); assertNotNull(llmFeedbackEventAttributes.get("id")); - assertEquals("Java", llmFeedbackEventAttributes.get("ingestSource")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingest_source")); assertFalse(llmFeedbackEventAttributes.containsKey("category")); assertFalse(llmFeedbackEventAttributes.containsKey("message")); assertFalse(llmFeedbackEventAttributes.containsKey("metadata")); @@ -46,7 +46,7 @@ public void testBuilderWithRequiredAndOptionalParams() { .build(); assertNotNull(llmFeedbackEventAttributes); - assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); assertEquals(3, llmFeedbackEventAttributes.get("rating")); assertEquals("exampleCategory", llmFeedbackEventAttributes.get("category")); assertEquals("exampleMessage", llmFeedbackEventAttributes.get("message")); @@ -61,13 +61,13 @@ public void testBuilderWithOptionalParamsSetToNull() { .build(); assertNotNull(llmFeedbackEventAttributes); - assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); assertEquals(3, llmFeedbackEventAttributes.get("rating")); assertNull(llmFeedbackEventAttributes.get("category")); assertNull(llmFeedbackEventAttributes.get("message")); assertNull(llmFeedbackEventAttributes.get("metadata")); assertNotNull(llmFeedbackEventAttributes.get("id")); - assertEquals("Java", llmFeedbackEventAttributes.get("ingestSource")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingest_source")); } @Test @@ -78,7 +78,7 @@ public void testBuilderWithRatingParamAsStringType() { llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); assertNotNull(llmFeedbackEventAttributes); - assertEquals("123456", llmFeedbackEventAttributes.get("traceId")); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); assertEquals("3", llmFeedbackEventAttributes.get("rating")); } From 434c3552b97d57f9daf880d7291e03f3a4f0241c Mon Sep 17 00:00:00 2001 From: Jason Keller Date: Tue, 2 Apr 2024 16:24:48 -0700 Subject: [PATCH 68/68] Disable AIM if HSM is enabled --- .../aimonitoring/AiMonitoringUtils.java | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java index 686b12e8ad..5332df36c8 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java @@ -7,13 +7,17 @@ package com.newrelic.agent.bridge.aimonitoring; +import com.newrelic.api.agent.Config; import com.newrelic.api.agent.NewRelic; +import java.util.logging.Level; + public class AiMonitoringUtils { // Enabled defaults private static final boolean AI_MONITORING_ENABLED_DEFAULT = false; private static final boolean AI_MONITORING_STREAMING_ENABLED_DEFAULT = true; private static final boolean AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT = true; + private static final boolean HIGH_SECURITY_ENABLED_DEFAULT = false; /** * Check if ai_monitoring features are enabled. @@ -22,15 +26,20 @@ public class AiMonitoringUtils { * @return true if AI monitoring is enabled, else false */ public static boolean isAiMonitoringEnabled() { - Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + Config config = NewRelic.getAgent().getConfig(); + Boolean aimEnabled = config.getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + Boolean highSecurity = config.getValue("high_security", HIGH_SECURITY_ENABLED_DEFAULT); - if (enabled) { - NewRelic.incrementCounter("Supportability/Java/ML/Enabled"); - } else { + if (highSecurity || !aimEnabled) { + aimEnabled = false; + String disabledReason = highSecurity ? "High Security Mode." : "agent config."; + NewRelic.getAgent().getLogger().log(Level.FINE, "AIM: AI Monitoring is disabled due to " + disabledReason); NewRelic.incrementCounter("Supportability/Java/ML/Disabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Enabled"); } - return enabled; + return aimEnabled; } /**