Skip to content

Commit

Permalink
Revert "Add support for Bedrock Converse API (Anthropic Messages API,…
Browse files Browse the repository at this point in the history
… Claude 3.5 Sonnet) (#2851) (#2913)" (#2929)

This reverts commit ed37690.
  • Loading branch information
Zhangxunmt authored Sep 10, 2024
1 parent 24fc9c3 commit 0135cb9
Show file tree
Hide file tree
Showing 18 changed files with 52 additions and 1,869 deletions.
5 changes: 0 additions & 5 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,3 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
}

forbiddenPatterns {
exclude '**/*.pdf'
exclude '**/*.jpg'
}

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField(),
params.getLlmMessages()
params.getLlmResponseField()
),
null,
llmQuestion,
Expand All @@ -203,8 +202,7 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField(),
params.getLlmMessages()
params.getLlmResponseField()
),
conversationId,
llmQuestion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.ParseField;
Expand All @@ -32,7 +30,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -84,8 +81,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");

private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");

public static final int SIZE_NULL_VALUE = -1;

static {
Expand All @@ -99,7 +94,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
}

@Setter
Expand Down Expand Up @@ -138,10 +132,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Setter
@Getter
private List<MessageBlock> llmMessages = new ArrayList<>();

public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -152,32 +142,6 @@ public GenerativeQAParameters(
Integer interactionSize,
Integer timeout,
String llmResponseField
) {
this(
conversationId,
llmModel,
llmQuestion,
systemPrompt,
userInstructions,
contextSize,
interactionSize,
timeout,
llmResponseField,
null
);
}

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout,
String llmResponseField,
List<MessageBlock> llmMessages
) {
this.conversationId = conversationId;
this.llmModel = llmModel;
Expand All @@ -192,9 +156,6 @@ public GenerativeQAParameters(
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
this.llmResponseField = llmResponseField;
if (llmMessages != null) {
this.llmMessages.addAll(llmMessages);
}
}

public GenerativeQAParameters(StreamInput input) throws IOException {
Expand All @@ -207,7 +168,6 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();
this.llmMessages.addAll(input.readList(MessageBlock::new));
}

@Override
Expand All @@ -221,8 +181,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
}

@Override
Expand All @@ -238,7 +197,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);
out.writeList(llmMessages);
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
Expand All @@ -265,8 +223,4 @@ public boolean equals(Object o) {
&& (this.timeout == other.getTimeout())
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
}

public void setMessageBlock(List<MessageBlock> blockList) {
this.llmMessages = blockList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,4 @@ public class ChatCompletionInput {
private String userInstructions;
private Llm.ModelProvider modelProvider;
private String llmResponseField;
private List<MessageBlock> llmMessages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
Expand Down Expand Up @@ -112,15 +113,14 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
// log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
|| chatCompletionInput.getLlmResponseField() != null) {
Expand All @@ -136,19 +136,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
chatCompletionInput.getContexts()
)
);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
// Bedrock Converse API does not include the system prompt as part of the Messages block.
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
null,
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: "
Expand All @@ -157,6 +144,7 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
);
}

// log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

Expand Down Expand Up @@ -196,20 +184,6 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
Map output = (Map) dataAsMap.get("output");
Map message = (Map) output.get("message");
if (message != null) {
List content = (List) message.get("content");
String answer = (String) ((Map) content.get(0)).get("text");
answers.add(answer);
} else {
Map error = (Map) output.get("error");
if (error == null) {
throw new RuntimeException("Unexpected output: " + output);
}
errors.add((String) error.get("message"));
}
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ public interface Llm {
enum ModelProvider {
OPENAI,
BEDROCK,
COHERE,
BEDROCK_CONVERSE
COHERE
}

void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public class LlmIOUtil {

public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
Expand All @@ -50,8 +49,7 @@ public static ChatCompletionInput createChatCompletionInput(
chatHistory,
contexts,
timeoutInSeconds,
llmResponseField,
null
llmResponseField
);
}

Expand All @@ -63,8 +61,7 @@ public static ChatCompletionInput createChatCompletionInput(
List<Interaction> chatHistory,
List<String> contexts,
int timeoutInSeconds,
String llmResponseField,
List<MessageBlock> llmMessages
String llmResponseField
) {
Llm.ModelProvider provider = null;
if (llmResponseField == null) {
Expand All @@ -74,8 +71,6 @@ public static ChatCompletionInput createChatCompletionInput(
provider = Llm.ModelProvider.BEDROCK;
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.COHERE;
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
}
}
}
Expand All @@ -88,8 +83,7 @@ public static ChatCompletionInput createChatCompletionInput(
systemPrompt,
userInstructions,
provider,
llmResponseField,
llmMessages
llmResponseField
);
}
}
Loading

0 comments on commit 0135cb9

Please sign in to comment.