Skip to content

Commit

Permalink
[watsonx.ai] Add truncate_input_tokens property
Browse files Browse the repository at this point in the history
  • Loading branch information
andreadimaio committed Oct 14, 2024
1 parent 7068836 commit 6b3edf4
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void test_embed_list_of_three_textsegment() throws Exception {

var input = "Embedding THIS!";
EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID,
List.of(input, input, input));
List.of(input, input, input), null);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
.body(mapper.writeValueAsString(request))
Expand Down Expand Up @@ -140,7 +140,7 @@ private List<Float> mockEmbeddingServer(String input) throws Exception {
.build();

EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID,
List.of(input));
List.of(input), null);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
.body(mapper.writeValueAsString(request))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty;
Expand Down Expand Up @@ -64,6 +65,7 @@ public class GenerationAllPropertiesTest extends WireMockAbstract {
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.truncate-input-tokens", "0")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.truncate-input-tokens", "10")
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@Override
Expand Down Expand Up @@ -102,6 +104,8 @@ void handlerBeforeEach() {
.includeStopSequence(false)
.build();

static EmbeddingParameters embeddingParameters = new EmbeddingParameters(10);

@Test
void check_config() throws Exception {
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();
Expand Down Expand Up @@ -133,6 +137,7 @@ void check_config() throws Exception {
assertEquals("@", runtimeConfig.chatModel().promptJoiner());
assertEquals(true, fixedRuntimeConfig.chatModel().promptFormatter());
assertEquals("my_super_embedding_model", runtimeConfig.embeddingModel().modelId());
assertEquals(10, runtimeConfig.embeddingModel().truncateInputTokens().orElse(null));
}

@Test
Expand All @@ -158,7 +163,7 @@ void check_embedding_model() throws Exception {
String projectId = config.projectId();

EmbeddingRequest request = new EmbeddingRequest(modelId, projectId,
List.of("Embedding THIS!"));
List.of("Embedding THIS!"), embeddingParameters);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200, "aaaa-mm-dd")
.body(mapper.writeValueAsString(request))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void check_config() throws Exception {
assertTrue(runtimeConfig.chatModel().includeStopSequence().isEmpty());
assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType());
assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, runtimeConfig.embeddingModel().modelId());
assertTrue(runtimeConfig.embeddingModel().truncateInputTokens().isEmpty());
}

@Test
Expand All @@ -124,7 +125,7 @@ void check_embedding_model() throws Exception {
String projectId = config.projectId();

EmbeddingRequest request = new EmbeddingRequest(modelId, projectId,
List.of("Embedding THIS!"));
List.of("Embedding THIS!"), null);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
.body(mapper.writeValueAsString(request))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.TokenCountEstimator;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse.Result;
Expand All @@ -28,6 +29,7 @@
public class WatsonxEmbeddingModel implements EmbeddingModel, TokenCountEstimator {

private final String modelId, projectId, version;
private final EmbeddingParameters parameters;
private final WatsonxRestApi client;

public WatsonxEmbeddingModel(Builder builder) {
Expand All @@ -49,6 +51,11 @@ public WatsonxEmbeddingModel(Builder builder) {
this.modelId = builder.modelId;
this.projectId = builder.projectId;
this.version = builder.version;

if (builder.truncateInputTokens != null)
this.parameters = new EmbeddingParameters(builder.truncateInputTokens);
else
this.parameters = null;
}

@Override
Expand All @@ -61,7 +68,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
.map(TextSegment::text)
.collect(Collectors.toList());

EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs);
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs, parameters);
EmbeddingResponse result = retryOn(new Callable<EmbeddingResponse>() {
@Override
public EmbeddingResponse call() throws Exception {
Expand Down Expand Up @@ -102,6 +109,7 @@ public static final class Builder {
private String version;
private String projectId;
private Duration timeout;
private Integer truncateInputTokens;
private boolean logResponses;
private boolean logRequests;
private URL url;
Expand All @@ -127,6 +135,11 @@ public Builder timeout(Duration timeout) {
return this;
}

public Builder truncateInputTokens(Integer truncateInputTokens) {
this.truncateInputTokens = truncateInputTokens;
return this;
}

public Builder logRequests(boolean logRequests) {
this.logRequests = logRequests;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package io.quarkiverse.langchain4j.watsonx.bean;

public record EmbeddingParameters(Integer truncateInputTokens) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import java.util.List;

public record EmbeddingRequest(String modelId, String projectId, List<String> inputs) {
public record EmbeddingRequest(String modelId, String projectId, List<String> inputs, EmbeddingParameters parameters) {

public EmbeddingRequest of(String modelId, String projectId, String input) {
return new EmbeddingRequest(modelId, projectId, List.of(input));
public EmbeddingRequest of(String modelId, String projectId, String input, EmbeddingParameters parameters) {
return new EmbeddingRequest(modelId, projectId, List.of(input), parameters);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jWatsonxConfig runtimeC
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), watsonConfig.logResponses()))
.version(watsonConfig.version())
.projectId(watsonConfig.projectId())
.modelId(embeddingModelConfig.modelId());
.modelId(embeddingModelConfig.modelId())
.truncateInputTokens(embeddingModelConfig.truncateInputTokens().orElse(null));

return new Supplier<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,25 @@ public interface EmbeddingModelConfig {
/**
* Model id to use.
*
* To view the complete model list, <a href=
* To view the complete model list,
* <a href=
* "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx&audience=wdp">click
* here</a>.
*/
@WithDefault("ibm/slate-125m-english-rtrvr")
String modelId();

/**
* Represents the maximum number of input tokens accepted. This can be used to avoid requests failing due to input being
* longer
* than configured limits. If the text is truncated, then it truncates the end of the input (on the right), so the start of
* the
* input will remain the same. If this value exceeds the maximum sequence length (refer to the documentation to find this
* value
* for the model) then the call will fail if the total number of tokens exceeds the maximum sequence length.
*/
Optional<Integer> truncateInputTokens();

/**
* Whether embedding model requests should be logged.
*/
Expand Down

0 comments on commit 6b3edf4

Please sign in to comment.