diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 6edfe92d..1f78000e 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -126,7 +126,7 @@ public EventSource getCodeCompletionAsync( CodeCompletionRequestFactory.buildCustomRequest(requestDetails), new OpenAITextCompletionEventSourceListener(eventListener)); case LLAMA_CPP -> CompletionClientProvider.getLlamaClient() - .getInfillAsync( + .getChatCompletionAsync( CodeCompletionRequestFactory.buildLlamaRequest(requestDetails), eventListener); case OLLAMA -> CompletionClientProvider.getOllamaClient().getCompletionAsync( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index aadc7d36..a743a76c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -13,7 +13,6 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettingsState import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest -import ee.carlrobert.llm.client.llama.completion.LlamaInfillRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest @@ -82,16 +81,16 @@ object CodeCompletionRequestFactory { } @JvmStatic - fun buildLlamaRequest(details: InfillRequestDetails): LlamaInfillRequest { + fun buildLlamaRequest(details: InfillRequestDetails): LlamaCompletionRequest { val settings = LlamaSettings.getCurrentState() val promptTemplate = getLlamaInfillPromptTemplate(settings) - return LlamaInfillRequest( - LlamaCompletionRequest.Builder(null) - .setN_predict(settings.codeCompletionMaxTokens) - .setStream(true) - .setTemperature(0.4) - .setStop(promptTemplate.stopTokens), details.prefix, details.suffix - ) + val prompt = promptTemplate.buildPrompt(details.prefix, details.suffix) + return LlamaCompletionRequest.Builder(prompt) + .setN_predict(settings.codeCompletionMaxTokens) + .setStream(true) + .setTemperature(0.4) + .setStop(promptTemplate.stopTokens) + .build() } fun buildOllamaRequest(details: InfillRequestDetails): OllamaCompletionRequest { diff --git a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt index 5c2aeeed..71f02c20 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt @@ -35,11 +35,11 @@ class CodeCompletionServiceTest : IntegrationTest() { ${"z".repeat(247)} """.trimIndent() // 128 tokens expectLlama(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/infill") + assertThat(request.uri.path).isEqualTo("/completion") assertThat(request.method).isEqualTo("POST") assertThat(request.body) - .extracting("input_prefix", "input_suffix") - .containsExactly(prefix, suffix) + .extracting("prompt") + .isEqualTo(InfillPromptTemplate.CODE_LLAMA.buildPrompt(prefix, suffix)) listOf(jsonMapResponse(e("content", expectedCompletion), e("stop", true))) })