From 127c8d91b8399e504c9c25dc98c1e036714f2d5b Mon Sep 17 00:00:00 2001 From: openhands Date: Wed, 29 Jan 2025 22:50:12 +0000 Subject: [PATCH 1/7] Add custom providers support - Add CustomChatEndpoint and CustomModelsEndpoint - Add CustomProvider model and UI components - Support custom providers in model selection - Display custom models in settings and model selection view --- .../Endpoints/CustomChatEndpoint.swift | 22 +++ .../Endpoints/CustomModelsEndpoint.swift | 30 ++++ .../Data/Fetching/FetchingClient+Error.swift | 12 +- macos/Onit/Data/Fetching/FetchingClient.swift | 2 + macos/Onit/Data/Model/CustomProvider.swift | 88 ++++++++++ macos/Onit/Data/Model/Model+Chat.swift | 74 +++++++-- .../Data/Model/Model+TokenValidation.swift | 12 ++ macos/Onit/Data/Model/OnitModel.swift | 5 + macos/Onit/Data/Structures/AIModel.swift | 24 +++ .../Onit/UI/Prompt/Dialogs/SetUpDialogs.swift | 2 + macos/Onit/UI/RemoteModelsState.swift | 9 ++ .../Models/CustomProvidersSection.swift | 150 ++++++++++++++++++ .../Onit/UI/Settings/Models/ModelToggle.swift | 2 +- .../Settings/Models/RemoteModelSection.swift | 12 ++ macos/Onit/UI/Settings/ModelsTab.swift | 1 + 15 files changed, 430 insertions(+), 15 deletions(-) create mode 100644 macos/Onit/Data/Fetching/Endpoints/CustomChatEndpoint.swift create mode 100644 macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift create mode 100644 macos/Onit/Data/Model/CustomProvider.swift create mode 100644 macos/Onit/UI/Settings/Models/CustomProvidersSection.swift diff --git a/macos/Onit/Data/Fetching/Endpoints/CustomChatEndpoint.swift b/macos/Onit/Data/Fetching/Endpoints/CustomChatEndpoint.swift new file mode 100644 index 00000000..334a4449 --- /dev/null +++ b/macos/Onit/Data/Fetching/Endpoints/CustomChatEndpoint.swift @@ -0,0 +1,22 @@ +import Foundation + +struct CustomChatEndpoint: Endpoint { + var baseURL: URL + + typealias Request = OpenAIChatRequest + typealias Response = OpenAIChatResponse + + let messages: [OpenAIChatMessage] + let token: String? + let model: String + + var path: String { "/v1/chat/completions" } + var getParams: [String: String]? { nil } + var method: HTTPMethod { .post } + var requestBody: OpenAIChatRequest? { + OpenAIChatRequest(model: model, messages: messages) + } + var additionalHeaders: [String: String]? { + ["Authorization": "Bearer \(token ?? "")"] + } +} \ No newline at end of file diff --git a/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift b/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift new file mode 100644 index 00000000..5f92a5e5 --- /dev/null +++ b/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift @@ -0,0 +1,30 @@ +import Foundation + +struct CustomModelsEndpoint: Endpoint { + typealias Request = EmptyRequest + typealias Response = CustomModelsResponse + + var baseURL: URL + let token: String? + + var path: String { "/v1/models" } + var getParams: [String: String]? { nil } + var method: HTTPMethod { .get } + var requestBody: EmptyRequest? { nil } + + var additionalHeaders: [String: String]? { + ["Authorization": "Bearer \(token ?? "")"] + } +} + +struct CustomModelsResponse: Codable { + let object: String + let data: [CustomModelInfo] +} + +struct CustomModelInfo: Codable { + let id: String + let object: String + let created: Int + let owned_by: String +} \ No newline at end of file diff --git a/macos/Onit/Data/Fetching/FetchingClient+Error.swift b/macos/Onit/Data/Fetching/FetchingClient+Error.swift index 1584e6be..bdfd80c4 100644 --- a/macos/Onit/Data/Fetching/FetchingClient+Error.swift +++ b/macos/Onit/Data/Fetching/FetchingClient+Error.swift @@ -17,6 +17,7 @@ public enum FetchingError: Error { case serverError(statusCode: Int, message: String) case decodingError(Error) case networkError(Error) + case invalidURL } extension FetchingError: LocalizedError { @@ -40,6 +41,8 @@ extension FetchingError: LocalizedError { "Failed to decode response: \(error.localizedDescription)" case .networkError(let error): "Network error: \(error.localizedDescription)" + case .invalidURL: + "Invalid URL" } } } @@ -49,7 +52,8 @@ extension FetchingError: Equatable { switch (lhs, rhs) { case (.invalidResponse, .invalidResponse), (.unauthorized, .unauthorized), - (.notFound, .notFound): + (.notFound, .notFound), + (.invalidURL, .invalidURL): return true case (.forbidden(let lhsMessage), .forbidden(let rhsMessage)): return lhsMessage == rhsMessage @@ -73,7 +77,7 @@ extension FetchingError: Codable { } enum FetchingErrorType: String, Codable { - case invalidResponse, invalidRequest, unauthorized, forbidden, notFound, failedRequest, serverError, decodingError, networkError + case invalidResponse, invalidRequest, unauthorized, forbidden, notFound, failedRequest, serverError, decodingError, networkError, invalidURL } public init(from decoder: Decoder) throws { @@ -111,6 +115,8 @@ extension FetchingError: Codable { let errorDescription = try container.decode(String.self, forKey: .description) let error = NSError(domain: "", code: 0, userInfo: [NSLocalizedDescriptionKey: errorDescription]) self = .networkError(error) + case .invalidURL: + self = .invalidURL } } @@ -146,6 +152,8 @@ extension FetchingError: Codable { case .networkError(let error): try container.encode(FetchingErrorType.networkError, forKey: .type) try container.encode(error.localizedDescription, forKey: .description) + case .invalidURL: + try container.encode(FetchingErrorType.invalidURL, forKey: .type) } } } diff --git a/macos/Onit/Data/Fetching/FetchingClient.swift b/macos/Onit/Data/Fetching/FetchingClient.swift index 56bbf748..cdeefcf8 100644 --- a/macos/Onit/Data/Fetching/FetchingClient.swift +++ b/macos/Onit/Data/Fetching/FetchingClient.swift @@ -244,6 +244,8 @@ actor FetchingClient { let endpoint = GoogleAIChatEndpoint(messages: googleAIMessageStack, model: model.id, token: apiToken) let response = try await execute(endpoint) return response.choices[0].message.content + case .custom: + return "" // TODO: KNA - } } diff --git a/macos/Onit/Data/Model/CustomProvider.swift b/macos/Onit/Data/Model/CustomProvider.swift new file mode 100644 index 00000000..7dafb354 --- /dev/null +++ b/macos/Onit/Data/Model/CustomProvider.swift @@ -0,0 +1,88 @@ +import Defaults +import Foundation +import SwiftData + +@Model +class CustomProvider: Codable, Defaults.Serializable { + var name: String + var baseURL: String + var token: String + var models: [String] + var isEnabled: Bool { + didSet { + // Update model visibility when enabled state changes + let modelIds = Set(models) + if isEnabled { + // Add model IDs to visible set + Defaults[.visibleModelIds].formUnion(modelIds) + } else { + // Remove model IDs from visible set + Defaults[.visibleModelIds].subtract(modelIds) + } + } + } + + init(name: String, baseURL: String, token: String) { + self.name = name + self.baseURL = baseURL + self.token = token + self.models = [] + self.isEnabled = true + } + + func fetchModels() async throws { + guard let url = URL(string: baseURL) else { return } + + let endpoint = CustomModelsEndpoint(baseURL: url, token: token) + let client = FetchingClient() + let response = try await client.execute(endpoint) + + models = response.data.map { $0.id } + + // Initialize model IDs + let newModels = models.map { modelId in + AIModel(from: CustomModelInfo( + id: modelId, + object: "model", + created: Int(Date().timeIntervalSince1970), + owned_by: name + ), provider: self) + } + + // Add new models to available remote models + Defaults[.availableRemoteModels].append(contentsOf: newModels) + + // Initialize visible model IDs + for model in newModels { + Defaults[.visibleModelIds].insert(model.id) + } + } + + // MARK: - Decodable + + enum CodingKeys: CodingKey { + case name + case baseURL + case token + case models + case isEnabled + } + + required init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + name = try container.decode(String.self, forKey: .name) + baseURL = try container.decode(String.self, forKey: .baseURL) + token = try container.decode(String.self, forKey: .token) + models = try container.decode([String].self, forKey: .models) + isEnabled = try container.decode(Bool.self, forKey: .isEnabled) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encode(baseURL, forKey: .baseURL) + try container.encode(token, forKey: .token) + try container.encode(models, forKey: .models) + try container.encode(isEnabled, forKey: .isEnabled) + } +} diff --git a/macos/Onit/Data/Model/Model+Chat.swift b/macos/Onit/Data/Model/Model+Chat.swift index e5b7496e..e280f02d 100644 --- a/macos/Onit/Data/Model/Model+Chat.swift +++ b/macos/Onit/Data/Model/Model+Chat.swift @@ -85,17 +85,35 @@ extension OnitModel { do { let chat: String if Defaults[.mode] == .remote { - // chat(instructions: [String], inputs: [Input?], files: [[URL]], images[[URL]], responses: [String], model: AIModel?, apiToken: String?) - chat = try await client.chat( - instructions: instructionsHistory, - inputs: inputsHistory, - files: filesHistory, - images: imagesHistory, - autoContexts: autoContextsHistory, - responses: responsesHistory, - model: Defaults[.remoteModel], - apiToken: getTokenForModel(Defaults[.remoteModel] ?? nil) - ) + if let customProvider = Defaults[.remoteModel]?.customProvider { + // Handle custom provider chat + let messages = createOpenAIMessages( + instructions: instructionsHistory, + responses: responsesHistory + ) + if let endpoint = getCustomEndpoint( + for: customProvider, + messages: messages, + model: Defaults[.remoteModel]?.id ?? "" + ) { + let response = try await client.execute(endpoint) + chat = response.choices[0].message.content + } else { + throw FetchingError.invalidURL + } + } else { + // Regular remote model chat + chat = try await client.chat( + instructions: instructionsHistory, + inputs: inputsHistory, + files: filesHistory, + images: imagesHistory, + autoContexts: autoContextsHistory, + responses: responsesHistory, + model: Defaults[.remoteModel], + apiToken: getTokenForModel(Defaults[.remoteModel] ?? nil) + ) + } } else { // TODO implement history for local chat! chat = try await client.localChat( @@ -140,9 +158,23 @@ extension OnitModel { */ private func trackEventGeneration(prompt: Prompt) { let eventName = "user_prompted" + var modelName = "" + + if Defaults[.mode] == .remote { + if let model = Defaults[.remoteModel] { + if model.provider == .custom { + modelName = "\(model.customProvider?.name ?? "Custom")/\(model.displayName)" + } else { + modelName = model.displayName + } + } + } else { + modelName = Defaults[.localModel] ?? "" + } + let eventProperties: [String: Any] = [ "prompt_mode": Defaults[.mode].rawValue, - "prompt_model": Defaults[.mode] == .remote ? Defaults[.remoteModel]?.displayName ?? "" : Defaults[.localModel] ?? "" + "prompt_model": modelName ] PostHogSDK.shared.capture(eventName, properties: eventProperties) } @@ -172,6 +204,8 @@ extension OnitModel { Defaults[.isXAITokenValidated] = isValid case .googleAI: Defaults[.isGoogleAITokenValidated] = isValid + case .custom: + break // TODO: KNA - } } @@ -186,6 +220,8 @@ extension OnitModel { return Defaults[.xAIToken] case .googleAI: return Defaults[.googleAIToken] + case .custom: + return nil // TODO: KNA - } } return nil @@ -197,4 +233,18 @@ extension OnitModel { prompt.generationIndex = (prompt.responses.count - 1) prompt.generationState = .done } + + private func createOpenAIMessages(instructions: [String], responses: [String]) -> [OpenAIChatMessage] { + var messages: [OpenAIChatMessage] = [] + + for (index, instruction) in instructions.enumerated() { + messages.append(OpenAIChatMessage(role: "user", content: .text(instruction))) + + if index < responses.count { + messages.append(OpenAIChatMessage(role: "assistant", content: .text(responses[index]))) + } + } + + return messages + } } diff --git a/macos/Onit/Data/Model/Model+TokenValidation.swift b/macos/Onit/Data/Model/Model+TokenValidation.swift index 149c58eb..314272b7 100644 --- a/macos/Onit/Data/Model/Model+TokenValidation.swift +++ b/macos/Onit/Data/Model/Model+TokenValidation.swift @@ -3,6 +3,7 @@ // Onit // +import Defaults import Foundation struct TokenValidationState { @@ -101,6 +102,17 @@ extension OnitModel { let endpoint = GoogleAIValidationEndpoint(apiKey: token) _ = try await FetchingClient().execute(endpoint) state.setValid(provider: provider) + + case .custom: + // For custom providers, we'll validate by trying to fetch the models list + if let customProvider = Defaults[.remoteModel]?.customProvider, + let url = URL(string: customProvider.baseURL) { + let endpoint = CustomModelsEndpoint(baseURL: url, token: token) + _ = try await FetchingClient().execute(endpoint) + state.setValid(provider: provider) + } else { + throw FetchingError.invalidURL + } } setTokenIsValid(true) } catch let error as FetchingError { diff --git a/macos/Onit/Data/Model/OnitModel.swift b/macos/Onit/Data/Model/OnitModel.swift index d40b0163..d709f4eb 100644 --- a/macos/Onit/Data/Model/OnitModel.swift +++ b/macos/Onit/Data/Model/OnitModel.swift @@ -72,6 +72,11 @@ import Defaults var remoteFetchFailed: Bool = false var localFetchFailed: Bool = false + func getCustomEndpoint(for provider: CustomProvider, messages: [OpenAIChatMessage], model: String) -> CustomChatEndpoint? { + guard let url = URL(string: provider.baseURL) else { return nil } + return CustomChatEndpoint(baseURL: url, messages: messages, token: provider.token, model: model) + } + @MainActor func fetchLocalModels() async { do { diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 08905fce..d0b84619 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -17,6 +17,24 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { let supportsSystemPrompts: Bool var isNew: Bool = false var isDeprecated: Bool = false + var customProvider: CustomProvider? + + var formattedDisplayName: String { + if provider == .custom, let provider = customProvider { + return "\(provider.name) / \(displayName)" + } + return displayName + } + + init(from customModel: CustomModelInfo, provider: CustomProvider) { + self.id = customModel.id + self.displayName = customModel.id + self.provider = .custom + self.defaultOn = false + self.supportsVision = false + self.supportsSystemPrompts = true + self.customProvider = provider + } init?(from modelInfo: ModelInfo) { guard let provider = ModelProvider(rawValue: modelInfo.provider.lowercased()) else { @@ -30,6 +48,7 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { self.supportsSystemPrompts = modelInfo.supportsSystemPrompts } + @MainActor static func fetchModels() async throws -> [AIModel] { let client = FetchingClient() let endpoint = RemoteModelsEndpoint() @@ -42,6 +61,7 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { case anthropic = "anthropic" case xAI = "xai" case googleAI = "googleai" + case custom = "custom" var title: String { switch self { @@ -49,6 +69,7 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { case .anthropic: return "Anthropic" case .xAI: return "xAI" case .googleAI: return "Google AI" + case .custom: return "Custom Providers" } } @@ -58,6 +79,7 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { case .anthropic: return "Claude" case .xAI: return "Grok" case .googleAI: return "Gemini" + case .custom: return "Custom Model" } } @@ -71,6 +93,8 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { return URL(string: "https://accounts.x.ai/account")! case .googleAI: return URL(string: "https://makersuite.google.com/app/apikey")! + case .custom: + return URL(string: "about:blank")! } } } diff --git a/macos/Onit/UI/Prompt/Dialogs/SetUpDialogs.swift b/macos/Onit/UI/Prompt/Dialogs/SetUpDialogs.swift index 05b851d2..61378ca9 100644 --- a/macos/Onit/UI/Prompt/Dialogs/SetUpDialogs.swift +++ b/macos/Onit/UI/Prompt/Dialogs/SetUpDialogs.swift @@ -248,6 +248,8 @@ struct SetUpDialogs: View { closedXAI = true case .googleAI: closedGoogleAI = true + case .custom: + break // TODO: KNA - } } } diff --git a/macos/Onit/UI/RemoteModelsState.swift b/macos/Onit/UI/RemoteModelsState.swift index 32b6e47c..6278ffc7 100644 --- a/macos/Onit/UI/RemoteModelsState.swift +++ b/macos/Onit/UI/RemoteModelsState.swift @@ -55,6 +55,15 @@ final class RemoteModelsState: ObservableObject { models = models.filter { $0.provider != .googleAI } } + // Filter out models from disabled custom providers + models = models.filter { model in + if model.provider == .custom, + let provider = model.customProvider { + return provider.isEnabled + } + return true + } + return models } } diff --git a/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift b/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift new file mode 100644 index 00000000..9a88fc6c --- /dev/null +++ b/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift @@ -0,0 +1,150 @@ +import SwiftUI +import SwiftData + +struct CustomProvidersSection: View { + @Environment(\.model) var model + @Environment(\.modelContext) private var modelContext + @Query private var customProviders: [CustomProvider] + + @State private var isAddingProvider = false + @State private var newProviderName = "" + @State private var newProviderURL = "" + @State private var newProviderToken = "" + @State private var errorMessage: String? + + var body: some View { + RemoteModelSection(title: "Custom Providers") { + VStack(alignment: .leading, spacing: 12) { + ForEach(customProviders) { provider in + CustomProviderRow(provider: provider) + } + + if isAddingProvider { + VStack(alignment: .leading, spacing: 8) { + TextField("Provider Name", text: $newProviderName) + .textFieldStyle(.roundedBorder) + + TextField("Base URL", text: $newProviderURL) + .textFieldStyle(.roundedBorder) + + SecureField("API Token", text: $newProviderToken) + .textFieldStyle(.roundedBorder) + + HStack { + Button("Add Provider") { + addProvider() + } + .disabled(newProviderName.isEmpty || newProviderURL.isEmpty || newProviderToken.isEmpty) + + Button("Cancel") { + isAddingProvider = false + resetForm() + } + } + + if let error = errorMessage { + Text(error) + .foregroundColor(.red) + .font(.caption) + } + } + .padding() + .background(Color(.textBackgroundColor)) + .cornerRadius(8) + } else { + Button("Add Custom Provider") { + isAddingProvider = true + } + } + } + } + } + + private func addProvider() { + Task { + do { + let provider = CustomProvider( + name: newProviderName, + baseURL: newProviderURL, + token: newProviderToken + ) + + try await provider.fetchModels() + modelContext.insert(provider) + + isAddingProvider = false + resetForm() + } catch { + errorMessage = "Failed to fetch models: \(error.localizedDescription)" + } + } + } + + private func resetForm() { + newProviderName = "" + newProviderURL = "" + newProviderToken = "" + errorMessage = nil + } +} + +struct CustomProviderRow: View { + @Environment(\.model) var model + @Environment(\.modelContext) private var modelContext + + let provider: CustomProvider + + var body: some View { + VStack(alignment: .leading) { + HStack { + Toggle(isOn: .init( + get: { provider.isEnabled }, + set: { provider.isEnabled = $0 } + )) { + Text(provider.name) + .font(.headline) + } + + Spacer() + + Button(role: .destructive) { + // Remove provider's models from available remote models + model.updatePreferences { prefs in + prefs.availableRemoteModels.removeAll { model in + model.customProvider?.id == provider.id + } + } + modelContext.delete(provider) + } label: { + Image(systemName: "trash") + } + } + + if provider.isEnabled { + GroupBox { + VStack(alignment: .leading, spacing: 0) { + ForEach(provider.models, id: \.self) { modelId in + let aiModel = AIModel( + from: CustomModelInfo( + id: modelId, + object: "model", + created: Int(Date().timeIntervalSince1970), + owned_by: provider.name + ), + provider: provider + ) + ModelToggle(aiModel: aiModel) + .frame(height: 36) + } + } + .padding(.vertical, -4) + .padding(.horizontal, 4) + .frame(maxWidth: .infinity, alignment: .leading) + } + } + } + .padding() + .background(Color(.textBackgroundColor)) + .cornerRadius(8) + } +} diff --git a/macos/Onit/UI/Settings/Models/ModelToggle.swift b/macos/Onit/UI/Settings/Models/ModelToggle.swift index 9527fd54..6661cc74 100644 --- a/macos/Onit/UI/Settings/Models/ModelToggle.swift +++ b/macos/Onit/UI/Settings/Models/ModelToggle.swift @@ -30,7 +30,7 @@ struct ModelToggle: View { var body: some View { Toggle(isOn: isOn) { HStack { - Text(aiModel.displayName) + Text(aiModel.formattedDisplayName) .font(.system(size: 13)) .fontWeight(.regular) .opacity(0.85) diff --git a/macos/Onit/UI/Settings/Models/RemoteModelSection.swift b/macos/Onit/UI/Settings/Models/RemoteModelSection.swift index 2aecea98..91ce22a6 100644 --- a/macos/Onit/UI/Settings/Models/RemoteModelSection.swift +++ b/macos/Onit/UI/Settings/Models/RemoteModelSection.swift @@ -203,6 +203,8 @@ struct RemoteModelSection: View { xAIToken = key.isEmpty ? nil : key case .googleAI: googleAIToken = key.isEmpty ? nil : key + case .custom: + break // TODO: KNA - } } @@ -216,6 +218,8 @@ struct RemoteModelSection: View { key = xAIToken ?? "" case .googleAI: key = googleAIToken ?? "" + case .custom: + break // TODO: KNA - } } @@ -229,6 +233,8 @@ struct RemoteModelSection: View { use = useXAI case .googleAI: use = useGoogleAI + case .custom: + break // TODO: KNA - } } @@ -242,6 +248,8 @@ struct RemoteModelSection: View { validated = isXAITokenValidated case .googleAI: validated = isGoogleAITokenValidated + case .custom: + break // TODO: KNA - } } @@ -255,6 +263,8 @@ struct RemoteModelSection: View { useXAI = use case .googleAI: useGoogleAI = use + case .custom: + break // TODO: KNA - } } @@ -268,6 +278,8 @@ struct RemoteModelSection: View { isXAITokenValidated = validated case .googleAI: isGoogleAITokenValidated = validated + case .custom: + break // TODO: KNA - } } diff --git a/macos/Onit/UI/Settings/ModelsTab.swift b/macos/Onit/UI/Settings/ModelsTab.swift index 977f19fe..a906ff8e 100644 --- a/macos/Onit/UI/Settings/ModelsTab.swift +++ b/macos/Onit/UI/Settings/ModelsTab.swift @@ -16,6 +16,7 @@ struct ModelsTab: View { RemoteModelsSection() LocalModelsSection() DefaultModelsSection() + CustomProvidersSection() } .padding(.vertical, 20) .padding(.horizontal, 86) From 76f26ae1bc5b89e557ba5a3981c21aa8d5f6220e Mon Sep 17 00:00:00 2001 From: Niduank Date: Wed, 5 Feb 2025 17:55:39 +0100 Subject: [PATCH 2/7] Custom endpoint almost working --- macos/Onit/Data/Model/CustomProvider.swift | 85 ++-------- macos/Onit/Data/Model/Model+Chat.swift | 8 +- .../Data/Model/Model+TokenValidation.swift | 3 +- macos/Onit/Data/Model/OnitModel.swift | 7 +- macos/Onit/Data/Persistence/Defaults.swift | 1 + macos/Onit/Data/Structures/AIModel.swift | 10 +- .../Preview Content/PreviewSampleData.swift | 5 + .../Onit/UI/Components/LabeledTextField.swift | 44 +++++ .../Prompt/Selection/ModelSelectionView.swift | 59 +++++-- macos/Onit/UI/RemoteModelsState.swift | 35 ++-- .../CustomProviderFormView.swift | 116 ++++++++++++++ .../CustomProvider/CustomProviderRow.swift | 76 +++++++++ .../CustomProvidersSection.swift | 59 +++++++ .../Models/CustomProvidersSection.swift | 150 ------------------ .../Settings/Models/RemoteModelsSection.swift | 17 +- macos/Onit/UI/Settings/ModelsTab.swift | 1 - 16 files changed, 409 insertions(+), 267 deletions(-) create mode 100644 macos/Onit/UI/Components/LabeledTextField.swift create mode 100644 macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift create mode 100644 macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift create mode 100644 macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift delete mode 100644 macos/Onit/UI/Settings/Models/CustomProvidersSection.swift diff --git a/macos/Onit/Data/Model/CustomProvider.swift b/macos/Onit/Data/Model/CustomProvider.swift index 7dafb354..78244be5 100644 --- a/macos/Onit/Data/Model/CustomProvider.swift +++ b/macos/Onit/Data/Model/CustomProvider.swift @@ -2,87 +2,26 @@ import Defaults import Foundation import SwiftData -@Model -class CustomProvider: Codable, Defaults.Serializable { +struct CustomProvider: Codable, Identifiable, Hashable, Defaults.Serializable { + var id: String { name } + var name: String var baseURL: String var token: String var models: [String] - var isEnabled: Bool { - didSet { - // Update model visibility when enabled state changes - let modelIds = Set(models) - if isEnabled { - // Add model IDs to visible set - Defaults[.visibleModelIds].formUnion(modelIds) - } else { - // Remove model IDs from visible set - Defaults[.visibleModelIds].subtract(modelIds) - } - } - } - - init(name: String, baseURL: String, token: String) { + var isEnabled: Bool + + init(name: String, baseURL: String, token: String, models: [String]) { self.name = name self.baseURL = baseURL self.token = token - self.models = [] + self.models = models self.isEnabled = true } - - func fetchModels() async throws { - guard let url = URL(string: baseURL) else { return } - - let endpoint = CustomModelsEndpoint(baseURL: url, token: token) - let client = FetchingClient() - let response = try await client.execute(endpoint) - - models = response.data.map { $0.id } - - // Initialize model IDs - let newModels = models.map { modelId in - AIModel(from: CustomModelInfo( - id: modelId, - object: "model", - created: Int(Date().timeIntervalSince1970), - owned_by: name - ), provider: self) - } - - // Add new models to available remote models - Defaults[.availableRemoteModels].append(contentsOf: newModels) - - // Initialize visible model IDs - for model in newModels { - Defaults[.visibleModelIds].insert(model.id) - } - } - - // MARK: - Decodable - - enum CodingKeys: CodingKey { - case name - case baseURL - case token - case models - case isEnabled - } - - required init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - name = try container.decode(String.self, forKey: .name) - baseURL = try container.decode(String.self, forKey: .baseURL) - token = try container.decode(String.self, forKey: .token) - models = try container.decode([String].self, forKey: .models) - isEnabled = try container.decode(Bool.self, forKey: .isEnabled) - } - - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(name, forKey: .name) - try container.encode(baseURL, forKey: .baseURL) - try container.encode(token, forKey: .token) - try container.encode(models, forKey: .models) - try container.encode(isEnabled, forKey: .isEnabled) + + static func == (lhs: CustomProvider, rhs: CustomProvider) -> Bool { + return lhs.name == rhs.name && lhs.baseURL == rhs.baseURL + && lhs.token == rhs.token && lhs.models == rhs.models + && lhs.isEnabled == rhs.isEnabled } } diff --git a/macos/Onit/Data/Model/Model+Chat.swift b/macos/Onit/Data/Model/Model+Chat.swift index e280f02d..0797291b 100644 --- a/macos/Onit/Data/Model/Model+Chat.swift +++ b/macos/Onit/Data/Model/Model+Chat.swift @@ -85,7 +85,9 @@ extension OnitModel { do { let chat: String if Defaults[.mode] == .remote { - if let customProvider = Defaults[.remoteModel]?.customProvider { + if let customProviderName = Defaults[.remoteModel]?.customProviderName, + let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == customProviderName }) { + // Handle custom provider chat let messages = createOpenAIMessages( instructions: instructionsHistory, @@ -162,8 +164,8 @@ extension OnitModel { if Defaults[.mode] == .remote { if let model = Defaults[.remoteModel] { - if model.provider == .custom { - modelName = "\(model.customProvider?.name ?? "Custom")/\(model.displayName)" + if let customProviderName = model.customProviderName { + modelName = "\(customProviderName)/\(model.displayName)" } else { modelName = model.displayName } diff --git a/macos/Onit/Data/Model/Model+TokenValidation.swift b/macos/Onit/Data/Model/Model+TokenValidation.swift index 314272b7..f908a56e 100644 --- a/macos/Onit/Data/Model/Model+TokenValidation.swift +++ b/macos/Onit/Data/Model/Model+TokenValidation.swift @@ -105,7 +105,8 @@ extension OnitModel { case .custom: // For custom providers, we'll validate by trying to fetch the models list - if let customProvider = Defaults[.remoteModel]?.customProvider, + if let customProviderName = Defaults[.remoteModel]?.customProviderName, + let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == customProviderName }), let url = URL(string: customProvider.baseURL) { let endpoint = CustomModelsEndpoint(baseURL: url, token: token) _ = try await FetchingClient().execute(endpoint) diff --git a/macos/Onit/Data/Model/OnitModel.swift b/macos/Onit/Data/Model/OnitModel.swift index d709f4eb..131d0abc 100644 --- a/macos/Onit/Data/Model/OnitModel.swift +++ b/macos/Onit/Data/Model/OnitModel.swift @@ -91,6 +91,7 @@ import Defaults } else if localModel == nil || !models.contains(localModel!) { Defaults[.localModel] = models[0] } + print("KNA - OnitModel fetchLocalModels") if remoteModels.listedModels.isEmpty { Defaults[.mode] = .local } @@ -114,11 +115,11 @@ import Defaults var models = try await AIModel.fetchModels() // This means we've never successfully fetched before - if Defaults[.availableLocalModels].isEmpty { if Defaults[.visibleModelIds].isEmpty { Defaults[.visibleModelIds] = Set(models.filter { $0.defaultOn }.map { $0.id }) } + Defaults[.availableRemoteModels] = models if !remoteModels.listedModels.isEmpty { Defaults[.remoteModel] = remoteModels.listedModels.first @@ -140,13 +141,13 @@ import Defaults deprecatedModels[index].isDeprecated = true } - // We only save deprecated models if the user has them visibile. Otherwise, quietly remove them from the list. + // We only save deprecated models if the user has them visibile. Otherwise, quietly remove them from the list. let visibleModelIds = Set(Defaults[.visibleModelIds]) let visibleDeprecatedModels = deprecatedModels.filter { visibleModelIds.contains($0.id) } remoteFetchFailed = false Defaults[.availableRemoteModels] = models + visibleDeprecatedModels - if Defaults[.visibleModelIds].isEmpty { + if visibleModelIds.isEmpty { Defaults[.visibleModelIds] = Set((models + visibleDeprecatedModels).filter { $0.defaultOn }.map { $0.id }) } diff --git a/macos/Onit/Data/Persistence/Defaults.swift b/macos/Onit/Data/Persistence/Defaults.swift index 24230072..3d251f85 100644 --- a/macos/Onit/Data/Persistence/Defaults.swift +++ b/macos/Onit/Data/Persistence/Defaults.swift @@ -50,6 +50,7 @@ extension Defaults.Keys { static let mode = Key("mode", default: .remote) static let availableLocalModels = Key<[String]>("availableLocalModels", default: []) static let availableRemoteModels = Key<[AIModel]>("availableRemoteModels", default: []) + static let availableCustomProvider = Key<[CustomProvider]>("availableCustomProvider", default: []) static let visibleModelIds = Key>("visibleModelIds", default: Set([])) static let localEndpointURL = Key("localEndpointURL", default: URL(string: "http://localhost:11434")!) diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index d0b84619..309eb600 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -17,23 +17,23 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { let supportsSystemPrompts: Bool var isNew: Bool = false var isDeprecated: Bool = false - var customProvider: CustomProvider? + var customProviderName: String? var formattedDisplayName: String { - if provider == .custom, let provider = customProvider { - return "\(provider.name) / \(displayName)" + if provider == .custom, let providerName = customProviderName { + return "\(providerName) / \(displayName)" } return displayName } - init(from customModel: CustomModelInfo, provider: CustomProvider) { + init(from customModel: CustomModelInfo, providerName: String) { self.id = customModel.id self.displayName = customModel.id self.provider = .custom self.defaultOn = false self.supportsVision = false self.supportsSystemPrompts = true - self.customProvider = provider + self.customProviderName = providerName } init?(from modelInfo: ModelInfo) { diff --git a/macos/Onit/Preview Content/PreviewSampleData.swift b/macos/Onit/Preview Content/PreviewSampleData.swift index 8afc7570..f954f935 100644 --- a/macos/Onit/Preview Content/PreviewSampleData.swift +++ b/macos/Onit/Preview Content/PreviewSampleData.swift @@ -17,6 +17,11 @@ actor PreviewSampleData { static var remoteModels: RemoteModelsState = { RemoteModelsState() }() + + @MainActor + static var customProvider: CustomProvider = { + CustomProvider(name: "Provider name", baseURL: "http://google.com", token: "aiZafeoi", models: []) + }() @MainActor static var inMemoryContainer: () throws -> ModelContainer = { diff --git a/macos/Onit/UI/Components/LabeledTextField.swift b/macos/Onit/UI/Components/LabeledTextField.swift new file mode 100644 index 00000000..b658c984 --- /dev/null +++ b/macos/Onit/UI/Components/LabeledTextField.swift @@ -0,0 +1,44 @@ +// +// LabeledTextField.swift +// Onit +// +// Created by Kévin Naudin on 05/02/2025. +// + +import SwiftUI + +struct LabeledTextField: View { + let label: String + @Binding var text: String + + var secure: Bool = false + + var body: some View { + VStack(alignment: .leading, spacing: 4) { + Text(label) + .font(.subheadline) + .foregroundColor(.secondary) + .padding(.leading, 8) + + if secure { + SecureField("", text: $text) + .textFieldStyle(.roundedBorder) + .frame(height: 22) + .font(.system(size: 13, weight: .regular)) + .foregroundColor(.primary) + .textFieldStyle(PlainTextFieldStyle()) + } else { + TextField("", text: $text) + .textFieldStyle(.roundedBorder) + .frame(height: 22) + .font(.system(size: 13, weight: .regular)) + .foregroundColor(.primary) + .textFieldStyle(PlainTextFieldStyle()) + } + } + } +} + +#Preview { + LabeledTextField(label: "Title", text: .constant("")) +} diff --git a/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift b/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift index 2c28450c..9fe69138 100644 --- a/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift +++ b/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift @@ -88,22 +88,57 @@ struct ModelSelectionView: View { } } } + + var custom: some View { + VStack(alignment: .leading, spacing: 2) { + HStack { + Text("Custom models") + .appFont(.medium13) + .foregroundStyle(.white.opacity(0.6)) + Spacer() + if remoteModels.remoteNeedsSetup || (!remoteModels.remoteNeedsSetup && availableRemoteModels.isEmpty) { + Image(.warningSettings) + } + } + .padding(.horizontal, 12) + + if remoteModels.listedModels.isEmpty { + Button("Setup remote models") { + model.settingsTab = .models + openSettings() + } + .buttonStyle(SetUpButtonStyle(showArrow: true)) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(.horizontal, 12) + .padding(.top, 6) + .padding(.bottom, 10) + + } else { + remoteModelsView + } + } + } var remoteModelsView: some View { - Picker("", selection: selectedModel) { - ForEach(remoteModels.listedModels) { model in - Text(model.displayName) - .appFont(.medium14) - .tag(SelectedModel.remote(model)) - .padding(.vertical, 4) + ScrollView { + VStack(alignment: .leading) { + Picker("", selection: selectedModel) { + ForEach(remoteModels.listedModels) { model in + Text(model.formattedDisplayName) + .appFont(.medium14) + .tag(SelectedModel.remote(model)) + .padding(.vertical, 4) + } + } + .pickerStyle(.inline) + .clipped() + .padding(.vertical, 4) + .padding(.bottom, 5) + .padding(.leading, 5) + .tint(.blue600) } } - .frame(maxWidth: .infinity, alignment: .leading) - .pickerStyle(.inline) - .padding(.vertical, 4) - .padding(.bottom, 5) - .padding(.leading, 5) - .tint(.blue600) + .frame(maxHeight: 300) } var divider: some View { diff --git a/macos/Onit/UI/RemoteModelsState.swift b/macos/Onit/UI/RemoteModelsState.swift index 6278ffc7..d2ccf82f 100644 --- a/macos/Onit/UI/RemoteModelsState.swift +++ b/macos/Onit/UI/RemoteModelsState.swift @@ -15,9 +15,13 @@ final class RemoteModelsState: ObservableObject { @ObservationIgnored var availableRemoteModels: [AIModel] - @ObservableDefault(.visibleModelIds) + @ObservableDefault(.availableCustomProvider) @ObservationIgnored - var visibleModelIds: Set + var availableCustomProvider: [CustomProvider] + +// @ObservableDefault(.visibleModelIds) +// @ObservationIgnored +// var visibleModelIds: Set @ObservableDefault(.useOpenAI) @ObservationIgnored @@ -34,13 +38,9 @@ final class RemoteModelsState: ObservableObject { @ObservableDefault(.useGoogleAI) @ObservationIgnored var useGoogleAI: Bool - - var remoteNeedsSetup: Bool { - !useOpenAI && !useAnthropic && !useXAI && !useGoogleAI - } var listedModels: [AIModel] { - var models = availableRemoteModels.filter { visibleModelIds.contains($0.id) } + var models = availableRemoteModels.filter { Defaults[.visibleModelIds].contains($0.id) } if !useOpenAI { models = models.filter { $0.provider != .openAI } @@ -56,14 +56,21 @@ final class RemoteModelsState: ObservableObject { } // Filter out models from disabled custom providers - models = models.filter { model in - if model.provider == .custom, - let provider = model.customProvider { - return provider.isEnabled - } - return true - } +// for customProvider in availableCustomProvider { +// models = models.filter { model in +// if model.customProviderName == customProvider.name { +// print("customProvider.isEnabled: \(customProvider.isEnabled), model.defaultOn: \(model.defaultOn)") +// return customProvider.isEnabled ? model.defaultOn : customProvider.isEnabled +// } +// +// return true +// } +// } return models } + + var remoteNeedsSetup: Bool { + listedModels.isEmpty + } } diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift new file mode 100644 index 00000000..82498ff0 --- /dev/null +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift @@ -0,0 +1,116 @@ +// +// CustomProviderFormView.swift +// Onit +// +// Created by Kévin Naudin on 05/02/2025. +// + +import Defaults +import SwiftUI + +struct CustomProviderFormView: View { + @Environment(\.dismiss) private var dismiss + + @Default(.availableCustomProvider) var availableCustomProviders + @Default(.availableRemoteModels) var availableRemoteModels + @Default(.visibleModelIds) var visibleModelIds + + @State var name: String = "" + @State var baseURL: String = "" + @State var token: String = "" + + @Binding var isSubmitted: Bool + + @State private var errorMessage: String? + + var body: some View { + Form { + + Text("Add a new provider") + .font(.headline) + .padding(.bottom, 12) + + VStack(spacing: 16) { + LabeledTextField(label: "Provider Name", text: $name) + LabeledTextField(label: "URL", text: $baseURL) + LabeledTextField(label: "Token", text: $token, secure: true) + } + + HStack { + Spacer() + + Button("Cancel") { + dismiss() + } + + Button("Verify & Save") { + errorMessage = nil + addProvider() + } + .keyboardShortcut(.defaultAction) + .disabled(name.isEmpty || baseURL.isEmpty || token.isEmpty) + }.padding(.top, 8) + + if let errorMessage = errorMessage{ + Text(errorMessage) + .foregroundColor(.red) + } + } + .padding() + .frame(width: 330) + } + + private func addProvider() { + Task { + do { + + try await fetchModels() + + DispatchQueue.main.async { + isSubmitted = true + dismiss() + } + } catch { + errorMessage = "Failed to fetch models: \(error.localizedDescription)" + } + } + } + + func fetchModels() async throws { + guard let url = URL(string: baseURL) else { return } + + let endpoint = CustomModelsEndpoint(baseURL: url, token: token) + let client = FetchingClient() + let response = try await client.execute(endpoint) + + let models = response.data.map { $0.id } + let provider = CustomProvider( + name: name, + baseURL: baseURL, + token: token, + models: models + ) + + availableCustomProviders.append(provider) + + // Initialize model IDs + let newModels = models.map { modelId in + AIModel(from: CustomModelInfo( + id: modelId, + object: "model", + created: Int(Date().timeIntervalSince1970), + owned_by: name + ), providerName: provider.name) + } + + // Initialize visible model IDs + visibleModelIds = visibleModelIds.union(Set(newModels.map { $0.id })) + + // Add new models to available remote models + availableRemoteModels.append(contentsOf: newModels) + } +} + +#Preview { + CustomProviderFormView(isSubmitted: .constant(false)) +} diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift new file mode 100644 index 00000000..7fd697ee --- /dev/null +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift @@ -0,0 +1,76 @@ +// +// CustomProviderRow.swift +// Onit +// +// Created by Kévin Naudin on 05/02/2025. +// + +import Defaults +import SwiftUI + +struct CustomProviderRow: View { + @Default(.availableRemoteModels) var availableRemoteModels + @Default(.visibleModelIds) var visibleModelIds + @Binding var provider: CustomProvider + + private var providerModels: [AIModel] { + availableRemoteModels.filter { $0.customProviderName == provider.name } + } + + var body: some View { + VStack(alignment: .leading) { + HStack { + Text(provider.name) + .font(.system(size: 13)) + + Spacer() + + Toggle("", isOn: $provider.isEnabled) + .toggleStyle(.switch) + .controlSize(.small) + +// Button(role: .destructive) { +// +// // TODO: KNA - Remove from visible also +// // Remove provider's models from available remote models +//// Defaults[.availableRemoteModels].removeAll { model in +//// model.customProvider?.id == provider.id +//// } +// +// } label: { +// Image(systemName: "trash") +// } + } + + if provider.isEnabled { + GroupBox { + VStack(alignment: .leading, spacing: 0) { + ForEach(providerModels, id: \.self) { model in + ModelToggle(aiModel: model) + .frame(height: 36) + } + } + .padding(.vertical, -4) + .padding(.horizontal, 4) + .frame(maxWidth: .infinity, alignment: .leading) + } + } + } + .cornerRadius(8) + .onChange(of: provider.isEnabled, initial: false) { old, new in + let modelIds = Set(provider.models) + + if new { + visibleModelIds.formUnion(modelIds) + } else { + visibleModelIds.subtract(modelIds) + } + + print(visibleModelIds) + } + } +} + +#Preview { + CustomProviderRow(provider: .constant(PreviewSampleData.customProvider)) +} diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift new file mode 100644 index 00000000..adb4c5a6 --- /dev/null +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift @@ -0,0 +1,59 @@ +import Defaults +import SwiftUI +import SwiftData + +struct CustomProvidersSection: View { + @Default(.availableCustomProvider) private var availableCustomProvider + + @State private var showForm = false + @State private var isSubmitted = false + + var body: some View { + VStack(alignment: .leading, spacing: 10) { + title + caption + + ForEach($availableCustomProvider) { provider in + CustomProviderRow(provider: provider) + } + } + } + var title: some View { + HStack { + Text("Other providers / private models") + .font(.system(size: 13)) + + Spacer() + + Button { + showForm = true + } label: { + Text("Add new") + } + .foregroundStyle(.white) + .buttonStyle(.borderedProminent) + .frame(height: 22) + .fontWeight(.regular) + } + .sheet(isPresented: $showForm) { + CustomProviderFormView(isSubmitted: $isSubmitted) + } + .onChange(of: isSubmitted, initial: false) { old, new in + if new { + isSubmitted = false + } + } + } + + var caption: some View { + Text("Manually add remote models from other providers.") + .foregroundStyle(.foreground.opacity(0.65)) + .fontWeight(.regular) + .font(.system(size: 12)) + } + +} + +#Preview { + CustomProvidersSection() +} diff --git a/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift b/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift deleted file mode 100644 index 9a88fc6c..00000000 --- a/macos/Onit/UI/Settings/Models/CustomProvidersSection.swift +++ /dev/null @@ -1,150 +0,0 @@ -import SwiftUI -import SwiftData - -struct CustomProvidersSection: View { - @Environment(\.model) var model - @Environment(\.modelContext) private var modelContext - @Query private var customProviders: [CustomProvider] - - @State private var isAddingProvider = false - @State private var newProviderName = "" - @State private var newProviderURL = "" - @State private var newProviderToken = "" - @State private var errorMessage: String? - - var body: some View { - RemoteModelSection(title: "Custom Providers") { - VStack(alignment: .leading, spacing: 12) { - ForEach(customProviders) { provider in - CustomProviderRow(provider: provider) - } - - if isAddingProvider { - VStack(alignment: .leading, spacing: 8) { - TextField("Provider Name", text: $newProviderName) - .textFieldStyle(.roundedBorder) - - TextField("Base URL", text: $newProviderURL) - .textFieldStyle(.roundedBorder) - - SecureField("API Token", text: $newProviderToken) - .textFieldStyle(.roundedBorder) - - HStack { - Button("Add Provider") { - addProvider() - } - .disabled(newProviderName.isEmpty || newProviderURL.isEmpty || newProviderToken.isEmpty) - - Button("Cancel") { - isAddingProvider = false - resetForm() - } - } - - if let error = errorMessage { - Text(error) - .foregroundColor(.red) - .font(.caption) - } - } - .padding() - .background(Color(.textBackgroundColor)) - .cornerRadius(8) - } else { - Button("Add Custom Provider") { - isAddingProvider = true - } - } - } - } - } - - private func addProvider() { - Task { - do { - let provider = CustomProvider( - name: newProviderName, - baseURL: newProviderURL, - token: newProviderToken - ) - - try await provider.fetchModels() - modelContext.insert(provider) - - isAddingProvider = false - resetForm() - } catch { - errorMessage = "Failed to fetch models: \(error.localizedDescription)" - } - } - } - - private func resetForm() { - newProviderName = "" - newProviderURL = "" - newProviderToken = "" - errorMessage = nil - } -} - -struct CustomProviderRow: View { - @Environment(\.model) var model - @Environment(\.modelContext) private var modelContext - - let provider: CustomProvider - - var body: some View { - VStack(alignment: .leading) { - HStack { - Toggle(isOn: .init( - get: { provider.isEnabled }, - set: { provider.isEnabled = $0 } - )) { - Text(provider.name) - .font(.headline) - } - - Spacer() - - Button(role: .destructive) { - // Remove provider's models from available remote models - model.updatePreferences { prefs in - prefs.availableRemoteModels.removeAll { model in - model.customProvider?.id == provider.id - } - } - modelContext.delete(provider) - } label: { - Image(systemName: "trash") - } - } - - if provider.isEnabled { - GroupBox { - VStack(alignment: .leading, spacing: 0) { - ForEach(provider.models, id: \.self) { modelId in - let aiModel = AIModel( - from: CustomModelInfo( - id: modelId, - object: "model", - created: Int(Date().timeIntervalSince1970), - owned_by: provider.name - ), - provider: provider - ) - ModelToggle(aiModel: aiModel) - .frame(height: 36) - } - } - .padding(.vertical, -4) - .padding(.horizontal, 4) - .frame(maxWidth: .infinity, alignment: .leading) - } - } - } - .padding() - .background(Color(.textBackgroundColor)) - .cornerRadius(8) - } -} diff --git a/macos/Onit/UI/Settings/Models/RemoteModelsSection.swift b/macos/Onit/UI/Settings/Models/RemoteModelsSection.swift index 6a2cc1bc..d7cb6d6b 100644 --- a/macos/Onit/UI/Settings/Models/RemoteModelsSection.swift +++ b/macos/Onit/UI/Settings/Models/RemoteModelsSection.swift @@ -11,11 +11,18 @@ struct RemoteModelsSection: View { @Environment(\.model) var model var body: some View { - ModelsSection(title: "Remote Models") { - RemoteModelSection(provider: .openAI) - RemoteModelSection(provider: .anthropic) - RemoteModelSection(provider: .xAI) - RemoteModelSection(provider: .googleAI) + ScrollView { + ModelsSection(title: "Remote Models") { + RemoteModelSection(provider: .openAI) + RemoteModelSection(provider: .anthropic) + RemoteModelSection(provider: .xAI) + RemoteModelSection(provider: .googleAI) + CustomProvidersSection() + } } } } + +#Preview { + RemoteModelsSection() +} diff --git a/macos/Onit/UI/Settings/ModelsTab.swift b/macos/Onit/UI/Settings/ModelsTab.swift index a906ff8e..977f19fe 100644 --- a/macos/Onit/UI/Settings/ModelsTab.swift +++ b/macos/Onit/UI/Settings/ModelsTab.swift @@ -16,7 +16,6 @@ struct ModelsTab: View { RemoteModelsSection() LocalModelsSection() DefaultModelsSection() - CustomProvidersSection() } .padding(.vertical, 20) .padding(.horizontal, 86) From 766f5c1b7e3fd154e17079ce1d676cebaadd6d49 Mon Sep 17 00:00:00 2001 From: timl Date: Wed, 5 Feb 2025 10:59:32 -0800 Subject: [PATCH 3/7] fix CustomModelsResponse --- .../Endpoints/CustomModelsEndpoint.swift | 46 +++++++++++++++++-- .../Fetching/FetchingClient+Execute.swift | 2 +- macos/Onit/Data/Structures/AIModel.swift | 6 ++- .../CustomProviderFormView.swift | 13 ++---- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift b/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift index 5f92a5e5..8dc421be 100644 --- a/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift +++ b/macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift @@ -18,13 +18,51 @@ struct CustomModelsEndpoint: Endpoint { } struct CustomModelsResponse: Codable { - let object: String + let object: String? let data: [CustomModelInfo] } struct CustomModelInfo: Codable { let id: String - let object: String - let created: Int - let owned_by: String + + // These are the OpenRouter fields + let name: String? + let created: Int? + let description: String? + let context_length: Int? + let architecture: Architecture? + let pricing: Pricing? + let top_provider: TopProvider? + let per_request_limits: PerRequestLimits? + + // These are the Groq fields + let context_window: Int? + let object: String? + let owned_by: String? + let active: Bool? + + // 'id' is the only mutual field, so it's the only thing we can require... +} + +struct Architecture: Codable { + let modality: String? + let tokenizer: String? + let instruct_type: String? +} + +struct Pricing: Codable { + let prompt: String? + let completion: String? + let image: String? + let request: String? +} + +struct TopProvider: Codable { + let context_length: Int? + let max_completion_tokens: Int? + let is_moderated: Bool? +} + +struct PerRequestLimits: Codable { + // Define fields if necessary } \ No newline at end of file diff --git a/macos/Onit/Data/Fetching/FetchingClient+Execute.swift b/macos/Onit/Data/Fetching/FetchingClient+Execute.swift index 306a141f..9a90f80a 100644 --- a/macos/Onit/Data/Fetching/FetchingClient+Execute.swift +++ b/macos/Onit/Data/Fetching/FetchingClient+Execute.swift @@ -28,7 +28,7 @@ extension FetchingClient { } // Helpful debugging method- put in the endpoint name and you can see the full request - if endpoint.path.contains("/v1beta/models") { + if endpoint.path.contains("/v1/models") { printCurlRequest(endpoint: endpoint, url: url) print("here") } diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 309eb600..67895676 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -53,7 +53,11 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { let client = FetchingClient() let endpoint = RemoteModelsEndpoint() let response = try await client.execute(endpoint) - return response.models.compactMap { AIModel(from: $0) } + let onitModels = response.models.compactMap { AIModel(from: $0) } + + // TODO include custom models + + return onitModels } enum ModelProvider: String, Codable, Equatable, Hashable, Defaults.Serializable { diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift index 82498ff0..bbd2d452 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift @@ -83,24 +83,19 @@ struct CustomProviderFormView: View { let client = FetchingClient() let response = try await client.execute(endpoint) - let models = response.data.map { $0.id } + let modelIds = response.data.map { $0.id } let provider = CustomProvider( name: name, baseURL: baseURL, token: token, - models: models + models: modelIds ) availableCustomProviders.append(provider) // Initialize model IDs - let newModels = models.map { modelId in - AIModel(from: CustomModelInfo( - id: modelId, - object: "model", - created: Int(Date().timeIntervalSince1970), - owned_by: name - ), providerName: provider.name) + let newModels = response.data.map { model in + AIModel(from: model, providerName: provider.name) } // Initialize visible model IDs From 91878fb030ceafc363dca93aa1d70fa10a4b3391 Mon Sep 17 00:00:00 2001 From: timl Date: Wed, 5 Feb 2025 11:27:21 -0800 Subject: [PATCH 4/7] add uniqueId to AIModel and add custom providers to initial model fetch --- macos/Onit/Data/Persistence/Defaults.swift | 3 +++ macos/Onit/Data/Structures/AIModel.swift | 22 +++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/macos/Onit/Data/Persistence/Defaults.swift b/macos/Onit/Data/Persistence/Defaults.swift index 3d251f85..64c44526 100644 --- a/macos/Onit/Data/Persistence/Defaults.swift +++ b/macos/Onit/Data/Persistence/Defaults.swift @@ -51,7 +51,10 @@ extension Defaults.Keys { static let availableLocalModels = Key<[String]>("availableLocalModels", default: []) static let availableRemoteModels = Key<[AIModel]>("availableRemoteModels", default: []) static let availableCustomProvider = Key<[CustomProvider]>("availableCustomProvider", default: []) + + // Updated visible model identifiers to use composite ModelID (combining provider and id) static let visibleModelIds = Key>("visibleModelIds", default: Set([])) + static let localEndpointURL = Key("localEndpointURL", default: URL(string: "http://localhost:11434")!) // Feature flags diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 67895676..71dba9cb 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -26,6 +26,13 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { return displayName } + var uniqueId: String { + if provider == .custom, let providerName = customProviderName { + return "\(providerName)-\(id)" + } + return "\(provider)-\(id)" + } + init(from customModel: CustomModelInfo, providerName: String) { self.id = customModel.id self.displayName = customModel.id @@ -53,11 +60,20 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { let client = FetchingClient() let endpoint = RemoteModelsEndpoint() let response = try await client.execute(endpoint) - let onitModels = response.models.compactMap { AIModel(from: $0) } + let remoteModels = response.models.compactMap { AIModel(from: $0) } - // TODO include custom models + var customModels: [AIModel] = [] + for provider in Defaults[.availableCustomProvider] { + do { + let models = try await provider.fetchModels() + let custom = models.map { AIModel(from: $0, providerName: provider.name) } + customModels.append(contentsOf: custom) + } catch { + print("Error fetching custom models for provider \(provider.name): \(error)") + } + } - return onitModels + return remoteModels + customModels } enum ModelProvider: String, Codable, Equatable, Hashable, Defaults.Serializable { From 02f01dd5b1303b2c60c78ebbbefc16fa6092c24e Mon Sep 17 00:00:00 2001 From: openhands Date: Wed, 5 Feb 2025 19:32:26 +0000 Subject: [PATCH 5/7] Add model ID migration support for custom providers - Add migration from legacy model IDs to unique provider-based IDs - Add testing utilities for verifying migration - Update model visibility tracking to use unique IDs --- macos/Onit/Data/Model/OnitModel.swift | 15 +++++-- macos/Onit/Data/Persistence/Defaults.swift | 3 +- macos/Onit/Data/Structures/AIModel.swift | 18 ++++++++ macos/Onit/Testing/ModelMigrationTests.swift | 45 +++++++++++++++++++ .../Onit/UI/Settings/Models/ModelToggle.swift | 8 ++-- 5 files changed, 80 insertions(+), 9 deletions(-) create mode 100644 macos/Onit/Testing/ModelMigrationTests.swift diff --git a/macos/Onit/Data/Model/OnitModel.swift b/macos/Onit/Data/Model/OnitModel.swift index 131d0abc..b807da21 100644 --- a/macos/Onit/Data/Model/OnitModel.swift +++ b/macos/Onit/Data/Model/OnitModel.swift @@ -111,13 +111,20 @@ import Defaults @MainActor func fetchRemoteModels() async { do { - // if var models = try await AIModel.fetchModels() + // Migrate legacy model IDs if needed + if !Defaults[.hasPerformedModelIdMigration] { + let legacyIds = Defaults[.visibleModelIds] + let migratedIds = AIModel.migrateVisibleModelIds(models: models, legacyIds: legacyIds) + Defaults[.visibleModelIds] = migratedIds + Defaults[.hasPerformedModelIdMigration] = true + } + // This means we've never successfully fetched before if Defaults[.availableLocalModels].isEmpty { if Defaults[.visibleModelIds].isEmpty { - Defaults[.visibleModelIds] = Set(models.filter { $0.defaultOn }.map { $0.id }) + Defaults[.visibleModelIds] = Set(models.filter { $0.defaultOn }.map { $0.uniqueId }) } Defaults[.availableRemoteModels] = models @@ -143,12 +150,12 @@ import Defaults // We only save deprecated models if the user has them visibile. Otherwise, quietly remove them from the list. let visibleModelIds = Set(Defaults[.visibleModelIds]) - let visibleDeprecatedModels = deprecatedModels.filter { visibleModelIds.contains($0.id) } + let visibleDeprecatedModels = deprecatedModels.filter { visibleModelIds.contains($0.uniqueId) } remoteFetchFailed = false Defaults[.availableRemoteModels] = models + visibleDeprecatedModels if visibleModelIds.isEmpty { - Defaults[.visibleModelIds] = Set((models + visibleDeprecatedModels).filter { $0.defaultOn }.map { $0.id }) + Defaults[.visibleModelIds] = Set((models + visibleDeprecatedModels).filter { $0.defaultOn }.map { $0.uniqueId }) } if !remoteModels.listedModels.isEmpty && (Defaults[.remoteModel] == nil || !Defaults[.availableRemoteModels].contains(Defaults[.remoteModel]!)) { diff --git a/macos/Onit/Data/Persistence/Defaults.swift b/macos/Onit/Data/Persistence/Defaults.swift index 64c44526..aaa74f87 100644 --- a/macos/Onit/Data/Persistence/Defaults.swift +++ b/macos/Onit/Data/Persistence/Defaults.swift @@ -52,8 +52,9 @@ extension Defaults.Keys { static let availableRemoteModels = Key<[AIModel]>("availableRemoteModels", default: []) static let availableCustomProvider = Key<[CustomProvider]>("availableCustomProvider", default: []) - // Updated visible model identifiers to use composite ModelID (combining provider and id) + // Stores unique model identifiers in the format "provider-id" or "customProviderName-id" for custom providers static let visibleModelIds = Key>("visibleModelIds", default: Set([])) + static let hasPerformedModelIdMigration = Key("hasPerformedModelIdMigration", default: false) static let localEndpointURL = Key("localEndpointURL", default: URL(string: "http://localhost:11434")!) diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 71dba9cb..56e737a2 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -33,6 +33,24 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { return "\(provider)-\(id)" } + // Helper method to check if a legacy ID matches this model + func matchesLegacyId(_ legacyId: String) -> Bool { + return id == legacyId + } + + // Helper method to migrate legacy IDs to unique IDs + static func migrateVisibleModelIds(models: [AIModel], legacyIds: Set) -> Set { + var newIds = Set() + + // For each legacy ID, find all matching models and add their unique IDs + for legacyId in legacyIds { + let matchingModels = models.filter { $0.matchesLegacyId(legacyId) } + newIds.formUnion(matchingModels.map { $0.uniqueId }) + } + + return newIds + } + init(from customModel: CustomModelInfo, providerName: String) { self.id = customModel.id self.displayName = customModel.id diff --git a/macos/Onit/Testing/ModelMigrationTests.swift b/macos/Onit/Testing/ModelMigrationTests.swift new file mode 100644 index 00000000..e417ff2b --- /dev/null +++ b/macos/Onit/Testing/ModelMigrationTests.swift @@ -0,0 +1,45 @@ +// +// ModelMigrationTests.swift +// Onit +// +// Created by OpenHands on 2/13/25. +// + +import Foundation +import Defaults + +#if DEBUG +class ModelMigrationTests { + static func resetToLegacyState(legacyIds: Set) { + // Reset migration flag + Defaults[.hasPerformedModelIdMigration] = false + + // Set legacy model IDs + Defaults[.visibleModelIds] = legacyIds + } + + static func printMigrationState() { + print("=== Model Migration State ===") + print("Has performed migration:", Defaults[.hasPerformedModelIdMigration]) + print("Visible Model IDs:", Defaults[.visibleModelIds]) + print("Available Remote Models:", Defaults[.availableRemoteModels].map { "\($0.provider): \($0.id) -> \($0.uniqueId)" }) + print("=========================") + } + + static func testMigration() async { + // Example legacy state with duplicate model IDs + let legacyIds: Set = ["gpt-4", "claude-3", "gemini-pro"] + + // Reset to legacy state + resetToLegacyState(legacyIds: legacyIds) + print("Before migration:") + printMigrationState() + + // Let the app perform migration (you'll need to trigger model fetch) + // This will happen automatically when fetchRemoteModels is called + + print("\nAfter migration:") + printMigrationState() + } +} +#endif \ No newline at end of file diff --git a/macos/Onit/UI/Settings/Models/ModelToggle.swift b/macos/Onit/UI/Settings/Models/ModelToggle.swift index 6661cc74..9ba49718 100644 --- a/macos/Onit/UI/Settings/Models/ModelToggle.swift +++ b/macos/Onit/UI/Settings/Models/ModelToggle.swift @@ -17,12 +17,12 @@ struct ModelToggle: View { var isOn: Binding { Binding { - visibleModelIds.contains(aiModel.id) + visibleModelIds.contains(aiModel.uniqueId) } set: { isOn in if isOn { - visibleModelIds.insert(aiModel.id) + visibleModelIds.insert(aiModel.uniqueId) } else { - visibleModelIds.remove(aiModel.id) + visibleModelIds.remove(aiModel.uniqueId) } } } @@ -52,7 +52,7 @@ struct ModelToggle: View { .onDisappear() { if aiModel.isNew { // Once we've displayed the "NEW" tag in settings, the model is no longer new - if let index = availableRemoteModels.firstIndex(where: { $0.id == aiModel.id }) { + if let index = availableRemoteModels.firstIndex(where: { $0.uniqueId == aiModel.uniqueId }) { availableRemoteModels[index].isNew = false } } From 268588ddec01a5197532155667e2521b5dfdd456 Mon Sep 17 00:00:00 2001 From: timl Date: Wed, 5 Feb 2025 13:02:53 -0800 Subject: [PATCH 6/7] fixing implementation; move api call into chat method, among other things --- macos/Onit/Data/Fetching/FetchingClient.swift | 51 +++++++++++++++- macos/Onit/Data/Model/CustomProvider.swift | 20 ++++++- macos/Onit/Data/Model/Model+Chat.swift | 58 ++++--------------- macos/Onit/Data/Model/OnitModel.swift | 24 ++++---- macos/Onit/Data/Structures/AIModel.swift | 12 +--- .../Prompt/Selection/ModelSelectionView.swift | 2 +- macos/Onit/UI/RemoteModelsState.swift | 20 +++---- .../CustomProviderFormView.swift | 29 +++------- .../CustomProvider/CustomProviderRow.swift | 54 ++++++++--------- .../Onit/UI/Settings/Models/ModelToggle.swift | 2 +- 10 files changed, 134 insertions(+), 138 deletions(-) diff --git a/macos/Onit/Data/Fetching/FetchingClient.swift b/macos/Onit/Data/Fetching/FetchingClient.swift index cdeefcf8..d6f168c9 100644 --- a/macos/Onit/Data/Fetching/FetchingClient.swift +++ b/macos/Onit/Data/Fetching/FetchingClient.swift @@ -7,6 +7,7 @@ import Foundation import UniformTypeIdentifiers +import Defaults actor FetchingClient { let session = URLSession.shared @@ -245,7 +246,55 @@ actor FetchingClient { let response = try await execute(endpoint) return response.choices[0].message.content case .custom: - return "" // TODO: KNA - + + var openAIMessageStack: [OpenAIChatMessage] = [] + + // Initialize messages with system prompt if needed + // if model.supportsSystemPrompts { + + // 3rd Party model providers don't tell us if system prompts are enabled or not... + // How to handle? I guess the user needs to be able to toggle system prompts for each custom provider model. + openAIMessageStack.append(OpenAIChatMessage(role: "system", content: .text(systemMessage))) + + for (index, userMessage) in userMessages.enumerated() { + if images[index].isEmpty { + let openAIMessage = OpenAIChatMessage(role: "user", content: .text(userMessage)) + openAIMessageStack.append(openAIMessage) + } else { + var parts = [OpenAIChatContentPart(type: "text", text: userMessage, image_url: nil)] + for url in images[index] { + if let imageData = try? Data(contentsOf: url) { + let base64EncodedData = imageData.base64EncodedString() + let mimeType = mimeType(for: url) + let imagePart = OpenAIChatContentPart( + type: "image_url", + text: nil, + image_url: .init(url: "data:\(mimeType);base64,\(base64EncodedData)") + ) + parts.append(imagePart) + } else { + print("Unable to read image data from URL: \(url)") + } + } + let openAIMessage = OpenAIChatMessage(role: "user", content: .multiContent(parts)) + openAIMessageStack.append(openAIMessage) + } + + // If there is a corresponding response, add it as an assistant message + if index < responses.count { + let responseMessage = OpenAIChatMessage(role: "assistant", content: .text(responses[index])) + openAIMessageStack.append(responseMessage) + } + } + + if let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == model.customProviderName }) { + let url = URL(string: customProvider.baseURL)! + let endpoint = CustomChatEndpoint(baseURL: url, messages: openAIMessageStack, token: customProvider.token, model: model.id) + let response = try await execute(endpoint) + return response.choices[0].message.content + } else { + throw FetchingError.invalidRequest(message: "Custom provider not found") + } } } diff --git a/macos/Onit/Data/Model/CustomProvider.swift b/macos/Onit/Data/Model/CustomProvider.swift index 78244be5..6afb5dec 100644 --- a/macos/Onit/Data/Model/CustomProvider.swift +++ b/macos/Onit/Data/Model/CustomProvider.swift @@ -2,16 +2,16 @@ import Defaults import Foundation import SwiftData -struct CustomProvider: Codable, Identifiable, Hashable, Defaults.Serializable { +class CustomProvider: Codable, Identifiable, Defaults.Serializable { var id: String { name } var name: String var baseURL: String var token: String - var models: [String] + var models: [AIModel] var isEnabled: Bool - init(name: String, baseURL: String, token: String, models: [String]) { + init(name: String, baseURL: String, token: String, models: [AIModel]) { self.name = name self.baseURL = baseURL self.token = token @@ -24,4 +24,18 @@ struct CustomProvider: Codable, Identifiable, Hashable, Defaults.Serializable { && lhs.token == rhs.token && lhs.models == rhs.models && lhs.isEnabled == rhs.isEnabled } + + @MainActor + func fetchModels() async throws { + guard let url = URL(string: baseURL) else { throw URLError(.badURL) } + + let endpoint = CustomModelsEndpoint(baseURL: url, token: token) + let client = FetchingClient() + let response = try await client.execute(endpoint) + + // Initialize model IDs + models = response.data.map { model in + AIModel(from: model, providerName: name) + } + } } diff --git a/macos/Onit/Data/Model/Model+Chat.swift b/macos/Onit/Data/Model/Model+Chat.swift index 0797291b..9ddded3e 100644 --- a/macos/Onit/Data/Model/Model+Chat.swift +++ b/macos/Onit/Data/Model/Model+Chat.swift @@ -84,38 +84,18 @@ extension OnitModel { do { let chat: String - if Defaults[.mode] == .remote { - if let customProviderName = Defaults[.remoteModel]?.customProviderName, - let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == customProviderName }) { - - // Handle custom provider chat - let messages = createOpenAIMessages( - instructions: instructionsHistory, - responses: responsesHistory - ) - if let endpoint = getCustomEndpoint( - for: customProvider, - messages: messages, - model: Defaults[.remoteModel]?.id ?? "" - ) { - let response = try await client.execute(endpoint) - chat = response.choices[0].message.content - } else { - throw FetchingError.invalidURL - } - } else { - // Regular remote model chat - chat = try await client.chat( - instructions: instructionsHistory, - inputs: inputsHistory, - files: filesHistory, - images: imagesHistory, - autoContexts: autoContextsHistory, - responses: responsesHistory, - model: Defaults[.remoteModel], - apiToken: getTokenForModel(Defaults[.remoteModel] ?? nil) - ) - } + if Defaults[.mode] == .remote { + // Regular remote model chat + chat = try await client.chat( + instructions: instructionsHistory, + inputs: inputsHistory, + files: filesHistory, + images: imagesHistory, + autoContexts: autoContextsHistory, + responses: responsesHistory, + model: Defaults[.remoteModel], + apiToken: getTokenForModel(Defaults[.remoteModel] ?? nil) + ) } else { // TODO implement history for local chat! chat = try await client.localChat( @@ -235,18 +215,4 @@ extension OnitModel { prompt.generationIndex = (prompt.responses.count - 1) prompt.generationState = .done } - - private func createOpenAIMessages(instructions: [String], responses: [String]) -> [OpenAIChatMessage] { - var messages: [OpenAIChatMessage] = [] - - for (index, instruction) in instructions.enumerated() { - messages.append(OpenAIChatMessage(role: "user", content: .text(instruction))) - - if index < responses.count { - messages.append(OpenAIChatMessage(role: "assistant", content: .text(responses[index]))) - } - } - - return messages - } } diff --git a/macos/Onit/Data/Model/OnitModel.swift b/macos/Onit/Data/Model/OnitModel.swift index b807da21..5f4bca3d 100644 --- a/macos/Onit/Data/Model/OnitModel.swift +++ b/macos/Onit/Data/Model/OnitModel.swift @@ -71,12 +71,7 @@ import Defaults var remoteFetchFailed: Bool = false var localFetchFailed: Bool = false - - func getCustomEndpoint(for provider: CustomProvider, messages: [OpenAIChatMessage], model: String) -> CustomChatEndpoint? { - guard let url = URL(string: provider.baseURL) else { return nil } - return CustomChatEndpoint(baseURL: url, messages: messages, token: provider.token, model: model) - } - + @MainActor func fetchLocalModels() async { do { @@ -113,13 +108,7 @@ import Defaults do { var models = try await AIModel.fetchModels() - // Migrate legacy model IDs if needed - if !Defaults[.hasPerformedModelIdMigration] { - let legacyIds = Defaults[.visibleModelIds] - let migratedIds = AIModel.migrateVisibleModelIds(models: models, legacyIds: legacyIds) - Defaults[.visibleModelIds] = migratedIds - Defaults[.hasPerformedModelIdMigration] = true - } + // This means we've never successfully fetched before if Defaults[.availableLocalModels].isEmpty { @@ -134,6 +123,15 @@ import Defaults // If relevant shrink the dialog box to account for the removed SetupDialog. shrinkContent() } else { + + // Migrate legacy model IDs if needed + if !Defaults[.hasPerformedModelIdMigration] { + let legacyIds = Defaults[.visibleModelIds] + let migratedIds = AIModel.migrateVisibleModelIds(models: Defaults[.availableRemoteModels], legacyIds: legacyIds) + Defaults[.visibleModelIds] = migratedIds + Defaults[.hasPerformedModelIdMigration] = true + } + // Update the availableRemoteModels with the newly fetched models let newModelIds = Set(models.map { $0.id }) let existingModelIds = Set(Defaults[.availableRemoteModels].map { $0.id }) diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 56e737a2..5d89d6cd 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -19,13 +19,6 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { var isDeprecated: Bool = false var customProviderName: String? - var formattedDisplayName: String { - if provider == .custom, let providerName = customProviderName { - return "\(providerName) / \(displayName)" - } - return displayName - } - var uniqueId: String { if provider == .custom, let providerName = customProviderName { return "\(providerName)-\(id)" @@ -83,9 +76,8 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { var customModels: [AIModel] = [] for provider in Defaults[.availableCustomProvider] { do { - let models = try await provider.fetchModels() - let custom = models.map { AIModel(from: $0, providerName: provider.name) } - customModels.append(contentsOf: custom) + try await provider.fetchModels() + customModels.append(contentsOf: provider.models) } catch { print("Error fetching custom models for provider \(provider.name): \(error)") } diff --git a/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift b/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift index 9fe69138..72462ea2 100644 --- a/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift +++ b/macos/Onit/UI/Prompt/Selection/ModelSelectionView.swift @@ -124,7 +124,7 @@ struct ModelSelectionView: View { VStack(alignment: .leading) { Picker("", selection: selectedModel) { ForEach(remoteModels.listedModels) { model in - Text(model.formattedDisplayName) + Text(model.displayName) .appFont(.medium14) .tag(SelectedModel.remote(model)) .padding(.vertical, 4) diff --git a/macos/Onit/UI/RemoteModelsState.swift b/macos/Onit/UI/RemoteModelsState.swift index d2ccf82f..8a186318 100644 --- a/macos/Onit/UI/RemoteModelsState.swift +++ b/macos/Onit/UI/RemoteModelsState.swift @@ -40,7 +40,7 @@ final class RemoteModelsState: ObservableObject { var useGoogleAI: Bool var listedModels: [AIModel] { - var models = availableRemoteModels.filter { Defaults[.visibleModelIds].contains($0.id) } + var models = availableRemoteModels.filter { Defaults[.visibleModelIds].contains($0.uniqueId) } if !useOpenAI { models = models.filter { $0.provider != .openAI } @@ -56,16 +56,14 @@ final class RemoteModelsState: ObservableObject { } // Filter out models from disabled custom providers -// for customProvider in availableCustomProvider { -// models = models.filter { model in -// if model.customProviderName == customProvider.name { -// print("customProvider.isEnabled: \(customProvider.isEnabled), model.defaultOn: \(model.defaultOn)") -// return customProvider.isEnabled ? model.defaultOn : customProvider.isEnabled -// } -// -// return true -// } -// } + for customProvider in availableCustomProvider { + models = models.filter { model in + if model.customProviderName == customProvider.name { + return customProvider.isEnabled + } + return true + } + } return models } diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift index bbd2d452..88639270 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift @@ -13,7 +13,7 @@ struct CustomProviderFormView: View { @Default(.availableCustomProvider) var availableCustomProviders @Default(.availableRemoteModels) var availableRemoteModels - @Default(.visibleModelIds) var visibleModelIds + @State var name: String = "" @State var baseURL: String = "" @@ -65,7 +65,6 @@ struct CustomProviderFormView: View { do { try await fetchModels() - DispatchQueue.main.async { isSubmitted = true dismiss() @@ -77,32 +76,18 @@ struct CustomProviderFormView: View { } func fetchModels() async throws { - guard let url = URL(string: baseURL) else { return } - - let endpoint = CustomModelsEndpoint(baseURL: url, token: token) - let client = FetchingClient() - let response = try await client.execute(endpoint) - - let modelIds = response.data.map { $0.id } - let provider = CustomProvider( + var provider = CustomProvider( name: name, baseURL: baseURL, token: token, - models: modelIds + models: [] ) - availableCustomProviders.append(provider) + try await provider.fetchModels() - // Initialize model IDs - let newModels = response.data.map { model in - AIModel(from: model, providerName: provider.name) - } - - // Initialize visible model IDs - visibleModelIds = visibleModelIds.union(Set(newModels.map { $0.id })) - - // Add new models to available remote models - availableRemoteModels.append(contentsOf: newModels) + // If the above doesn't crash, we're good! + availableCustomProviders.append(provider) + availableRemoteModels.append(contentsOf: provider.models) } } diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift index 7fd697ee..ad145522 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift @@ -13,6 +13,14 @@ struct CustomProviderRow: View { @Default(.visibleModelIds) var visibleModelIds @Binding var provider: CustomProvider + @State private var searchText: String = "" + + private var filteredProviderModels: [AIModel] { + providerModels.filter { model in + searchText.isEmpty || model.displayName.localizedCaseInsensitiveContains(searchText) + } + } + private var providerModels: [AIModel] { availableRemoteModels.filter { $0.customProviderName == provider.name } } @@ -28,46 +36,32 @@ struct CustomProviderRow: View { Toggle("", isOn: $provider.isEnabled) .toggleStyle(.switch) .controlSize(.small) - -// Button(role: .destructive) { -// -// // TODO: KNA - Remove from visible also -// // Remove provider's models from available remote models -//// Defaults[.availableRemoteModels].removeAll { model in -//// model.customProvider?.id == provider.id -//// } -// -// } label: { -// Image(systemName: "trash") -// } } if provider.isEnabled { GroupBox { - VStack(alignment: .leading, spacing: 0) { - ForEach(providerModels, id: \.self) { model in - ModelToggle(aiModel: model) - .frame(height: 36) + VStack { + TextField("Search models", text: $searchText) + .textFieldStyle(RoundedBorderTextFieldStyle()) + .padding(.horizontal, 4) + + ScrollView { + VStack(alignment: .leading, spacing: 0) { + ForEach(filteredProviderModels, id: \.self) { model in + ModelToggle(aiModel: model) + .frame(height: 36) + } + } + .padding(.vertical, -4) + .padding(.horizontal, 4) + .frame(maxWidth: .infinity, alignment: .leading) } + .frame(maxHeight: 5 * 36) // Limit to 5 rows } - .padding(.vertical, -4) - .padding(.horizontal, 4) - .frame(maxWidth: .infinity, alignment: .leading) } } } .cornerRadius(8) - .onChange(of: provider.isEnabled, initial: false) { old, new in - let modelIds = Set(provider.models) - - if new { - visibleModelIds.formUnion(modelIds) - } else { - visibleModelIds.subtract(modelIds) - } - - print(visibleModelIds) - } } } diff --git a/macos/Onit/UI/Settings/Models/ModelToggle.swift b/macos/Onit/UI/Settings/Models/ModelToggle.swift index 9ba49718..19abef4b 100644 --- a/macos/Onit/UI/Settings/Models/ModelToggle.swift +++ b/macos/Onit/UI/Settings/Models/ModelToggle.swift @@ -30,7 +30,7 @@ struct ModelToggle: View { var body: some View { Toggle(isOn: isOn) { HStack { - Text(aiModel.formattedDisplayName) + Text(aiModel.displayName) .font(.system(size: 13)) .fontWeight(.regular) .opacity(0.85) From 606c653e779347b1c4f6697038e228b71905a481 Mon Sep 17 00:00:00 2001 From: timl Date: Wed, 5 Feb 2025 13:45:46 -0800 Subject: [PATCH 7/7] add ability to change token and delete custom providers --- macos/Onit/Data/Fetching/FetchingClient.swift | 2 +- .../Data/Model/Model+TokenValidation.swift | 2 +- macos/Onit/Data/Persistence/Defaults.swift | 2 +- macos/Onit/Data/Structures/AIModel.swift | 2 +- macos/Onit/UI/RemoteModelsState.swift | 2 +- .../CustomProviderFormView.swift | 37 +++++---- .../CustomProvider/CustomProviderRow.swift | 80 ++++++++++++++++++- .../CustomProvidersSection.swift | 2 +- 8 files changed, 103 insertions(+), 26 deletions(-) diff --git a/macos/Onit/Data/Fetching/FetchingClient.swift b/macos/Onit/Data/Fetching/FetchingClient.swift index d6f168c9..d00d6267 100644 --- a/macos/Onit/Data/Fetching/FetchingClient.swift +++ b/macos/Onit/Data/Fetching/FetchingClient.swift @@ -287,7 +287,7 @@ actor FetchingClient { } } - if let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == model.customProviderName }) { + if let customProvider = Defaults[.availableCustomProviders].first(where: { $0.name == model.customProviderName }) { let url = URL(string: customProvider.baseURL)! let endpoint = CustomChatEndpoint(baseURL: url, messages: openAIMessageStack, token: customProvider.token, model: model.id) let response = try await execute(endpoint) diff --git a/macos/Onit/Data/Model/Model+TokenValidation.swift b/macos/Onit/Data/Model/Model+TokenValidation.swift index f908a56e..335ef1f4 100644 --- a/macos/Onit/Data/Model/Model+TokenValidation.swift +++ b/macos/Onit/Data/Model/Model+TokenValidation.swift @@ -106,7 +106,7 @@ extension OnitModel { case .custom: // For custom providers, we'll validate by trying to fetch the models list if let customProviderName = Defaults[.remoteModel]?.customProviderName, - let customProvider = Defaults[.availableCustomProvider].first(where: { $0.name == customProviderName }), + let customProvider = Defaults[.availableCustomProviders].first(where: { $0.name == customProviderName }), let url = URL(string: customProvider.baseURL) { let endpoint = CustomModelsEndpoint(baseURL: url, token: token) _ = try await FetchingClient().execute(endpoint) diff --git a/macos/Onit/Data/Persistence/Defaults.swift b/macos/Onit/Data/Persistence/Defaults.swift index aaa74f87..5f5c75e1 100644 --- a/macos/Onit/Data/Persistence/Defaults.swift +++ b/macos/Onit/Data/Persistence/Defaults.swift @@ -50,7 +50,7 @@ extension Defaults.Keys { static let mode = Key("mode", default: .remote) static let availableLocalModels = Key<[String]>("availableLocalModels", default: []) static let availableRemoteModels = Key<[AIModel]>("availableRemoteModels", default: []) - static let availableCustomProvider = Key<[CustomProvider]>("availableCustomProvider", default: []) + static let availableCustomProviders = Key<[CustomProvider]>("availableCustomProvider", default: []) // Stores unique model identifiers in the format "provider-id" or "customProviderName-id" for custom providers static let visibleModelIds = Key>("visibleModelIds", default: Set([])) diff --git a/macos/Onit/Data/Structures/AIModel.swift b/macos/Onit/Data/Structures/AIModel.swift index 5d89d6cd..9ad77be4 100644 --- a/macos/Onit/Data/Structures/AIModel.swift +++ b/macos/Onit/Data/Structures/AIModel.swift @@ -74,7 +74,7 @@ struct AIModel: Codable, Identifiable, Hashable, Defaults.Serializable { let remoteModels = response.models.compactMap { AIModel(from: $0) } var customModels: [AIModel] = [] - for provider in Defaults[.availableCustomProvider] { + for provider in Defaults[.availableCustomProviders] { do { try await provider.fetchModels() customModels.append(contentsOf: provider.models) diff --git a/macos/Onit/UI/RemoteModelsState.swift b/macos/Onit/UI/RemoteModelsState.swift index 8a186318..43b87561 100644 --- a/macos/Onit/UI/RemoteModelsState.swift +++ b/macos/Onit/UI/RemoteModelsState.swift @@ -15,7 +15,7 @@ final class RemoteModelsState: ObservableObject { @ObservationIgnored var availableRemoteModels: [AIModel] - @ObservableDefault(.availableCustomProvider) + @ObservableDefault(.availableCustomProviders) @ObservationIgnored var availableCustomProvider: [CustomProvider] diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift index 88639270..9e306235 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderFormView.swift @@ -11,7 +11,7 @@ import SwiftUI struct CustomProviderFormView: View { @Environment(\.dismiss) private var dismiss - @Default(.availableCustomProvider) var availableCustomProviders + @Default(.availableCustomProviders) var availableCustomProviders @Default(.availableRemoteModels) var availableRemoteModels @@ -51,7 +51,7 @@ struct CustomProviderFormView: View { .disabled(name.isEmpty || baseURL.isEmpty || token.isEmpty) }.padding(.top, 8) - if let errorMessage = errorMessage{ + if let errorMessage = errorMessage { Text(errorMessage) .foregroundColor(.red) } @@ -63,8 +63,22 @@ struct CustomProviderFormView: View { private func addProvider() { Task { do { - - try await fetchModels() + // Check for duplicate provider name + if availableCustomProviders.contains(where: { $0.name == name }) { + errorMessage = "A provider with this name already exists." + return + } + var provider = CustomProvider( + name: name, + baseURL: baseURL, + token: token, + models: [] + ) + try await provider.fetchModels() + + // If the above doesn't crash, we're good! + availableCustomProviders.append(provider) + availableRemoteModels.append(contentsOf: provider.models) DispatchQueue.main.async { isSubmitted = true dismiss() @@ -74,21 +88,6 @@ struct CustomProviderFormView: View { } } } - - func fetchModels() async throws { - var provider = CustomProvider( - name: name, - baseURL: baseURL, - token: token, - models: [] - ) - - try await provider.fetchModels() - - // If the above doesn't crash, we're good! - availableCustomProviders.append(provider) - availableRemoteModels.append(contentsOf: provider.models) - } } #Preview { diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift index ad145522..a8c6bf9d 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProviderRow.swift @@ -9,11 +9,16 @@ import Defaults import SwiftUI struct CustomProviderRow: View { + @Default(.availableCustomProviders) var availableCustomProviders @Default(.availableRemoteModels) var availableRemoteModels @Default(.visibleModelIds) var visibleModelIds @Binding var provider: CustomProvider @State private var searchText: String = "" + @State private var loading = false + @State private var validated = false + @State private var errorMessage: String? + @State private var showAlert = false private var filteredProviderModels: [AIModel] { providerModels.filter { model in @@ -30,12 +35,34 @@ struct CustomProviderRow: View { HStack { Text(provider.name) .font(.system(size: 13)) - Spacer() Toggle("", isOn: $provider.isEnabled) .toggleStyle(.switch) .controlSize(.small) + + Button(action: { + showAlert = true + }) { + Image(systemName: "trash") + .foregroundColor(.red) + } + .alert(isPresented: $showAlert) { + Alert( + title: Text("Remove Provider"), + message: Text("Are you sure you want to remove this provider?"), + primaryButton: .destructive(Text("Remove")) { + removeProvider() + }, + secondaryButton: .cancel() + ) + } + } + + tokenField + if let errorMessage = errorMessage { + Text(errorMessage) + .foregroundColor(.red) } if provider.isEnabled { @@ -63,6 +90,57 @@ struct CustomProviderRow: View { } .cornerRadius(8) } + + var tokenField: some View { + HStack(spacing: 7) { + TextField("Enter your \(provider.name) API key", text: $provider.token) + .textFieldStyle(.roundedBorder) + .frame(height: 22) + .font(.system(size: 13, weight: .regular)) + .foregroundColor(.primary) // Ensure placeholder text is not dimmed + + Button { + Task { + loading = true + do { + try await provider.fetchModels() + validated = true + } catch { + errorMessage = "Failed to fetch models: \(error.localizedDescription)" + } + loading = false + } + } label: { + if validated { + Text("Verified") + } else { + if loading { + ProgressView() + .controlSize(.small) + } else { + Text("Verify →") + } + } + } + .disabled(loading) + .foregroundStyle(.white) + .buttonStyle(.borderedProminent) + .frame(height: 22) + .fontWeight(.regular) + } + } + + private func removeProvider() { + let modelsToRemove = availableRemoteModels.filter { $0.customProviderName == provider.name } + availableRemoteModels.removeAll(where: { modelsToRemove.contains($0) }) + + let modelsToRemoveUniqueIDs = modelsToRemove.map { $0.uniqueId } + visibleModelIds.subtract(modelsToRemoveUniqueIDs) + + if let index = availableCustomProviders.firstIndex(where: { $0.name == provider.name }) { + availableCustomProviders.remove(at: index) + } + } } #Preview { diff --git a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift index adb4c5a6..034361ff 100644 --- a/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift +++ b/macos/Onit/UI/Settings/Models/CustomProvider/CustomProvidersSection.swift @@ -3,7 +3,7 @@ import SwiftUI import SwiftData struct CustomProvidersSection: View { - @Default(.availableCustomProvider) private var availableCustomProvider + @Default(.availableCustomProviders) private var availableCustomProvider @State private var showForm = false @State private var isSubmitted = false