Skip to content

Commit

Permalink
feat: Implement Ollama as a high-level service (#510)
Browse files Browse the repository at this point in the history
* Initial implementation of Ollama as a service

* Fix model selector in tool window

* Enable image attachment

* Rewrite OllamaSettingsForm in Kt

* Create OllamaInlineCompletionModel and use it for building completion template

* Add support for blocking code completion on models that we don't know support it

* Allow disabling code completion settings

* Disable code completion settings when an unsupported model is entered

* Track FIM template in settings as a derived state

* Update llm-client

* Initial implementation of model combo box

* Add Ollama icon and display models as list

* Make OllamaSettingsState immutable & convert OllamaSettings to Kotlin

* Add refresh models button

* Distinguish between empty/needs refresh/loading

* Avoid storing any model if the combo box is empty

* Fix icon size

* Back to mutable settings
There were some bugs with immutable settings

* Store available models in settings state

* Expose available models in model dropdown

* Add dark icon

* Cleanups for CompletionRequestProvider

* Fix checkstyle issues

* refactor: migrate to SimplePersistentStateComponent

* fix: add code completion stop tokens

* fix: display only one item in the model popup action group

* fix: add back multi model selection

---------

Co-authored-by: Carl-Robert Linnupuu <carlrobertoh@gmail.com>
  • Loading branch information
boswelja and carlrobertoh authored May 7, 2024
1 parent 7f7b35d commit e40630d
Show file tree
Hide file tree
Showing 23 changed files with 505 additions and 39 deletions.
1 change: 1 addition & 0 deletions src/main/java/ee/carlrobert/codegpt/Icons.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public final class Icons {
public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class);
public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class);
public static final Icon YouSmall = IconLoader.getIcon("/icons/you_small.png", Icons.class);
public static final Icon Ollama = IconLoader.getIcon("/icons/ollama.svg", Icons.class);
public static final Icon User = IconLoader.getIcon("/icons/user.svg", Icons.class);
public static final Icon Upload = IconLoader.getIcon("/icons/upload.svg", Icons.class);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.completions;

import com.intellij.openapi.application.ApplicationManager;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.completions.you.YouUserManager;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
Expand All @@ -8,11 +9,13 @@
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.anthropic.ClaudeClient;
import ee.carlrobert.llm.client.azure.AzureClient;
import ee.carlrobert.llm.client.azure.AzureCompletionRequestParams;
import ee.carlrobert.llm.client.llama.LlamaClient;
import ee.carlrobert.llm.client.ollama.OllamaClient;
import ee.carlrobert.llm.client.openai.OpenAIClient;
import ee.carlrobert.llm.client.you.UTMParameters;
import ee.carlrobert.llm.client.you.YouClient;
Expand Down Expand Up @@ -92,6 +95,16 @@ public static LlamaClient getLlamaClient() {
return builder.build(getDefaultClientBuilder());
}

public static OllamaClient getOllamaClient() {
var host = ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getHost();
return new OllamaClient.Builder()
.setHost(host)
.build(getDefaultClientBuilder());
}

public static OkHttpClient.Builder getDefaultClientBuilder() {
OkHttpClient.Builder builder = new OkHttpClient.Builder();
var advancedSettings = AdvancedSettings.getCurrentState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
import ee.carlrobert.codegpt.telemetry.core.configuration.TelemetryConfiguration;
Expand All @@ -41,6 +41,8 @@
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageImageContent;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageTextContent;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
Expand All @@ -56,6 +58,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -140,7 +143,8 @@ public static Request buildCustomOpenAICompletionRequest(String context, String

public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceState.class)
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage(
Expand Down Expand Up @@ -210,7 +214,7 @@ public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getCurrentState();
return new OpenAIChatCompletionRequest.Builder(buildMessages(model, callParameters))
return new OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.getMaxTokens())
.setStream(true)
Expand All @@ -222,7 +226,7 @@ public Request buildCustomOpenAIChatCompletionRequest(
CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
settings,
buildMessages(callParameters),
buildOpenAIMessages(callParameters),
true);
}

Expand Down Expand Up @@ -307,7 +311,68 @@ public ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
return request;
}

private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
public OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
CallParameters callParameters
) {
var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState();
return new OllamaChatCompletionRequest
.Builder(settings.getModel(), buildOllamaMessages(callParameters))
.build();
}

private List<OllamaChatCompletionMessage> buildOllamaMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OllamaChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
String systemPrompt = ConfigurationSettings.getCurrentState().getSystemPrompt();
messages.add(new OllamaChatCompletionMessage("system", systemPrompt, null));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null)
);
}

for (var prevMessage : conversation.getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
var prevMessageImageFilePath = prevMessage.getImageFilePath();
if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) {
try {
var imageFilePath = Path.of(prevMessageImageFilePath);
var imageBytes = Files.readAllBytes(imageFilePath);
var imageBase64 = Base64.getEncoder().encodeToString(imageBytes);
messages.add(
new OllamaChatCompletionMessage(
"user", prevMessage.getPrompt(), List.of(imageBase64)
)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
messages.add(
new OllamaChatCompletionMessage("user", prevMessage.getPrompt(), null)
);
}
messages.add(
new OllamaChatCompletionMessage("assistant", prevMessage.getResponse(), null)
);
}

if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) {
var imageBase64 = Base64.getEncoder().encodeToString(callParameters.getImageData());
messages.add(
new OllamaChatCompletionMessage("user", message.getPrompt(), List.of(imageBase64))
);
} else {
messages.add(new OllamaChatCompletionMessage("user", message.getPrompt(), null));
}
return messages;
}

private List<OpenAIChatCompletionMessage> buildOpenAIMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
Expand Down Expand Up @@ -339,7 +404,9 @@ private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParam
} else {
messages.add(new OpenAIChatCompletionStandardMessage("user", prevMessage.getPrompt()));
}
messages.add(new OpenAIChatCompletionStandardMessage("assistant", prevMessage.getResponse()));
messages.add(
new OpenAIChatCompletionStandardMessage("assistant", prevMessage.getResponse())
);
}

if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) {
Expand All @@ -355,10 +422,10 @@ private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParam
return messages;
}

private List<OpenAIChatCompletionMessage> buildMessages(
private List<OpenAIChatCompletionMessage> buildOpenAIMessages(
@Nullable String model,
CallParameters callParameters) {
var messages = buildMessages(callParameters);
var messages = buildOpenAIMessages(callParameters);

if (model == null
|| GeneralSettings.getCurrentState().getSelectedService() == ServiceType.YOU) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
Expand Down Expand Up @@ -104,6 +107,9 @@ public EventSource getChatCompletionAsync(
callParameters.getMessage(),
callParameters.getConversationType()),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
requestProvider.buildOllamaChatCompletionRequest(callParameters),
eventListener);
};
}

Expand All @@ -123,6 +129,9 @@ public EventSource getCodeCompletionAsync(
.getInfillAsync(
CodeCompletionRequestFactory.buildLlamaRequest(requestDetails),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getCompletionAsync(
CodeCompletionRequestFactory.INSTANCE.buildOllamaRequest(requestDetails),
eventListener);
default ->
throw new IllegalArgumentException("Code completion not supported for selected service");
};
Expand Down Expand Up @@ -189,6 +198,20 @@ public void generateCommitMessageAsync(
.setRepeat_penalty(settings.getRepeatPenalty())
.build(), eventListener);
break;
case OLLAMA:
var model = ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
var request = new OllamaChatCompletionRequest.Builder(
model,
List.of(
new OllamaChatCompletionMessage("system", systemPrompt, null),
new OllamaChatCompletionMessage("user", gitDiff, null)
)
).build();
CompletionClientProvider.getOllamaClient().getChatCompletionAsync(request, eventListener);
break;
default:
LOG.debug("Unknown service: {}", selectedService);
break;
Expand Down Expand Up @@ -228,9 +251,9 @@ public static boolean isRequestAllowed(ServiceType serviceType) {
case OPENAI -> CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.OPENAI_API_KEY);
case AZURE -> CredentialsStore.INSTANCE.isCredentialSet(
AzureSettings.getCurrentState().isUseAzureApiKeyAuthentication()
? CredentialKey.AZURE_OPENAI_API_KEY
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN);
case CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP -> true;
? CredentialKey.AZURE_OPENAI_API_KEY
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN);
case CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP, OLLAMA -> true;
case YOU -> false;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import java.time.LocalDateTime;
import java.util.ArrayList;
Expand Down Expand Up @@ -195,9 +196,13 @@ private static String getModelForSelectedService(ServiceType serviceType) {
case LLAMA_CPP -> {
var llamaSettings = LlamaSettings.getCurrentState();
yield llamaSettings.isUseCustomModel()
? llamaSettings.getCustomLlamaModelPath()
: llamaSettings.getHuggingFaceModel().getCode();
? llamaSettings.getCustomLlamaModelPath()
: llamaSettings.getHuggingFaceModel().getCode();
}
case OLLAMA -> ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import org.jetbrains.annotations.NotNull;

Expand Down Expand Up @@ -69,6 +70,9 @@ public void sync(Conversation conversation) {
if ("you.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.YOU);
}
if ("ollama.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.OLLAMA);
}
}

public String getModel() {
Expand Down Expand Up @@ -98,6 +102,11 @@ public String getModel() {
llamaModel.getLabel(),
huggingFaceModel.getParameterSize(),
huggingFaceModel.getQuantization());
case OLLAMA:
return ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
default:
return "Unknown";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;

Expand All @@ -20,6 +21,8 @@
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
Expand All @@ -45,6 +48,7 @@ public class GeneralSettingsComponent {
private final AzureSettingsForm azureSettingsForm;
private final YouSettingsForm youSettingsForm;
private final LlamaSettingsForm llamaSettingsForm;
private final OllamaSettingsForm ollamaSettingsForm;

public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings settings) {
displayNameField = new JBTextField(settings.getState().getDisplayName(), 20);
Expand All @@ -54,6 +58,7 @@ public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings set
azureSettingsForm = new AzureSettingsForm(AzureSettings.getCurrentState());
youSettingsForm = new YouSettingsForm(YouSettings.getCurrentState(), parentDisposable);
llamaSettingsForm = new LlamaSettingsForm(LlamaSettings.getCurrentState());
ollamaSettingsForm = new OllamaSettingsForm();

var cardLayout = new DynamicCardLayout();
var cards = new JPanel(cardLayout);
Expand All @@ -63,6 +68,7 @@ public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings set
cards.add(azureSettingsForm.getForm(), AZURE.getCode());
cards.add(youSettingsForm, YOU.getCode());
cards.add(llamaSettingsForm, LLAMA_CPP.getCode());
cards.add(ollamaSettingsForm.getForm(), OLLAMA.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()).toList());
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
Expand Down Expand Up @@ -106,6 +112,10 @@ public YouSettingsForm getYouSettingsForm() {
return youSettingsForm;
}

public OllamaSettingsForm getOllamaSettingsForm() {
return ollamaSettingsForm;
}

public ServiceType getSelectedService() {
return serviceComboBox.getItem();
}
Expand Down Expand Up @@ -137,6 +147,7 @@ public void resetForms() {
azureSettingsForm.resetForm();
youSettingsForm.resetForm();
llamaSettingsForm.resetForm();
ollamaSettingsForm.resetForm();
}

static class DynamicCardLayout extends CardLayout {
Expand Down
Loading

0 comments on commit e40630d

Please sign in to comment.