Skip to content

Commit

Permalink
add Watsonx client
Browse files Browse the repository at this point in the history
  • Loading branch information
mq200 committed Sep 6, 2024
1 parent 03648e9 commit 228420a
Show file tree
Hide file tree
Showing 11 changed files with 732 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ee.carlrobert.llm.client.watsonx;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public class IBMAuthBearerToken {
@JsonProperty("access_token")
String accessToken;
@JsonProperty("expiration")
int expiration;

String getAccessToken() {
return this.accessToken;
}

public void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}

int getExpiration() {
return this.expiration;
}

public void setExpiration(int expiration) {
this.expiration = expiration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package ee.carlrobert.llm.client.watsonx;

import okhttp3.*;

import java.io.IOException;
import java.util.Base64;
import java.util.Date;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

public class WatsonxAuthenticator {

IBMAuthBearerToken bearerToken;
OkHttpClient client;
Request request;
Boolean isZenApiKey=false;

// On Cloud
public WatsonxAuthenticator(String apiKey) {
this.client = new OkHttpClient().newBuilder()
.build();
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey="+apiKey);
this.request = new Request.Builder()
.url("https://iam.cloud.ibm.com/identity/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

// Zen API Key
public WatsonxAuthenticator(String username, String zenApiKey){
IBMAuthBearerToken token = new IBMAuthBearerToken();
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()).toString();
token.setAccessToken(tokenStr);
this.bearerToken = token;
this.isZenApiKey = true;
}

// Watsonx API Key
public WatsonxAuthenticator(String username, String apiKey, String host){//TODO add support for password
this.client = new OkHttpClient().newBuilder()
.build();
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, "{\"username\":\""+username+"\",\"api_key\":\""+apiKey+"\"}");
this.request = new Request.Builder()
.url(host + "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

private void generateNewBearerToken() {
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

public String getBearerTokenValue() {
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) < new Date().getTime() + 1000000)) {//TODO add correct number of seconds
generateNewBearerToken();
}
return this.bearerToken.getAccessToken();
}
}
163 changes: 163 additions & 0 deletions src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package ee.carlrobert.llm.client.watsonx;

import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

public class WatsonxClient {

private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
private final OkHttpClient httpClient;
private final String host;
private final String apiVersion;
private final WatsonxAuthenticator authenticator;

private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
this.httpClient = httpClientBuilder.build();
this.apiVersion = builder.apiVersion;
this.host = builder.host;
if (builder.isOnPrem) {
if (builder.isZenApiKey)
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey);
else
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, builder.host);
} else {
this.authenticator = new WatsonxAuthenticator(builder.apiKey);
}
}

public EventSource getCompletionAsync(
WatsonxCompletionRequest request,
CompletionEventListener<String> eventListener) {
return EventSources.createFactory(httpClient).newEventSource(
buildCompletionRequest(request),
getCompletionEventSourceListener(eventListener));
}

public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) {
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) {
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

protected Request buildCompletionRequest(WatsonxCompletionRequest request) {
var headers = new HashMap<>(getRequiredHeaders());
if (request.getStream()) {
headers.put("Accept", "text/event-stream");
}
try {
return new Request.Builder()
.url(host + "/ml/v1/text/" + (request.getStream() ? "generation_stream" : "generation") + "?version=" + apiVersion)
.headers(Headers.of(headers))
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to process request", e);
}
}

private Map<String, String> getRequiredHeaders() {
return new HashMap<>(Map.of("Authorization", "Bearer " + authenticator.getBearerTokenValue()));
}

private CompletionEventSourceListener<String> getCompletionEventSourceListener(
CompletionEventListener<String> eventListener) {
return new CompletionEventSourceListener<>(eventListener) {
@Override
protected String getMessage(String data) {
try {
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class)
.getResults().get(0).getGeneratedText();
} catch (Exception e) {
try {
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class)
.getError()
.getMessage();
if (message != null) return message;
return "";
} catch (Exception ex) {
return "";
}
}
}

@Override
protected ErrorDetails getErrorDetails(String error) {
try {
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
};
}

public static class Builder {

private final String apiKey;
private String host = PropertiesLoader.getValue("watsonx.baseUrl");
private String apiVersion = "2024-03-14";
private Boolean isOnPrem;
private Boolean isZenApiKey;
private String username;

public Builder(String apiKey){
this.apiKey = apiKey;
}
public Builder setApiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

public Builder setHost(String host) {
this.host = host;
return this;
}

public Builder setIsZenApiKey(Boolean isZenApiKey) {
this.isZenApiKey = isZenApiKey;
return this;
}

public Builder setIsOnPrem(Boolean isOnPrem) {
this.isOnPrem = isOnPrem;
return this;
}

public Builder setUsername(String username) {
this.username = username;
return this;
}

public WatsonxClient build(OkHttpClient.Builder builder) {
return new WatsonxClient(this, builder);
}

public WatsonxClient build() {
return build(new OkHttpClient.Builder());
}
}
}






Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package ee.carlrobert.llm.client.watsonx.completion;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import ee.carlrobert.llm.client.BaseError;

@JsonIgnoreProperties(ignoreUnknown = true)
public class WatsonxCompletionErrorDetails extends BaseError {

private static final String DEFAULT_ERROR_MSG = "Something went wrong. Please try again later.";

String code;
String message;

public WatsonxCompletionErrorDetails(String message) {
this(message, null);
}

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public WatsonxCompletionErrorDetails(
@JsonProperty("message") String message,
@JsonProperty("code") String code) {
this.message = message;
this.code = code;
}

public static WatsonxCompletionErrorDetails DEFAULT_ERROR = new WatsonxCompletionErrorDetails(DEFAULT_ERROR_MSG,null);

public String getMessage() {
return message;
}

public String getCode() {
return code;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ee.carlrobert.llm.client.watsonx.completion;

import ee.carlrobert.llm.completion.CompletionModel;

import java.util.Arrays;

public enum WatsonxCompletionModel implements CompletionModel {

GRANITE_3B_CODE_INSTRUCT("ibm/granite-3b-code-instruct","IBM Granite 3B Code Instruct", 8192),
GRANITE_8B_CODE_INSTRUCT("ibm/granite-8b-code-instruct","IBM Granite 8B Code Instruct", 8192),
GRANITE_20B_CODE_INSTRUCT( "ibm/granite-20b-code-instruct","IBM Granite 20B Code Instruct",8192),
GRANITE_34B_CODE_INSTRUCT( "ibm/granite-34b-code-instruct","IBM Granite 34B Code Instruct",8192),
CODELLAMA_34_B_INSTRUCT("codellama/codellama-34b-instruct-hf","Code Llama 34B Instruct", 8192),
MIXTRAL_8_7B("mistralai/mixtral-8x7b-instruct-v01","Mixtral (8x7B)",32768),
MIXTRAL_LARGE("mistralai/mistral-large","Mistral Large",128000),
LLAMA_3_1_70B( "meta-llama/llama-3-1-70b-instruct","Llama 3.1 Instruct (70B)", 128000),
LLAMA_3_1_8B( "meta-llama/llama-3-1-8b-instruct", "Llama 3.1 Instruct (8B)", 128000),
LLAMA_2_7B("meta-llama/llama-2-70b-chat","Llama 2 Chat (70B)",4096),
LLAMA_2_13B("meta-llama/llama-2-13b-chat","Llama 2 Chat (13B)",4096),
GRANITE_13B_INSTRUCT_V2("ibm/granite-13b-instruct-v2","IBM Granite 13B Instruct V2",8192),
GRANITE_13B_CHAT_V2("ibm/granite-13b-chat-v2","IBM Granite 13B Chat V2",8192),
GRANITE_20B_MULTILINGUAL("ibm/granite-20b-multilingual","IBM Granite 20B Multilingual",8192);

private final String code;
private final String description;
private final int maxTokens;

WatsonxCompletionModel(String code, String description, int maxTokens) {
this.code = code;
this.description = description;
this.maxTokens = maxTokens;
}

public String getCode() {
return code;
}

public String getDescription() {
return description;
}

public int getMaxTokens() {
return maxTokens;
}

@Override
public String toString() {
return description;
}

public static WatsonxCompletionModel findByCode(String code) {
return Arrays.stream(WatsonxCompletionModel.values())
.filter(item -> item.getCode().equals(code))
.findFirst().orElseThrow();
}
}

Loading

0 comments on commit 228420a

Please sign in to comment.