From b0e8b924dd96fc68bb5e9d5a48dfb5055a7f9d23 Mon Sep 17 00:00:00 2001 From: mq200 Date: Sat, 14 Sep 2024 13:13:04 -0700 Subject: [PATCH 1/2] watonx client code formatting --- CHANGELOG.md | 9 +- README.md | 4 +- build.gradle.kts | 2 +- .../llm/client/ollama/OllamaClient.java | 3 +- .../client/watsonx/IBMAuthBearerToken.java | 29 +++ .../client/watsonx/WatsonxAuthenticator.java | 107 ++++++++ .../llm/client/watsonx/WatsonxClient.java | 178 +++++++++++++ .../WatsonxCompletionErrorDetails.java | 41 +++ .../completion/WatsonxCompletionModel.java | 56 ++++ .../completion/WatsonxCompletionRequest.java | 242 ++++++++++++++++++ .../completion/WatsonxCompletionResponse.java | 39 +++ .../WatsonxCompletionResponseError.java | 22 ++ .../completion/WatsonxCompletionResult.java | 33 +++ .../WatsonxCompletionStreamResponse.java | 40 +++ src/main/resources/application.properties | 3 +- 15 files changed, 802 insertions(+), 6 deletions(-) create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java create mode 100644 src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 395ed49..a3ed83a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.8.17] - 2024-09-12 + +### Fixed + +- Ollama host overriding + ## [0.8.16] - 2024-09-05 ### Fixed @@ -216,7 +222,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Upgrade OpenAI chat models: **gpt-4-0125-preview**, **gpt-3.5-turbo-0125** -[0.8.16]: https://github.com/carlrobertoh/llm-client/compare/d714854331915387da583c9a5b24877cc06286e...HEAD +[0.8.17]: https://github.com/carlrobertoh/llm-client/compare/6b7e26477b8e3454e78c8c639e97c8803fa5a301...HEAD +[0.8.16]: https://github.com/carlrobertoh/llm-client/compare/d714854331915387da583c9a5b24877cc06286e...6b7e26477b8e3454e78c8c639e97c8803fa5a301 [0.8.15]: https://github.com/carlrobertoh/llm-client/compare/fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc...d714854331915387da583c9a5b24877cc06286e [0.8.14]: https://github.com/carlrobertoh/llm-client/compare/6461c8458325e7b2a33670fc09493b3357eb094c...fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc [0.8.13]: https://github.com/carlrobertoh/llm-client/compare/a55fe7dcefbe6b911d5b99950d402dd06a66ec1e...6461c8458325e7b2a33670fc09493b3357eb094c diff --git a/README.md b/README.md index 43a6e65..5263b86 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,13 @@ To use the package, you need to use following Maven dependency: ee.carlrobert llm-client - 0.8.16 + 0.8.17 ``` Gradle dependency: ```kts dependencies { - implementation("ee.carlrobert:llm-client:0.8.16") + implementation("ee.carlrobert:llm-client:0.8.17") } ``` diff --git a/build.gradle.kts b/build.gradle.kts index a421d4d..150f3a0 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { } group = "ee.carlrobert" -version = "0.8.16" +version = "0.8.17" repositories { mavenCentral() diff --git a/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java b/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java index 18adde8..956e8ea 100644 --- a/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java +++ b/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java @@ -244,7 +244,8 @@ private void processStreamRequest( private HttpRequest buildPostHttpRequest( Object request, String path) throws JsonProcessingException { - var requestBuilder = HttpRequest.newBuilder(URI.create(BASE_URL + path)) + var baseHost = port == null ? BASE_URL : format("http://localhost:%d", port); + var requestBuilder = HttpRequest.newBuilder(URI.create((host == null ? baseHost : host) + path)) .POST(HttpRequest.BodyPublishers.ofString(new ObjectMapper().writeValueAsString(request))) .header("Content-Type", "application/x-ndjson"); diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java b/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java new file mode 100644 index 0000000..58668d1 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java @@ -0,0 +1,29 @@ +package ee.carlrobert.llm.client.watsonx; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class IBMAuthBearerToken { + + @JsonProperty("access_token") + String accessToken; + @JsonProperty("expiration") + int expiration; + + String getAccessToken() { + return this.accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + int getExpiration() { + return this.expiration; + } + + public void setExpiration(int expiration) { + this.expiration = expiration; + } +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java new file mode 100644 index 0000000..640c913 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java @@ -0,0 +1,107 @@ +package ee.carlrobert.llm.client.watsonx; + +import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Base64; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.Map; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; + +public class WatsonxAuthenticator { + + IBMAuthBearerToken bearerToken; + OkHttpClient client; + Request request; + Boolean isZenApiKey = false; + + // On Cloud + public WatsonxAuthenticator(String apiKey) { + this.client = new OkHttpClient().newBuilder() + .build(); + MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); + RequestBody body = RequestBody.create(mediaType, + "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey); + this.request = new Request.Builder() + .url("https://iam.cloud.ibm.com/identity/token") + .method("POST", body) + .addHeader("Content-Type", "application/x-www-form-urlencoded") + .build(); + try { + Response response = client.newCall(request).execute(); + this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), + IBMAuthBearerToken.class); + } catch (IOException e) { + System.out.println(e); + } + } + + // Zen API Key + public WatsonxAuthenticator(String username, String zenApiKey) { + IBMAuthBearerToken token = new IBMAuthBearerToken(); + String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()) + .toString(); + token.setAccessToken(tokenStr); + this.bearerToken = token; + this.isZenApiKey = true; + } + + // Watsonx API Key + public WatsonxAuthenticator(String username, String apiKey, + String host) {//TODO add support for password + this.client = new OkHttpClient().newBuilder() + .build(); + ObjectMapper mapper = new ObjectMapper(); + Map authParams = new LinkedHashMap<>(); + authParams.put("username", username); + authParams.put("api_key", apiKey); + + String authParamsStr = ""; + try { + authParamsStr = mapper.writeValueAsString(authParams); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + MediaType mediaType = MediaType.parse("application/json"); + RequestBody body = RequestBody.create(mediaType, authParamsStr); + this.request = new Request.Builder() + .url(host + + "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken + .method("POST", body) + .addHeader("Content-Type", "application/json") + .build(); + try { + Response response = client.newCall(request).execute(); + this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), + IBMAuthBearerToken.class); + } catch (IOException e) { + System.out.println(e); + } + } + + private void generateNewBearerToken() { + try { + Response response = client.newCall(request).execute(); + this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), + IBMAuthBearerToken.class); + } catch (IOException e) { + System.out.println(e); + } + } + + public String getBearerTokenValue() { + if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) + < (new Date().getTime() + 60000))) { + generateNewBearerToken(); + } + return this.bearerToken.getAccessToken(); + } +} diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java new file mode 100644 index 0000000..970c3cf --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java @@ -0,0 +1,178 @@ +package ee.carlrobert.llm.client.watsonx; + +import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; + +import com.fasterxml.jackson.core.JsonProcessingException; +import ee.carlrobert.llm.PropertiesLoader; +import ee.carlrobert.llm.client.DeserializationUtil; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest; +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse; +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError; +import ee.carlrobert.llm.completion.CompletionEventListener; +import ee.carlrobert.llm.completion.CompletionEventSourceListener; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import okhttp3.Headers; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSources; + +public class WatsonxClient { + + private static final MediaType APPLICATION_JSON = MediaType.parse("application/json"); + private final OkHttpClient httpClient; + private final String host; + private final String apiVersion; + private final WatsonxAuthenticator authenticator; + + private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) { + this.httpClient = httpClientBuilder.build(); + this.apiVersion = builder.apiVersion; + this.host = builder.host; + if (builder.isOnPrem) { + if (builder.isZenApiKey) { + this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey); + } else { + this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, + builder.host); + } + } else { + this.authenticator = new WatsonxAuthenticator(builder.apiKey); + } + } + + public EventSource getCompletionAsync( + WatsonxCompletionRequest request, + CompletionEventListener eventListener) { + return EventSources.createFactory(httpClient).newEventSource( + buildCompletionRequest(request), + getCompletionEventSourceListener(eventListener)); + } + + public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) { + try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) { + return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Request buildCompletionRequest(WatsonxCompletionRequest request) { + var headers = new HashMap<>(getRequiredHeaders()); + if (request.getStream()) { + headers.put("Accept", "text/event-stream"); + } + try { + String deployment = request.getDeploymentId().isEmpty() ? "" + : "deployments/" + request.getDeploymentId() + "/"; + String generation = request.getStream() ? "generation_stream" : "generation"; + return new Request.Builder() + .url(host + "/ml/v1/text/" + deployment + generation + "?version=" + apiVersion) + .headers(Headers.of(headers)) + .post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON)) + .build(); + } catch (JsonProcessingException e) { + throw new RuntimeException("Unable to process request", e); + } + } + + private Map getRequiredHeaders() { + return new HashMap<>(Map.of("Authorization", + (this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ") + + authenticator.getBearerTokenValue())); + } + + private CompletionEventSourceListener getCompletionEventSourceListener( + CompletionEventListener eventListener) { + return new CompletionEventSourceListener<>(eventListener) { + @Override + protected String getMessage(String data) { + try { + return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class) + .getResults().get(0).getGeneratedText(); + } catch (Exception e) { + try { + System.out.println(data); + String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class) + .getError() + .getMessage(); + if (message != null) { + return message; + } + return message; + } catch (Exception ex) { + System.out.println(ex.toString()); + return ""; + } + } + } + + @Override + protected ErrorDetails getErrorDetails(String error) { + try { + return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError(); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + }; + } + + public static class Builder { + + private final String apiKey; + private String host = PropertiesLoader.getValue("watsonx.baseUrl"); + private String apiVersion = "2024-03-14"; + private Boolean isOnPrem; + private Boolean isZenApiKey; + private String username; + + public Builder(String apiKey) { + this.apiKey = apiKey; + } + + public Builder setApiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + public Builder setHost(String host) { + this.host = host; + return this; + } + + public Builder setIsZenApiKey(Boolean isZenApiKey) { + this.isZenApiKey = isZenApiKey; + return this; + } + + public Builder setIsOnPrem(Boolean isOnPrem) { + this.isOnPrem = isOnPrem; + return this; + } + + public Builder setUsername(String username) { + this.username = username; + return this; + } + + public WatsonxClient build(OkHttpClient.Builder builder) { + return new WatsonxClient(this, builder); + } + + public WatsonxClient build() { + return build(new OkHttpClient.Builder()); + } + } +} + + + + + + diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java new file mode 100644 index 0000000..4a45f9e --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java @@ -0,0 +1,41 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; + + +@JsonIgnoreProperties(ignoreUnknown = true) +public class WatsonxCompletionErrorDetails { + + private static final String DEFAULT_ERROR_MSG = "Something went wrong. Please try again later."; + public static WatsonxCompletionErrorDetails DEFAULT_ERROR = new WatsonxCompletionErrorDetails( + DEFAULT_ERROR_MSG, null); + String code; + String message; + ErrorDetails details; + + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public WatsonxCompletionErrorDetails( + @JsonProperty("message") String message, + @JsonProperty("code") String code) { + this.message = message; + this.code = code; + this.details = new ErrorDetails(message, null, null, code); + } + + public String getMessage() { + return message; + } + + public String getCode() { + return code; + } + + public ErrorDetails getDetails() { + return details; + } + +} diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java new file mode 100644 index 0000000..69b5c95 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java @@ -0,0 +1,56 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import ee.carlrobert.llm.completion.CompletionModel; +import java.util.Arrays; + +public enum WatsonxCompletionModel implements CompletionModel { + + GRANITE_3B_CODE_INSTRUCT("ibm/granite-3b-code-instruct", "IBM Granite 3B Code Instruct", 8192), + GRANITE_8B_CODE_INSTRUCT("ibm/granite-8b-code-instruct", "IBM Granite 8B Code Instruct", 8192), + GRANITE_20B_CODE_INSTRUCT("ibm/granite-20b-code-instruct", "IBM Granite 20B Code Instruct", 8192), + GRANITE_34B_CODE_INSTRUCT("ibm/granite-34b-code-instruct", "IBM Granite 34B Code Instruct", 8192), + CODELLAMA_34_B_INSTRUCT("codellama/codellama-34b-instruct-hf", "Code Llama 34B Instruct", 8192), + MIXTRAL_8_7B("mistralai/mixtral-8x7b-instruct-v01", "Mixtral (8x7B)", 32768), + MIXTRAL_LARGE("mistralai/mistral-large", "Mistral Large", 128000), + LLAMA_3_1_70B("meta-llama/llama-3-1-70b-instruct", "Llama 3.1 Instruct (70B)", 128000), + LLAMA_3_1_8B("meta-llama/llama-3-1-8b-instruct", "Llama 3.1 Instruct (8B)", 128000), + LLAMA_2_7B("meta-llama/llama-2-70b-chat", "Llama 2 Chat (70B)", 4096), + LLAMA_2_13B("meta-llama/llama-2-13b-chat", "Llama 2 Chat (13B)", 4096), + GRANITE_13B_INSTRUCT_V2("ibm/granite-13b-instruct-v2", "IBM Granite 13B Instruct V2", 8192), + GRANITE_13B_CHAT_V2("ibm/granite-13b-chat-v2", "IBM Granite 13B Chat V2", 8192), + GRANITE_20B_MULTILINGUAL("ibm/granite-20b-multilingual", "IBM Granite 20B Multilingual", 8192); + + private final String code; + private final String description; + private final int maxTokens; + + WatsonxCompletionModel(String code, String description, int maxTokens) { + this.code = code; + this.description = description; + this.maxTokens = maxTokens; + } + + public static WatsonxCompletionModel findByCode(String code) { + return Arrays.stream(WatsonxCompletionModel.values()) + .filter(item -> item.getCode().equals(code)) + .findFirst().orElseThrow(); + } + + public String getCode() { + return code; + } + + public String getDescription() { + return description; + } + + public int getMaxTokens() { + return maxTokens; + } + + @Override + public String toString() { + return description; + } +} + diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java new file mode 100644 index 0000000..b2c6cc7 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java @@ -0,0 +1,242 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import ee.carlrobert.llm.completion.CompletionRequest; + +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WatsonxCompletionRequest implements CompletionRequest { + + String input; + @JsonProperty("project_id") + String projectId; + @JsonProperty("space_id") + String spaceId; + @JsonProperty("model_id") + String modelId; + String deploymentId; + Boolean stream; + WatsonxCompletionParameters parameters; + + public WatsonxCompletionRequest(Builder builder) { + System.out.println("Model ID: " + builder.modelId); + System.out.println("Deployment ID: " + builder.deploymentId); + System.out.println("decodingMethod: " + builder.decodingMethod); + System.out.println("maxNewTokens: " + builder.maxNewTokens); + System.out.println("minNewTokens: " + builder.minNewTokens); + System.out.println("randomSeed: " + builder.randomSeed); + System.out.println("stopSequences: " + builder.stopSequences); + System.out.println("timeLimit: " + builder.timeLimit); + System.out.println("topK: " + builder.topK); + System.out.println("topP: " + builder.topP); + System.out.println("temperature: " + builder.temperature); + System.out.println("rep penalty: " + builder.repetitionPenalty); + System.out.println("include stop seq: " + builder.includeStopSequence); + this.input = builder.input; + this.stream = builder.stream; + this.projectId = builder.projectId; + this.spaceId = builder.spaceId; + this.modelId = builder.modelId; + this.deploymentId = builder.deploymentId; + this.parameters = new WatsonxCompletionParameters( + builder.decodingMethod, + builder.maxNewTokens, + builder.minNewTokens, + builder.randomSeed, + builder.stopSequences, + builder.timeLimit, + builder.topK, + builder.topP, + builder.temperature, + builder.repetitionPenalty, + builder.includeStopSequence); + } + + public Boolean getStream() { + return this.stream; + } + + public String getModelId() { + return modelId; + } + + public String getDeploymentId() { + return deploymentId; + } + + public String getSpaceId() { + return spaceId; + } + + public String getProjectId() { + return projectId; + } + + public String getInput() { + return input; + } + + public WatsonxCompletionParameters getParameters() { + return parameters; + } + + public static class Builder { + + String input; + String projectId; + String spaceId; + String modelId; + String deploymentId; + Boolean stream; + String decodingMethod; + Integer maxNewTokens; + Integer minNewTokens; + Integer randomSeed; + String[] stopSequences; + Integer timeLimit; + Integer topK; + Double topP; + Double repetitionPenalty; + Boolean includeStopSequence; + Double temperature; + + public Builder(String prompt) { + this.input = prompt; + } + + public Builder setInput(String input) { + this.input = input; + return this; + } + + public Builder setModelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder setDeploymentId(String deploymentId) { + this.deploymentId = deploymentId; + return this; + } + + public Builder setSpaceId(String spaceId) { + this.spaceId = spaceId; + return this; + } + + public Builder setProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder setStream(Boolean stream) { + this.stream = stream; + return this; + } + + public Builder setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + return this; + } + + public Builder setMinNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + return this; + } + + public Builder setTemperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder setRepetitionPenalty(Double frequencyPenalty) { + this.repetitionPenalty = frequencyPenalty; + return this; + } + + public Builder setDecodingMethod(String decodingMethod) { + this.decodingMethod = decodingMethod; + return this; + } + + public Builder setStopSequences(String[] stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder setIncludeStopSequence(Boolean includeStopSequence) { + this.includeStopSequence = includeStopSequence; + return this; + } + + public Builder setRandomSeed(Integer randomSeed) { + this.randomSeed = randomSeed; + return this; + } + + public Builder setTopP(Double topP) { + this.topP = topP; + return this; + } + + public Builder setTopK(Integer topK) { + this.topK = topK; + return this; + } + + public WatsonxCompletionRequest build() { + return new WatsonxCompletionRequest(this); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public class WatsonxCompletionParameters { + + @JsonProperty("decoding_method") + String decodingMethod; + @JsonProperty("max_new_tokens") + Integer maxNewTokens; + @JsonProperty("min_new_tokens") + Integer minNewTokens; + @JsonProperty("random_seed") + Integer randomSeed; + @JsonProperty("stop_sequences") + String[] stopSequences; + @JsonProperty("time_limit") + Integer timeLimit; + @JsonProperty("top_k") + Integer topK; + @JsonProperty("top_p") + Double topP; + Double temperature; + @JsonProperty("repetition_penalty") + Double repetitionPenalty; + @JsonProperty("include_stop_sequence") + Boolean includeStopSequence; + + public WatsonxCompletionParameters( + String decodingMethod, + Integer maxNewTokens, + Integer minNewTokens, + Integer randomSeed, + String[] stopSequences, + Integer timeLimit, + Integer topK, + Double topP, + Double temperature, + Double repetitionPenalty, + Boolean includeStopSequence) { + this.decodingMethod = decodingMethod; + this.maxNewTokens = maxNewTokens; + this.minNewTokens = minNewTokens; + this.randomSeed = randomSeed; + this.stopSequences = stopSequences; + this.timeLimit = timeLimit; + this.topK = topK; + this.topP = topP; + this.temperature = temperature; + this.repetitionPenalty = repetitionPenalty; + this.includeStopSequence = includeStopSequence; + } + } +} diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java new file mode 100644 index 0000000..01a48d0 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java @@ -0,0 +1,39 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import ee.carlrobert.llm.completion.CompletionResponse; +import java.util.List; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class WatsonxCompletionResponse implements CompletionResponse { + + private final String modelId; + private String createdAt; + private List results; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public WatsonxCompletionResponse( + @JsonProperty("model_id") String modelId, + @JsonProperty("created_at") String createdAt, + @JsonProperty("results") List results) { + this.modelId = modelId; + this.createdAt = createdAt; + this.results = results; + } + + public String getModelId() { + return modelId; + } + + public String getCreatedAt() { + return createdAt; + } + + public List getResults() { + return results; + } +} + + diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java new file mode 100644 index 0000000..14ed4c9 --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java @@ -0,0 +1,22 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class WatsonxCompletionResponseError { + + private final WatsonxCompletionErrorDetails error; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public WatsonxCompletionResponseError( + @JsonProperty("error") WatsonxCompletionErrorDetails error) { + this.error = error; + } + + public ErrorDetails getError() { + return error.getDetails(); + } +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java new file mode 100644 index 0000000..a24235c --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java @@ -0,0 +1,33 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class WatsonxCompletionResult { + + String generatedText; + String stopReason; + int generatedTokenCount; + int inputTokenCount; + int seed; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public WatsonxCompletionResult( + @JsonProperty("generated_text") String generatedText, + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("generated_token_count") int generatedTokenCount, + @JsonProperty("input_token_count") int inputTokenCount, + @JsonProperty("seed") int seed) { + this.generatedText = generatedText; + this.stopReason = stopReason; + this.generatedTokenCount = generatedTokenCount; + this.inputTokenCount = inputTokenCount; + this.seed = seed; + } + + public String getGeneratedText() { + return this.generatedText; + } +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java new file mode 100644 index 0000000..0164d4c --- /dev/null +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java @@ -0,0 +1,40 @@ +package ee.carlrobert.llm.client.watsonx.completion; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class WatsonxCompletionStreamResponse { + + private final String modelId; + private List items; + private String createdAt; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public WatsonxCompletionStreamResponse( + @JsonProperty("model_id") String modelId, + @JsonProperty("created_at") String createdAt, + @JsonProperty("items") List items) { + this.modelId = modelId; + this.createdAt = createdAt; + this.items = items; + } + + public String getModelId() { + WatsonxCompletionResponse firstItem = items.get(0); + return firstItem.getModelId(); + } + + public String getCreatedAt() { + WatsonxCompletionResponse firstItem = items.get(0); + return firstItem.getCreatedAt(); + } + + public List getItems() { + return items; + } +} + + diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 5747d22..efcfd80 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -5,4 +5,5 @@ anthropic.baseUrl=https://api.anthropic.com you.baseUrl=https://you.com llama.baseUrl=http://localhost:8080 ollama.baseUrl=http://localhost:11434 -google.baseUrl=https://generativelanguage.googleapis.com \ No newline at end of file +google.baseUrl=https://generativelanguage.googleapis.com +watsonx.baseUrl=https://us-south.ml.cloud.ibm.com \ No newline at end of file From 8b8b09fa10feda2d7f000752403686393363f46a Mon Sep 17 00:00:00 2001 From: mq200 Date: Sat, 14 Sep 2024 13:25:23 -0700 Subject: [PATCH 2/2] - --- .../java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java index 970c3cf..85e7a08 100644 --- a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java +++ b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java @@ -101,9 +101,7 @@ protected String getMessage(String data) { String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class) .getError() .getMessage(); - if (message != null) { - return message; - } + if (message == null) return ""; return message; } catch (Exception ex) { System.out.println(ex.toString());