Skip to content

Commit

Permalink
Migrate to the JsonSchemaElement API
Browse files Browse the repository at this point in the history
  • Loading branch information
edeandrea committed Nov 22, 2024
1 parent 2103f07 commit 0dd1268
Show file tree
Hide file tree
Showing 15 changed files with 543 additions and 102 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonArraySchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonArraySchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var a = (JsonArraySchema) obj;
return new Serialized(a.description(), a.items());
}

@Override
public JsonArraySchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var a = (JsonArraySchemaObjectSubstitution.Serialized) obj;
return JsonArraySchema.builder()
.description(a.description)
.items(a.items)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;
private final JsonSchemaElement items;

@RecordableConstructor
public Serialized(String description, JsonSchemaElement items) {
this.description = description;
this.items = items;
}

public String getDescription() {
return description;
}

public JsonSchemaElement getItems() {
return items;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonBooleanSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonBooleanSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var b = (JsonBooleanSchema) obj;
return new Serialized(b.description());
}

@Override
public JsonBooleanSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var b = (JsonBooleanSchemaObjectSubstitution.Serialized) obj;
return JsonBooleanSchema.builder()
.description(b.description)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;

@RecordableConstructor
public Serialized(String description) {
this.description = description;
}

public String getDescription() {
return description;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;

import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonEnumSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonEnumSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var e = (JsonEnumSchema) obj;
return new Serialized(e.description(), e.enumValues());
}

@Override
public JsonEnumSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var e = (JsonEnumSchemaObjectSubstitution.Serialized) obj;
return JsonEnumSchema.builder()
.description(e.description)
.enumValues(e.enumValues)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;
private final List<String> enumValues;

@RecordableConstructor
public Serialized(String description, List<String> enumValues) {
this.description = description;
this.enumValues = enumValues;
}

public String getDescription() {
return description;
}

public List<String> getEnumValues() {
return enumValues;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonIntegerSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonIntegerSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var i = (JsonIntegerSchema) obj;
return new Serialized(i.description());
}

@Override
public JsonIntegerSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var i = (JsonIntegerSchemaObjectSubstitution.Serialized) obj;
return JsonIntegerSchema.builder()
.description(i.description)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;

@RecordableConstructor
public Serialized(String description) {
this.description = description;
}

public String getDescription() {
return description;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonNumberSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonNumberSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var n = (JsonNumberSchema) obj;
return new Serialized(n.description());
}

@Override
public JsonNumberSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var n = (JsonNumberSchemaObjectSubstitution.Serialized) obj;
return JsonNumberSchema.builder()
.description(n.description)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;

@RecordableConstructor
public Serialized(String description) {
this.description = description;
}

public String getDescription() {
return description;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;
import java.util.Map;

import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonObjectSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
@Override
public JsonObjectSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var o = (JsonObjectSchema) obj;
return new Serialized(o.description(), o.properties(), o.required(), o.additionalProperties(), o.definitions());
}

@Override
public JsonObjectSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var o = (JsonObjectSchemaObjectSubstitution.Serialized) obj;
return JsonObjectSchema.builder()
.description(o.description)
.properties(o.properties)
.required(o.required)
.additionalProperties(o.additionalProperties)
.definitions(o.definitions)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String description;
private final Map<String, JsonSchemaElement> properties;
private final List<String> required;
private final Boolean additionalProperties;
private final Map<String, JsonSchemaElement> definitions;

@RecordableConstructor
public Serialized(String description, Map<String, JsonSchemaElement> properties, List<String> required,
Boolean additionalProperties, Map<String, JsonSchemaElement> definitions) {
this.description = description;
this.properties = properties;
this.required = required;
this.additionalProperties = additionalProperties;
this.definitions = definitions;
}

public String getDescription() {
return description;
}

public Map<String, JsonSchemaElement> getProperties() {
return properties;
}

public List<String> getRequired() {
return required;
}

public Boolean getAdditionalProperties() {
return additionalProperties;
}

public Map<String, JsonSchemaElement> getDefinitions() {
return definitions;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonReferenceSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution {
public JsonReferenceSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
var r = (JsonReferenceSchema) obj;
return new Serialized(r.reference());
}

public JsonReferenceSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
var r = (JsonReferenceSchemaObjectSubstitution.Serialized) obj;
return JsonReferenceSchema.builder()
.reference(r.reference)
.build();
}

public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized {
private final String reference;

@RecordableConstructor
public Serialized(String reference) {
this.reference = reference;
}

public String getReference() {
return reference;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import io.quarkus.runtime.ObjectSubstitution;

public sealed class JsonSchemaElementObjectSubstitution
implements ObjectSubstitution<JsonSchemaElement, JsonSchemaElementObjectSubstitution.Serialized>
permits JsonArraySchemaObjectSubstitution,
JsonBooleanSchemaObjectSubstitution,
JsonEnumSchemaObjectSubstitution,
JsonIntegerSchemaObjectSubstitution,
JsonNumberSchemaObjectSubstitution,
JsonObjectSchemaObjectSubstitution,
JsonReferenceSchemaObjectSubstitution,
JsonStringSchemaObjectSubstitution {

// Using ConcurrentHashMap in case multiple threads are using this class at the same time
// Not sure if this will ever happen
private final Map<Class<?>, JsonSchemaElementObjectSubstitution> substitutions = new ConcurrentHashMap<>(8);

@Override
public JsonSchemaElementObjectSubstitution.Serialized serialize(JsonSchemaElement obj) {
return getSubstitution(obj.getClass()).serialize(obj);
}

@Override
public JsonSchemaElement deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) {
return getSubstitution(obj.getClass()).deserialize(obj);
}

private JsonSchemaElementObjectSubstitution getSubstitution(Class<?> clazz) {
return this.substitutions.computeIfAbsent(clazz, c -> {
if (JsonArraySchema.class.isAssignableFrom(c)
|| JsonArraySchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonArraySchemaObjectSubstitution();
} else if (JsonBooleanSchema.class.isAssignableFrom(c)
|| JsonBooleanSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonBooleanSchemaObjectSubstitution();
} else if (JsonEnumSchema.class.isAssignableFrom(c)
|| JsonEnumSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonEnumSchemaObjectSubstitution();
} else if (JsonIntegerSchema.class.isAssignableFrom(c)
|| JsonIntegerSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonIntegerSchemaObjectSubstitution();
} else if (JsonNumberSchema.class.isAssignableFrom(c)
|| JsonNumberSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonNumberSchemaObjectSubstitution();
} else if (JsonObjectSchema.class.isAssignableFrom(c)
|| JsonObjectSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonObjectSchemaObjectSubstitution();
} else if (JsonReferenceSchema.class.isAssignableFrom(c)
|| JsonReferenceSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonReferenceSchemaObjectSubstitution();
} else if (JsonStringSchema.class.isAssignableFrom(c)
|| JsonStringSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) {
return new JsonStringSchemaObjectSubstitution();
}

// Handle unsupported types
throw new IllegalArgumentException("Unsupported type: %s".formatted(c.getName()));
});
}

public static sealed class Serialized
permits JsonArraySchemaObjectSubstitution.Serialized,
JsonBooleanSchemaObjectSubstitution.Serialized,
JsonEnumSchemaObjectSubstitution.Serialized,
JsonIntegerSchemaObjectSubstitution.Serialized,
JsonNumberSchemaObjectSubstitution.Serialized,
JsonObjectSchemaObjectSubstitution.Serialized,
JsonReferenceSchemaObjectSubstitution.Serialized,
JsonStringSchemaObjectSubstitution.Serialized {

}
}
Loading

0 comments on commit 0dd1268

Please sign in to comment.