diff --git a/build.gradle b/build.gradle index 8d479256..f58a907f 100644 --- a/build.gradle +++ b/build.gradle @@ -101,12 +101,12 @@ dependencies { implementation 'com.google.inject:guice:5.1.0' // Environment variables implementation 'io.github.cdimascio:java-dotenv:5.2.2' - // Gson retrofit converter - implementation "com.squareup.retrofit2:converter-gson:2.9.0" - // Retrofit guava adapter + // Retrofit implementation 'com.squareup.retrofit2:adapter-guava:2.9.0' implementation 'com.squareup.retrofit2:converter-jackson:2.9.0' testImplementation 'com.squareup.retrofit2:retrofit-mock:2.9.0' + implementation 'com.squareup.okhttp3:logging-interceptor:4.9.2' + implementation "com.squareup.retrofit2:converter-gson:2.9.0" } // Testing related dependencies diff --git a/src/main/java/ai/knowly/langtorch/llm/Utils.java b/src/main/java/ai/knowly/langtorch/llm/Utils.java deleted file mode 100644 index df7392e9..00000000 --- a/src/main/java/ai/knowly/langtorch/llm/Utils.java +++ /dev/null @@ -1,34 +0,0 @@ -package ai.knowly.langtorch.llm; - -import com.google.common.flogger.FluentLogger; -import io.github.cdimascio.dotenv.Dotenv; -import java.util.Optional; - -public class Utils { - public static void logPartialApiKey(FluentLogger logger, String provider, String apiKey) { - logger.atInfo().log( - "Using %s API key: ***************" + apiKey.substring(apiKey.length() - 6), provider); - } - - public static String getOpenAIApiKeyFromEnv() { - return getOpenAIApiKeyFromEnv(Optional.empty()); - } - - public static String getOpenAIApiKeyFromEnv(Optional logger) { - Dotenv dotenv = Dotenv.configure().ignoreIfMissing().load(); - String openaiApiKey = dotenv.get("OPENAI_API_KEY"); - logger.ifPresent(l -> logPartialApiKey(l, "OpenAI", openaiApiKey)); - return openaiApiKey; - } - - public static String getCohereAIApiKeyFromEnv() { - return getCohereAIApiKeyFromEnv(Optional.empty()); - } - - public static String getCohereAIApiKeyFromEnv(Optional logger) { - Dotenv dotenv = Dotenv.configure().ignoreIfMissing().load(); - String openaiApiKey = dotenv.get("COHERE_API_KEY"); - logger.ifPresent(l -> logPartialApiKey(l, "CohereAI", openaiApiKey)); - return openaiApiKey; - } -} diff --git a/src/main/java/ai/knowly/langtorch/llm/processor/openai/OpenAIServiceProvider.java b/src/main/java/ai/knowly/langtorch/llm/processor/openai/OpenAIServiceProvider.java index 0f303faf..114a3b3e 100644 --- a/src/main/java/ai/knowly/langtorch/llm/processor/openai/OpenAIServiceProvider.java +++ b/src/main/java/ai/knowly/langtorch/llm/processor/openai/OpenAIServiceProvider.java @@ -1,6 +1,6 @@ package ai.knowly.langtorch.llm.processor.openai; -import ai.knowly.langtorch.llm.Utils; +import ai.knowly.langtorch.utils.ApiKeyUtils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIApi; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import com.google.common.flogger.FluentLogger; @@ -17,7 +17,7 @@ public static OpenAIApi createOpenAiAPI(String apiKey) { public static OpenAIApi createOpenAiAPI() { return OpenAIService.buildApi( - Utils.getOpenAIApiKeyFromEnv(Optional.of(logger)), DEFAULT_TIMEOUT); + ApiKeyUtils.getOpenAIApiKeyFromEnv(Optional.of(logger)), DEFAULT_TIMEOUT); } public static OpenAIService createOpenAIService(String apiKey) { diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAPI.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAPI.java new file mode 100644 index 00000000..a36f5736 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAPI.java @@ -0,0 +1,36 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import com.google.common.util.concurrent.ListenableFuture; +import java.util.List; +import retrofit2.http.Body; +import retrofit2.http.GET; +import retrofit2.http.POST; +import retrofit2.http.Query; + +public interface PineconeAPI { + + @POST("/vectors/upsert") + ListenableFuture upsert(@Body UpsertRequest request); + + @POST("/query") + ListenableFuture query(@Body QueryRequest request); + + @POST("/vectors/delete") + ListenableFuture delete(@Body DeleteRequest request); + + @GET("/vectors/fetch") + ListenableFuture fetch( + @Query("namespace") String namespace, @Query("ids") List ids); + + @POST("/vectors/update") + ListenableFuture update(@Body UpdateRequest request); +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAuthenticationInterceptor.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAuthenticationInterceptor.java new file mode 100644 index 00000000..bd0bf15d --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeAuthenticationInterceptor.java @@ -0,0 +1,31 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import java.io.IOException; +import java.util.Objects; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +/** OkHttp Interceptor that adds an authorization header */ +public class PineconeAuthenticationInterceptor implements Interceptor { + + private final String apiKey; + + PineconeAuthenticationInterceptor(String apiKey) { + Objects.requireNonNull(apiKey, "Pinecone API required"); + this.apiKey = apiKey; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = + chain + .request() + .newBuilder() + .header("accept", "application/json") + .header("content-type", "application/json") + .header("Api-Key", apiKey) + .build(); + return chain.proceed(request); + } +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeService.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeService.java new file mode 100644 index 00000000..e40a7767 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/PineconeService.java @@ -0,0 +1,150 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.google.common.flogger.FluentLogger; +import com.google.common.util.concurrent.ListenableFuture; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import okhttp3.*; +import okhttp3.OkHttpClient.Builder; +import okhttp3.logging.HttpLoggingInterceptor; +import retrofit2.HttpException; +import retrofit2.Retrofit; +import retrofit2.adapter.guava.GuavaCallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +public class PineconeService { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + + private static final ObjectMapper mapper = defaultObjectMapper(); + private final PineconeAPI api; + private final ExecutorService executorService; + + public PineconeService(final PineconeServiceConfig pineconeServiceConfig) { + ObjectMapper mapper = defaultObjectMapper(); + OkHttpClient client = buildClient(pineconeServiceConfig); + Retrofit retrofit = defaultRetrofit(pineconeServiceConfig.endpoint(), client, mapper); + + this.api = retrofit.create(PineconeAPI.class); + this.executorService = client.dispatcher().executorService(); + } + + public PineconeService(final PineconeAPI api) { + this.api = api; + this.executorService = null; + } + + public PineconeService(final PineconeAPI api, final ExecutorService executorService) { + this.api = api; + this.executorService = executorService; + } + + public static T execute(ListenableFuture apiCall) { + try { + return apiCall.get(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + if (e.getCause() instanceof HttpException) { + HttpException httpException = (HttpException) e.getCause(); + try { + String errorBody = httpException.response().errorBody().string(); + logger.atSevere().log("HTTP Error: %s", errorBody); + throw new RuntimeException(errorBody); + } catch (IOException ioException) { + logger.atSevere().withCause(ioException).log("Error while reading errorBody"); + } + } + throw new RuntimeException(e); + } + } + + public static ObjectMapper defaultObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + return mapper; + } + + public static OkHttpClient buildClient(PineconeServiceConfig pineconeServiceConfig) { + logger.atInfo().log("Pinecone:" + pineconeServiceConfig.apiKey()); + Builder builder = + new Builder() + .addInterceptor(new PineconeAuthenticationInterceptor(pineconeServiceConfig.apiKey())) + .connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS)) + .readTimeout(pineconeServiceConfig.timeoutDuration().toMillis(), TimeUnit.MILLISECONDS); + + if (pineconeServiceConfig.enableLogging()) { + HttpLoggingInterceptor logging = new HttpLoggingInterceptor(); + builder.addInterceptor(logging.setLevel(HttpLoggingInterceptor.Level.BODY)); + } + return builder.build(); + } + + public static Retrofit defaultRetrofit( + String endpoint, OkHttpClient client, ObjectMapper mapper) { + return new Retrofit.Builder() + .baseUrl(endpoint.startsWith("https://") ? endpoint : "https://" + endpoint) + .client(client) + .addConverterFactory(JacksonConverterFactory.create(mapper)) + .addCallAdapterFactory(GuavaCallAdapterFactory.create()) + .build(); + } + + public UpsertResponse upsert(UpsertRequest request) { + return execute(api.upsert(request)); + } + + public ListenableFuture upsertAsync(UpsertRequest request) { + return api.upsert(request); + } + + public QueryResponse query(QueryRequest request) { + return execute(api.query(request)); + } + + public ListenableFuture queryAsync(QueryRequest request) { + return api.query(request); + } + + public DeleteResponse delete(DeleteRequest request) { + return execute(api.delete(request)); + } + + public ListenableFuture queryAsync(DeleteRequest request) { + return api.delete(request); + } + + public FetchResponse fetch(FetchRequest request) { + return execute(api.fetch(request.getNamespace(), request.getIds())); + } + + public ListenableFuture fetchAsync(FetchRequest request) { + return api.fetch(request.getNamespace(), request.getIds()); + } + + public UpdateResponse update(UpdateRequest request) { + return execute(api.update(request)); + } + + public ListenableFuture updateAsync(UpdateRequest request) { + return api.update(request); + } +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/PineconeServiceConfig.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/PineconeServiceConfig.java new file mode 100644 index 00000000..4f592931 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/PineconeServiceConfig.java @@ -0,0 +1,34 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema; + +import com.google.auto.value.AutoValue; +import java.time.Duration; + +@AutoValue +public abstract class PineconeServiceConfig { + public static Builder builder() { + return new AutoValue_PineconeServiceConfig.Builder() + .setTimeoutDuration(Duration.ofSeconds(10)) + .setEnableLogging(false); + } + + public abstract String apiKey(); + + public abstract String endpoint(); + + public abstract Duration timeoutDuration(); + + public abstract boolean enableLogging(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEndpoint(String endpoint); + + public abstract Builder setApiKey(String newApiKey); + + public abstract Builder setTimeoutDuration(Duration timeoutDuration); + + public abstract Builder setEnableLogging(boolean enableLogging); + + public abstract PineconeServiceConfig build(); + } +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/SparseValues.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/SparseValues.java new file mode 100644 index 00000000..21e0a89a --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/SparseValues.java @@ -0,0 +1,20 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +@NoArgsConstructor +@AllArgsConstructor +public class SparseValues { + @JsonProperty("indices") + private List indices; + + @JsonProperty("values") + private List values; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/Vector.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/Vector.java new file mode 100644 index 00000000..f8617e44 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/Vector.java @@ -0,0 +1,27 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +@NoArgsConstructor +@AllArgsConstructor +public class Vector { + @JsonProperty("id") + private String id; + + @JsonProperty("values") + private List values; + + @JsonProperty("sparseValues") + private SparseValues sparseValues; + + @JsonProperty("metadata") + private Map metadata; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteRequest.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteRequest.java new file mode 100644 index 00000000..117b66ce --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteRequest.java @@ -0,0 +1,23 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class DeleteRequest { + @JsonProperty("ids") + private List ids; + + @JsonProperty("deleteAll") + private boolean deleteAll; + + @JsonProperty("namespace") + private String namespace; + + @JsonProperty("filter") + private Map filter; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteResponse.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteResponse.java new file mode 100644 index 00000000..8823699c --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/delete/DeleteResponse.java @@ -0,0 +1,8 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete; + +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +public class DeleteResponse {} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchRequest.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchRequest.java new file mode 100644 index 00000000..81e7e2a6 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchRequest.java @@ -0,0 +1,16 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class FetchRequest { + @JsonProperty("ids") + private List ids; + + @JsonProperty("namespace") + private String namespace; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchResponse.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchResponse.java new file mode 100644 index 00000000..16e7f45a --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/fetch/FetchResponse.java @@ -0,0 +1,19 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class FetchResponse { + @JsonProperty("vectors") + private Map vectors; + + @JsonProperty("namespace") + private String namespace; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/Match.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/Match.java new file mode 100644 index 00000000..531003f0 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/Match.java @@ -0,0 +1,29 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.SparseValues; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class Match { + @JsonProperty("id") + private String id; + + @JsonProperty("score") + private Double score; + + @JsonProperty("values") + private List values; + + @JsonProperty("sparseValues") + private SparseValues sparseValues; + + @JsonProperty("metadata") + private Map metadata; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryRequest.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryRequest.java new file mode 100644 index 00000000..3481959f --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryRequest.java @@ -0,0 +1,22 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.SparseValues; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class QueryRequest { + private String namespace; + private long topK; + // The filter to apply. You can use vector metadata to limit your search. See + // https://www.pinecone.io/docs/metadata-filtering/. + private Map filter; + private boolean includeValues; + private boolean includeMetadata; + private List vector; + private SparseValues sparseVector; + private String id; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryResponse.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryResponse.java new file mode 100644 index 00000000..8e926ae1 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/query/QueryResponse.java @@ -0,0 +1,18 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class QueryResponse { + @JsonProperty("matches") + private List matches; + + @JsonProperty("namespace") + private String namespace; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateRequest.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateRequest.java new file mode 100644 index 00000000..87a1e139 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateRequest.java @@ -0,0 +1,29 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.SparseValues; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import lombok.Data; +import lombok.NonNull; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class UpdateRequest { + @JsonProperty("id") + @NonNull + private String id; + + @JsonProperty("values") + private List values; + + @JsonProperty("sparseValues") + private SparseValues sparseValues; + + @JsonProperty("setMetadata") + private Map setMetadata; + + @JsonProperty("namespace") + private String namespace; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateResponse.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateResponse.java new file mode 100644 index 00000000..df131068 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/update/UpdateResponse.java @@ -0,0 +1,9 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +public class UpdateResponse {} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertRequest.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertRequest.java new file mode 100644 index 00000000..7989244c --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertRequest.java @@ -0,0 +1,13 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import java.util.List; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class UpsertRequest { + private List vectors; + private String namespace; +} diff --git a/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertResponse.java b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertResponse.java new file mode 100644 index 00000000..17414700 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/schema/dto/upsert/UpsertResponse.java @@ -0,0 +1,14 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class UpsertResponse { + @JsonProperty("upsertedCount") + long upsertedCount; +} diff --git a/src/main/java/ai/knowly/langtorch/utils/ApiKey.java b/src/main/java/ai/knowly/langtorch/utils/ApiKey.java new file mode 100644 index 00000000..405b249f --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/utils/ApiKey.java @@ -0,0 +1,7 @@ +package ai.knowly.langtorch.utils; + +public enum ApiKey { + OPENAI_API_KEY, + PINECONE_API_KEY, + COHERE_API_KEY, +} diff --git a/src/main/java/ai/knowly/langtorch/utils/ApiKeyUtils.java b/src/main/java/ai/knowly/langtorch/utils/ApiKeyUtils.java new file mode 100644 index 00000000..2a33c5b3 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/utils/ApiKeyUtils.java @@ -0,0 +1,48 @@ +package ai.knowly.langtorch.utils; + +import com.google.common.flogger.FluentLogger; +import io.github.cdimascio.dotenv.Dotenv; +import java.util.Optional; + +public class ApiKeyUtils { + public static void logPartialApiKey(FluentLogger logger, String provider, String apiKey) { + logger.atInfo().log( + "Using %s API key: ***************" + apiKey.substring(apiKey.length() - 6), provider); + } + + public static String getOpenAIApiKeyFromEnv() { + return getOpenAIApiKeyFromEnv(Optional.empty()); + } + + public static String getOpenAIApiKeyFromEnv(Optional logger) { + String keyFromEnv = getKeyFromEnv(ApiKey.OPENAI_API_KEY); + logger.ifPresent(l -> logPartialApiKey(l, ApiKey.OPENAI_API_KEY.name(), keyFromEnv)); + return keyFromEnv; + } + + public static String getPineconeKeyFromEnv(Optional logger) { + String keyFromEnv = getKeyFromEnv(ApiKey.PINECONE_API_KEY); + logger.ifPresent(l -> logPartialApiKey(l, ApiKey.PINECONE_API_KEY.name(), keyFromEnv)); + return keyFromEnv; + } + + public static String getPineconeKeyFromEnv() { + return getPineconeKeyFromEnv(Optional.empty()); + } + + public static String getCohereAIApiKeyFromEnv() { + return getCohereAIApiKeyFromEnv(Optional.empty()); + } + + public static String getCohereAIApiKeyFromEnv(Optional logger) { + String keyFromEnv = getKeyFromEnv(ApiKey.COHERE_API_KEY); + logger.ifPresent(l -> logPartialApiKey(l, ApiKey.COHERE_API_KEY.name(), keyFromEnv)); + return keyFromEnv; + } + + private static String getKeyFromEnv(ApiKey apiKey) { + Dotenv dotenv = Dotenv.configure().ignoreIfMissing().load(); + String key = dotenv.get(apiKey.name()); + return key; + } +} diff --git a/src/main/resources/.env.example b/src/main/resources/.env.example index dd7959b2..9b6d3464 100644 --- a/src/main/resources/.env.example +++ b/src/main/resources/.env.example @@ -1,2 +1,3 @@ OPENAI_API_KEY= -COHERE_API_KEY= \ No newline at end of file +COHERE_API_KEY= +PINECONE_API_KEY= \ No newline at end of file diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/TestingUtils.java b/src/test/java/ai/knowly/langtorch/TestingUtils.java similarity index 93% rename from src/test/java/ai/knowly/langtorch/llm/integration/openai/TestingUtils.java rename to src/test/java/ai/knowly/langtorch/TestingUtils.java index 6d4c336c..2456c509 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/TestingUtils.java +++ b/src/test/java/ai/knowly/langtorch/TestingUtils.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.llm.integration.openai; +package ai.knowly.langtorch; import com.google.common.collect.ImmutableMap; import java.io.InputStream; diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ChatCompletionTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ChatCompletionTest.java index 8e8fcd3c..3013b69c 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ChatCompletionTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ChatCompletionTest.java @@ -2,24 +2,24 @@ import static org.junit.jupiter.api.Assertions.*; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.completion.chat.ChatCompletionChoice; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.completion.chat.ChatCompletionRequest; import ai.knowly.langtorch.schema.chat.ChatMessage; import ai.knowly.langtorch.schema.chat.SystemMessage; +import ai.knowly.langtorch.utils.ApiKeyUtils; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class ChatCompletionTest { @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createChatCompletion() { - String token = Utils.getOpenAIApiKeyFromEnv(); + String token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); OpenAIService service = new OpenAIService(token); final List messages = new ArrayList<>(); messages.add(SystemMessage.of("You are a dog and will speak as such.")); diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/CompletionTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/CompletionTest.java index 82050aef..664fce32 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/CompletionTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/CompletionTest.java @@ -2,21 +2,21 @@ import static org.junit.jupiter.api.Assertions.*; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.completion.CompletionChoice; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.completion.CompletionRequest; +import ai.knowly.langtorch.utils.ApiKeyUtils; import java.util.HashMap; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class CompletionTest { @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createCompletion() { - String token = Utils.getOpenAIApiKeyFromEnv(); + String token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); OpenAIService service = new OpenAIService(token); CompletionRequest completionRequest = CompletionRequest.builder() diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/EditTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/EditTest.java index fc6f8b3f..51399a18 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/EditTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/EditTest.java @@ -2,20 +2,20 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.OpenAIHttpException; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.edit.EditRequest; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.edit.EditResult; +import ai.knowly.langtorch.utils.ApiKeyUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class EditTest { @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void edit() throws OpenAIHttpException { - String token = Utils.getOpenAIApiKeyFromEnv(); + String token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); OpenAIService service = new OpenAIService(token); EditRequest request = EditRequest.builder() diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/EmbeddingTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/EmbeddingTest.java index afb07663..ef898484 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/EmbeddingTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/EmbeddingTest.java @@ -2,21 +2,21 @@ import static org.junit.jupiter.api.Assertions.assertFalse; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.embedding.Embedding; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.embedding.EmbeddingRequest; +import ai.knowly.langtorch.utils.ApiKeyUtils; import java.util.Collections; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class EmbeddingTest { @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createEmbeddings() { - String token = Utils.getOpenAIApiKeyFromEnv(); + String token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); OpenAIService service = new OpenAIService(token); EmbeddingRequest embeddingRequest = EmbeddingRequest.builder() diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ImageTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ImageTest.java index 397dcb1f..19553265 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ImageTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ImageTest.java @@ -3,17 +3,18 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.image.CreateImageEditRequest; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.image.CreateImageRequest; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.image.CreateImageVariationRequest; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.image.Image; +import ai.knowly.langtorch.utils.ApiKeyUtils; import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class ImageTest { static String filePath = "src/test/resources/penguin.png"; @@ -25,12 +26,11 @@ class ImageTest { @BeforeEach void setUp() { - token = Utils.getOpenAIApiKeyFromEnv(); + token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); service = new OpenAIService(token); } @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createImageUrl() { CreateImageRequest createImageRequest = CreateImageRequest.builder().prompt("penguin").n(3).size("256x256").user("testing").build(); @@ -41,7 +41,6 @@ void createImageUrl() { } @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createImageBase64() { CreateImageRequest createImageRequest = CreateImageRequest.builder() @@ -56,7 +55,6 @@ void createImageBase64() { } @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createImageEdit() { CreateImageEditRequest createImageRequest = CreateImageEditRequest.builder() @@ -74,7 +72,6 @@ void createImageEdit() { } @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createImageEditWithMask() { CreateImageEditRequest createImageRequest = CreateImageEditRequest.builder() @@ -91,7 +88,6 @@ void createImageEditWithMask() { } @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createImageVariation() { CreateImageVariationRequest createImageVariationRequest = CreateImageVariationRequest.builder() diff --git a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ModerationTest.java b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ModerationTest.java index f47743c4..6d160b1c 100644 --- a/src/test/java/ai/knowly/langtorch/llm/integration/openai/ModerationTest.java +++ b/src/test/java/ai/knowly/langtorch/llm/integration/openai/ModerationTest.java @@ -2,18 +2,18 @@ import static org.junit.jupiter.api.Assertions.assertTrue; -import ai.knowly.langtorch.llm.Utils; import ai.knowly.langtorch.llm.integration.openai.service.OpenAIService; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.moderation.Moderation; import ai.knowly.langtorch.llm.integration.openai.service.schema.dto.moderation.ModerationRequest; +import ai.knowly.langtorch.utils.ApiKeyUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIf; +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") class ModerationTest { @Test - @EnabledIf("ai.knowly.langtorch.llm.integration.openai.TestingUtils#testWithHttpRequestEnabled") void createModeration() { - String token = Utils.getOpenAIApiKeyFromEnv(); + String token = ApiKeyUtils.getOpenAIApiKeyFromEnv(); OpenAIService service = new OpenAIService(token); ModerationRequest moderationRequest = ModerationRequest.builder() diff --git a/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/DeleteTest.java b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/DeleteTest.java new file mode 100644 index 00000000..d65a246b --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/DeleteTest.java @@ -0,0 +1,59 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import static com.google.common.truth.Truth.assertThat; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.delete.DeleteResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import ai.knowly.langtorch.utils.ApiKeyUtils; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") +final class DeleteTest { + @Test + void test() { + // Arrange. + String token = ApiKeyUtils.getPineconeKeyFromEnv(); + PineconeService service = + new PineconeService( + PineconeServiceConfig.builder() + .setApiKey(token) + .setEndpoint("https://test1-c4943a1.svc.us-west4-gcp-free.pinecone.io") + .build()); + + UpsertRequest upsertRequest = + UpsertRequest.builder() + .setVectors( + List.of( + Vector.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setMetadata(Map.of("key", "val")) + .build())) + .setNamespace("namespace") + .build(); + + DeleteRequest deleteRequest = + DeleteRequest.builder().setIds(List.of("test2")).setNamespace("namespace").build(); + + QueryRequest queryRequest = + QueryRequest.builder().setNamespace("namespace").setId("test2").setTopK(1).build(); + + // Act. + UpsertResponse upsertResponse = service.upsert(upsertRequest); + DeleteResponse deleteResponse = service.delete(deleteRequest); + QueryResponse queryResponse = service.query(queryRequest); + // Assert. + assertThat(upsertResponse.getUpsertedCount()).isEqualTo(1); + assertThat(deleteResponse).isNotNull(); + assertThat(queryResponse.getMatches()).isEmpty(); + } +} diff --git a/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/FetchTest.java b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/FetchTest.java new file mode 100644 index 00000000..87227cdf --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/FetchTest.java @@ -0,0 +1,51 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import static com.google.common.truth.Truth.assertThat; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import ai.knowly.langtorch.utils.ApiKeyUtils; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") +class FetchTest { + @Test + void test() { + // Arrange. + String token = ApiKeyUtils.getPineconeKeyFromEnv(); + PineconeService service = + new PineconeService( + PineconeServiceConfig.builder() + .setApiKey(token) + .setEndpoint("https://test1-c4943a1.svc.us-west4-gcp-free.pinecone.io") + .build()); + + Vector vector = + Vector.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setMetadata(Map.of("key", "val")) + .build(); + + UpsertRequest upsertRequest = + UpsertRequest.builder().setVectors(List.of(vector)).setNamespace("namespace").build(); + + FetchRequest fetchRequest = + FetchRequest.builder().setIds(List.of("test2")).setNamespace("namespace").build(); + + // Act. + UpsertResponse response = service.upsert(upsertRequest); + FetchResponse fetchResponse = service.fetch(fetchRequest); + + // Assert. + assertThat(response.getUpsertedCount()).isEqualTo(1); + assertThat(fetchResponse.getVectors().get("test2")).isEqualTo(vector); + } +} diff --git a/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/QueryTest.java b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/QueryTest.java new file mode 100644 index 00000000..ad74492f --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/QueryTest.java @@ -0,0 +1,63 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import static com.google.common.truth.Truth.assertThat; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.query.QueryResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import ai.knowly.langtorch.utils.ApiKeyUtils; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") +final class QueryTest { + @Test + void test() { + // Arrange. + String token = ApiKeyUtils.getPineconeKeyFromEnv(); + PineconeService service = + new PineconeService( + PineconeServiceConfig.builder() + .setApiKey(token) + .setEndpoint("https://test1-c4943a1.svc.us-west4-gcp-free.pinecone.io") + .build()); + + UpsertRequest upsertRequest = + UpsertRequest.builder() + .setVectors( + List.of( + Vector.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setMetadata(Map.of("key", "val")) + .build())) + .setNamespace("namespace") + .build(); + + QueryRequest queryRequest = + QueryRequest.builder() + .setVector(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setTopK(3) + .setNamespace("namespace") + .setIncludeValues(true) + .setIncludeMetadata(true) + .build(); + + // Act. + UpsertResponse response = service.upsert(upsertRequest); + QueryResponse queryResponse = service.query(queryRequest); + + // Assert. + assertThat(response.getUpsertedCount()).isEqualTo(1); + assertThat(queryResponse.getMatches()).isNotEmpty(); + assertThat(queryResponse.getMatches().get(0).getId()).isEqualTo("test2"); + assertThat(queryResponse.getMatches().get(0).getValues()).isNotEmpty(); + assertThat(queryResponse.getMatches().get(0).getMetadata()).isNotEmpty(); + assertThat(queryResponse.getNamespace()).isEqualTo("namespace"); + } +} diff --git a/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpdateTest.java b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpdateTest.java new file mode 100644 index 00000000..60b70d9c --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpdateTest.java @@ -0,0 +1,63 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import static com.google.common.truth.Truth.assertThat; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.fetch.FetchResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateResponse; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import ai.knowly.langtorch.utils.ApiKeyUtils; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") +public class UpdateTest { + @Test + void test() { + // Arrange. + String token = ApiKeyUtils.getPineconeKeyFromEnv(); + PineconeService service = + new PineconeService( + PineconeServiceConfig.builder() + .setApiKey(token) + .setEndpoint("https://test1-c4943a1.svc.us-west4-gcp-free.pinecone.io") + .build()); + + UpsertRequest upsertRequest = + UpsertRequest.builder() + .setVectors( + List.of( + Vector.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setMetadata(Map.of("key", "val")) + .build())) + .setNamespace("testr2") + .build(); + + UpdateRequest updateRequest = + UpdateRequest.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9)) + .setNamespace("testr2") + .build(); + + FetchRequest fetchRequest = + FetchRequest.builder().setIds(List.of("test2")).setNamespace("testr2").build(); + + // Act. + UpsertResponse response = service.upsert(upsertRequest); + UpdateResponse updateResponse = service.update(updateRequest); + FetchResponse fetchResponse = service.fetch(fetchRequest); + // Assert. + assertThat(response.getUpsertedCount()).isEqualTo(1); + assertThat(fetchResponse.getVectors().get("test2").getValues()) + .isEqualTo(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9)); + } +} diff --git a/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpsertTest.java b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpsertTest.java new file mode 100644 index 00000000..025a3757 --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/store/vectordb/integration/pinecone/UpsertTest.java @@ -0,0 +1,45 @@ +package ai.knowly.langtorch.store.vectordb.integration.pinecone; + +import static com.google.common.truth.Truth.assertThat; + +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.PineconeServiceConfig; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.Vector; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest; +import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse; +import ai.knowly.langtorch.utils.ApiKeyUtils; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +@EnabledIf("ai.knowly.langtorch.TestingUtils#testWithHttpRequestEnabled") +class UpsertTest { + @Test + void upsertTest() { + // Arrange. + String token = ApiKeyUtils.getPineconeKeyFromEnv(); + PineconeService service = + new PineconeService( + PineconeServiceConfig.builder() + .setApiKey(token) + .setEndpoint("https://test1-c4943a1.svc.us-west4-gcp-free.pinecone.io") + .build()); + + UpsertRequest upsertRequest = + UpsertRequest.builder() + .setVectors( + List.of( + Vector.builder() + .setId("test2") + .setValues(List.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)) + .setMetadata(Map.of("key", "val")) + .build())) + .setNamespace("testr2") + .build(); + + // Act. + UpsertResponse response = service.upsert(upsertRequest); + // Assert. + assertThat(response.getUpsertedCount()).isEqualTo(1); + } +}