-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
732 additions
and
1 deletion.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
public class IBMAuthBearerToken { | ||
@JsonProperty("access_token") | ||
String accessToken; | ||
@JsonProperty("expiration") | ||
int expiration; | ||
|
||
String getAccessToken() { | ||
return this.accessToken; | ||
} | ||
|
||
public void setAccessToken(String accessToken) { | ||
this.accessToken = accessToken; | ||
} | ||
|
||
int getExpiration() { | ||
return this.expiration; | ||
} | ||
|
||
public void setExpiration(int expiration) { | ||
this.expiration = expiration; | ||
} | ||
} |
80 changes: 80 additions & 0 deletions
80
src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import okhttp3.*; | ||
|
||
import java.io.IOException; | ||
import java.util.Base64; | ||
import java.util.Date; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
public class WatsonxAuthenticator { | ||
|
||
IBMAuthBearerToken bearerToken; | ||
OkHttpClient client; | ||
Request request; | ||
Boolean isZenApiKey=false; | ||
|
||
// On Cloud | ||
public WatsonxAuthenticator(String apiKey) { | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); | ||
RequestBody body = RequestBody.create(mediaType, "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey="+apiKey); | ||
this.request = new Request.Builder() | ||
.url("https://iam.cloud.ibm.com/identity/token") | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/x-www-form-urlencoded") | ||
.build(); | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
// Zen API Key | ||
public WatsonxAuthenticator(String username, String zenApiKey){ | ||
IBMAuthBearerToken token = new IBMAuthBearerToken(); | ||
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()).toString(); | ||
token.setAccessToken(tokenStr); | ||
this.bearerToken = token; | ||
this.isZenApiKey = true; | ||
} | ||
|
||
// Watsonx API Key | ||
public WatsonxAuthenticator(String username, String apiKey, String host){//TODO add support for password | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
MediaType mediaType = MediaType.parse("application/json"); | ||
RequestBody body = RequestBody.create(mediaType, "{\"username\":\""+username+"\",\"api_key\":\""+apiKey+"\"}"); | ||
this.request = new Request.Builder() | ||
.url(host + "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/json") | ||
.build(); | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
private void generateNewBearerToken() { | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
public String getBearerTokenValue() { | ||
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) < new Date().getTime() + 1000000)) {//TODO add correct number of seconds | ||
generateNewBearerToken(); | ||
} | ||
return this.bearerToken.getAccessToken(); | ||
} | ||
} |
163 changes: 163 additions & 0 deletions
163
src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import ee.carlrobert.llm.PropertiesLoader; | ||
import ee.carlrobert.llm.client.DeserializationUtil; | ||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError; | ||
import ee.carlrobert.llm.completion.CompletionEventListener; | ||
import ee.carlrobert.llm.completion.CompletionEventSourceListener; | ||
import okhttp3.*; | ||
import okhttp3.sse.EventSource; | ||
import okhttp3.sse.EventSources; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
public class WatsonxClient { | ||
|
||
private static final MediaType APPLICATION_JSON = MediaType.parse("application/json"); | ||
private final OkHttpClient httpClient; | ||
private final String host; | ||
private final String apiVersion; | ||
private final WatsonxAuthenticator authenticator; | ||
|
||
private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) { | ||
this.httpClient = httpClientBuilder.build(); | ||
this.apiVersion = builder.apiVersion; | ||
this.host = builder.host; | ||
if (builder.isOnPrem) { | ||
if (builder.isZenApiKey) | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey); | ||
else | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, builder.host); | ||
} else { | ||
this.authenticator = new WatsonxAuthenticator(builder.apiKey); | ||
} | ||
} | ||
|
||
public EventSource getCompletionAsync( | ||
WatsonxCompletionRequest request, | ||
CompletionEventListener<String> eventListener) { | ||
return EventSources.createFactory(httpClient).newEventSource( | ||
buildCompletionRequest(request), | ||
getCompletionEventSourceListener(eventListener)); | ||
} | ||
|
||
public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) { | ||
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) { | ||
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
protected Request buildCompletionRequest(WatsonxCompletionRequest request) { | ||
var headers = new HashMap<>(getRequiredHeaders()); | ||
if (request.getStream()) { | ||
headers.put("Accept", "text/event-stream"); | ||
} | ||
try { | ||
return new Request.Builder() | ||
.url(host + "/ml/v1/text/" + (request.getStream() ? "generation_stream" : "generation") + "?version=" + apiVersion) | ||
.headers(Headers.of(headers)) | ||
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON)) | ||
.build(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException("Unable to process request", e); | ||
} | ||
} | ||
|
||
private Map<String, String> getRequiredHeaders() { | ||
return new HashMap<>(Map.of("Authorization", "Bearer " + authenticator.getBearerTokenValue())); | ||
} | ||
|
||
private CompletionEventSourceListener<String> getCompletionEventSourceListener( | ||
CompletionEventListener<String> eventListener) { | ||
return new CompletionEventSourceListener<>(eventListener) { | ||
@Override | ||
protected String getMessage(String data) { | ||
try { | ||
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class) | ||
.getResults().get(0).getGeneratedText(); | ||
} catch (Exception e) { | ||
try { | ||
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class) | ||
.getError() | ||
.getMessage(); | ||
if (message != null) return message; | ||
return ""; | ||
} catch (Exception ex) { | ||
return ""; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
protected ErrorDetails getErrorDetails(String error) { | ||
try { | ||
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final String apiKey; | ||
private String host = PropertiesLoader.getValue("watsonx.baseUrl"); | ||
private String apiVersion = "2024-03-14"; | ||
private Boolean isOnPrem; | ||
private Boolean isZenApiKey; | ||
private String username; | ||
|
||
public Builder(String apiKey){ | ||
this.apiKey = apiKey; | ||
} | ||
public Builder setApiVersion(String apiVersion) { | ||
this.apiVersion = apiVersion; | ||
return this; | ||
} | ||
|
||
public Builder setHost(String host) { | ||
this.host = host; | ||
return this; | ||
} | ||
|
||
public Builder setIsZenApiKey(Boolean isZenApiKey) { | ||
this.isZenApiKey = isZenApiKey; | ||
return this; | ||
} | ||
|
||
public Builder setIsOnPrem(Boolean isOnPrem) { | ||
this.isOnPrem = isOnPrem; | ||
return this; | ||
} | ||
|
||
public Builder setUsername(String username) { | ||
this.username = username; | ||
return this; | ||
} | ||
|
||
public WatsonxClient build(OkHttpClient.Builder builder) { | ||
return new WatsonxClient(this, builder); | ||
} | ||
|
||
public WatsonxClient build() { | ||
return build(new OkHttpClient.Builder()); | ||
} | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
|
||
|
38 changes: 38 additions & 0 deletions
38
src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package ee.carlrobert.llm.client.watsonx.completion; | ||
|
||
import com.fasterxml.jackson.annotation.JsonCreator; | ||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import ee.carlrobert.llm.client.BaseError; | ||
|
||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
public class WatsonxCompletionErrorDetails extends BaseError { | ||
|
||
private static final String DEFAULT_ERROR_MSG = "Something went wrong. Please try again later."; | ||
|
||
String code; | ||
String message; | ||
|
||
public WatsonxCompletionErrorDetails(String message) { | ||
this(message, null); | ||
} | ||
|
||
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES) | ||
public WatsonxCompletionErrorDetails( | ||
@JsonProperty("message") String message, | ||
@JsonProperty("code") String code) { | ||
this.message = message; | ||
this.code = code; | ||
} | ||
|
||
public static WatsonxCompletionErrorDetails DEFAULT_ERROR = new WatsonxCompletionErrorDetails(DEFAULT_ERROR_MSG,null); | ||
|
||
public String getMessage() { | ||
return message; | ||
} | ||
|
||
public String getCode() { | ||
return code; | ||
} | ||
|
||
} |
57 changes: 57 additions & 0 deletions
57
src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package ee.carlrobert.llm.client.watsonx.completion; | ||
|
||
import ee.carlrobert.llm.completion.CompletionModel; | ||
|
||
import java.util.Arrays; | ||
|
||
public enum WatsonxCompletionModel implements CompletionModel { | ||
|
||
GRANITE_3B_CODE_INSTRUCT("ibm/granite-3b-code-instruct","IBM Granite 3B Code Instruct", 8192), | ||
GRANITE_8B_CODE_INSTRUCT("ibm/granite-8b-code-instruct","IBM Granite 8B Code Instruct", 8192), | ||
GRANITE_20B_CODE_INSTRUCT( "ibm/granite-20b-code-instruct","IBM Granite 20B Code Instruct",8192), | ||
GRANITE_34B_CODE_INSTRUCT( "ibm/granite-34b-code-instruct","IBM Granite 34B Code Instruct",8192), | ||
CODELLAMA_34_B_INSTRUCT("codellama/codellama-34b-instruct-hf","Code Llama 34B Instruct", 8192), | ||
MIXTRAL_8_7B("mistralai/mixtral-8x7b-instruct-v01","Mixtral (8x7B)",32768), | ||
MIXTRAL_LARGE("mistralai/mistral-large","Mistral Large",128000), | ||
LLAMA_3_1_70B( "meta-llama/llama-3-1-70b-instruct","Llama 3.1 Instruct (70B)", 128000), | ||
LLAMA_3_1_8B( "meta-llama/llama-3-1-8b-instruct", "Llama 3.1 Instruct (8B)", 128000), | ||
LLAMA_2_7B("meta-llama/llama-2-70b-chat","Llama 2 Chat (70B)",4096), | ||
LLAMA_2_13B("meta-llama/llama-2-13b-chat","Llama 2 Chat (13B)",4096), | ||
GRANITE_13B_INSTRUCT_V2("ibm/granite-13b-instruct-v2","IBM Granite 13B Instruct V2",8192), | ||
GRANITE_13B_CHAT_V2("ibm/granite-13b-chat-v2","IBM Granite 13B Chat V2",8192), | ||
GRANITE_20B_MULTILINGUAL("ibm/granite-20b-multilingual","IBM Granite 20B Multilingual",8192); | ||
|
||
private final String code; | ||
private final String description; | ||
private final int maxTokens; | ||
|
||
WatsonxCompletionModel(String code, String description, int maxTokens) { | ||
this.code = code; | ||
this.description = description; | ||
this.maxTokens = maxTokens; | ||
} | ||
|
||
public String getCode() { | ||
return code; | ||
} | ||
|
||
public String getDescription() { | ||
return description; | ||
} | ||
|
||
public int getMaxTokens() { | ||
return maxTokens; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return description; | ||
} | ||
|
||
public static WatsonxCompletionModel findByCode(String code) { | ||
return Arrays.stream(WatsonxCompletionModel.values()) | ||
.filter(item -> item.getCode().equals(code)) | ||
.findFirst().orElseThrow(); | ||
} | ||
} | ||
|
Oops, something went wrong.