diff --git a/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java b/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java index 531c228..0159e56 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java +++ b/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java @@ -36,6 +36,10 @@ public Optional getEncodingForModel(final String modelName) { return Optional.of(getEncodingForModel(ModelType.GPT_4O)); } + if (modelName.startsWith(ModelType.O1.getName())) { + return Optional.of(getEncodingForModel(ModelType.O1)); + } + if (modelName.startsWith(ModelType.GPT_4_32K.getName())) { return Optional.of(getEncodingForModel(ModelType.GPT_4_32K)); } diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/ModelType.java b/lib/src/main/java/com/knuddels/jtokkit/api/ModelType.java index d2623da..5c1fe4f 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/ModelType.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/ModelType.java @@ -15,6 +15,8 @@ public enum ModelType { GPT_4_TURBO("gpt-4-turbo", EncodingType.CL100K_BASE, 128000), GPT_3_5_TURBO("gpt-3.5-turbo", EncodingType.CL100K_BASE, 16385), GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k", EncodingType.CL100K_BASE, 16385), + O1("o1", EncodingType.O200K_BASE, 200000), + O1_MINI("o1-mini", EncodingType.O200K_BASE, 128000), // text TEXT_DAVINCI_003("text-davinci-003", EncodingType.P50K_BASE, 4097), diff --git a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java index d14afa9..6a0af6c 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java @@ -102,6 +102,27 @@ void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4oMini() { assertEquals(encoding.get().getName(), ModelType.GPT_4O_MINI.getEncodingType().getName()); } + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForO1Mini() { + var encoding = registry.getEncodingForModel("o1-mini-2024-09-12"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.O1_MINI.getEncodingType().getName()); + } + + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForO1() { + var encoding = registry.getEncodingForModel("o1-2024-12-17"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.O1.getEncodingType().getName()); + } + + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForO1Preview() { + var encoding = registry.getEncodingForModel("o1-preview-2024-09-12"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.O1.getEncodingType().getName()); + } + @Test void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4Turbo() { var encoding = registry.getEncodingForModel("gpt-4-turbo-123");