Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add custom providers #46

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions macos/Onit/Data/Fetching/Endpoints/CustomChatEndpoint.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import Foundation

struct CustomChatEndpoint: Endpoint {
var baseURL: URL

typealias Request = OpenAIChatRequest
typealias Response = OpenAIChatResponse

let messages: [OpenAIChatMessage]
let token: String?
let model: String

var path: String { "/v1/chat/completions" }
var getParams: [String: String]? { nil }
var method: HTTPMethod { .post }
var requestBody: OpenAIChatRequest? {
OpenAIChatRequest(model: model, messages: messages)
}
var additionalHeaders: [String: String]? {
["Authorization": "Bearer \(token ?? "")"]
}
}
68 changes: 68 additions & 0 deletions macos/Onit/Data/Fetching/Endpoints/CustomModelsEndpoint.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import Foundation

struct CustomModelsEndpoint: Endpoint {
typealias Request = EmptyRequest
typealias Response = CustomModelsResponse

var baseURL: URL
let token: String?

var path: String { "/v1/models" }
var getParams: [String: String]? { nil }
var method: HTTPMethod { .get }
var requestBody: EmptyRequest? { nil }

var additionalHeaders: [String: String]? {
["Authorization": "Bearer \(token ?? "")"]
}
}

struct CustomModelsResponse: Codable {
let object: String?
let data: [CustomModelInfo]
}

struct CustomModelInfo: Codable {
let id: String

// These are the OpenRouter fields
let name: String?
let created: Int?
let description: String?
let context_length: Int?
let architecture: Architecture?
let pricing: Pricing?
let top_provider: TopProvider?
let per_request_limits: PerRequestLimits?

// These are the Groq fields
let context_window: Int?
let object: String?
let owned_by: String?
let active: Bool?

// 'id' is the only mutual field, so it's the only thing we can require...
}

struct Architecture: Codable {
let modality: String?
let tokenizer: String?
let instruct_type: String?
}

struct Pricing: Codable {
let prompt: String?
let completion: String?
let image: String?
let request: String?
}

struct TopProvider: Codable {
let context_length: Int?
let max_completion_tokens: Int?
let is_moderated: Bool?
}

struct PerRequestLimits: Codable {
// Define fields if necessary
}
12 changes: 10 additions & 2 deletions macos/Onit/Data/Fetching/FetchingClient+Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public enum FetchingError: Error {
case serverError(statusCode: Int, message: String)
case decodingError(Error)
case networkError(Error)
case invalidURL
}

extension FetchingError: LocalizedError {
Expand All @@ -40,6 +41,8 @@ extension FetchingError: LocalizedError {
"Failed to decode response: \(error.localizedDescription)"
case .networkError(let error):
"Network error: \(error.localizedDescription)"
case .invalidURL:
"Invalid URL"
}
}
}
Expand All @@ -49,7 +52,8 @@ extension FetchingError: Equatable {
switch (lhs, rhs) {
case (.invalidResponse, .invalidResponse),
(.unauthorized, .unauthorized),
(.notFound, .notFound):
(.notFound, .notFound),
(.invalidURL, .invalidURL):
return true
case (.forbidden(let lhsMessage), .forbidden(let rhsMessage)):
return lhsMessage == rhsMessage
Expand All @@ -73,7 +77,7 @@ extension FetchingError: Codable {
}

enum FetchingErrorType: String, Codable {
case invalidResponse, invalidRequest, unauthorized, forbidden, notFound, failedRequest, serverError, decodingError, networkError
case invalidResponse, invalidRequest, unauthorized, forbidden, notFound, failedRequest, serverError, decodingError, networkError, invalidURL
}

public init(from decoder: Decoder) throws {
Expand Down Expand Up @@ -111,6 +115,8 @@ extension FetchingError: Codable {
let errorDescription = try container.decode(String.self, forKey: .description)
let error = NSError(domain: "", code: 0, userInfo: [NSLocalizedDescriptionKey: errorDescription])
self = .networkError(error)
case .invalidURL:
self = .invalidURL
}
}

Expand Down Expand Up @@ -146,6 +152,8 @@ extension FetchingError: Codable {
case .networkError(let error):
try container.encode(FetchingErrorType.networkError, forKey: .type)
try container.encode(error.localizedDescription, forKey: .description)
case .invalidURL:
try container.encode(FetchingErrorType.invalidURL, forKey: .type)
}
}
}
2 changes: 1 addition & 1 deletion macos/Onit/Data/Fetching/FetchingClient+Execute.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extension FetchingClient {
}

// Helpful debugging method- put in the endpoint name and you can see the full request
if endpoint.path.contains("/v1beta/models") {
if endpoint.path.contains("/v1/models") {
printCurlRequest(endpoint: endpoint, url: url)
print("here")
}
Expand Down
51 changes: 51 additions & 0 deletions macos/Onit/Data/Fetching/FetchingClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import Foundation
import UniformTypeIdentifiers
import Defaults

actor FetchingClient {
let session = URLSession.shared
Expand Down Expand Up @@ -244,6 +245,56 @@ actor FetchingClient {
let endpoint = GoogleAIChatEndpoint(messages: googleAIMessageStack, model: model.id, token: apiToken)
let response = try await execute(endpoint)
return response.choices[0].message.content
case .custom:

var openAIMessageStack: [OpenAIChatMessage] = []

// Initialize messages with system prompt if needed
// if model.supportsSystemPrompts {

// 3rd Party model providers don't tell us if system prompts are enabled or not...
// How to handle? I guess the user needs to be able to toggle system prompts for each custom provider model.
openAIMessageStack.append(OpenAIChatMessage(role: "system", content: .text(systemMessage)))

for (index, userMessage) in userMessages.enumerated() {
if images[index].isEmpty {
let openAIMessage = OpenAIChatMessage(role: "user", content: .text(userMessage))
openAIMessageStack.append(openAIMessage)
} else {
var parts = [OpenAIChatContentPart(type: "text", text: userMessage, image_url: nil)]
for url in images[index] {
if let imageData = try? Data(contentsOf: url) {
let base64EncodedData = imageData.base64EncodedString()
let mimeType = mimeType(for: url)
let imagePart = OpenAIChatContentPart(
type: "image_url",
text: nil,
image_url: .init(url: "data:\(mimeType);base64,\(base64EncodedData)")
)
parts.append(imagePart)
} else {
print("Unable to read image data from URL: \(url)")
}
}
let openAIMessage = OpenAIChatMessage(role: "user", content: .multiContent(parts))
openAIMessageStack.append(openAIMessage)
}

// If there is a corresponding response, add it as an assistant message
if index < responses.count {
let responseMessage = OpenAIChatMessage(role: "assistant", content: .text(responses[index]))
openAIMessageStack.append(responseMessage)
}
}

if let customProvider = Defaults[.availableCustomProviders].first(where: { $0.name == model.customProviderName }) {
let url = URL(string: customProvider.baseURL)!
let endpoint = CustomChatEndpoint(baseURL: url, messages: openAIMessageStack, token: customProvider.token, model: model.id)
let response = try await execute(endpoint)
return response.choices[0].message.content
} else {
throw FetchingError.invalidRequest(message: "Custom provider not found")
}
}
}

Expand Down
41 changes: 41 additions & 0 deletions macos/Onit/Data/Model/CustomProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import Defaults
import Foundation
import SwiftData

class CustomProvider: Codable, Identifiable, Defaults.Serializable {
var id: String { name }

var name: String
var baseURL: String
var token: String
var models: [AIModel]
var isEnabled: Bool

init(name: String, baseURL: String, token: String, models: [AIModel]) {
self.name = name
self.baseURL = baseURL
self.token = token
self.models = models
self.isEnabled = true
}

static func == (lhs: CustomProvider, rhs: CustomProvider) -> Bool {
return lhs.name == rhs.name && lhs.baseURL == rhs.baseURL
&& lhs.token == rhs.token && lhs.models == rhs.models
&& lhs.isEnabled == rhs.isEnabled
}

@MainActor
func fetchModels() async throws {
guard let url = URL(string: baseURL) else { throw URLError(.badURL) }

let endpoint = CustomModelsEndpoint(baseURL: url, token: token)
let client = FetchingClient()
let response = try await client.execute(endpoint)

// Initialize model IDs
models = response.data.map { model in
AIModel(from: model, providerName: name)
}
}
}
24 changes: 21 additions & 3 deletions macos/Onit/Data/Model/Model+Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ extension OnitModel {

do {
let chat: String
if Defaults[.mode] == .remote {
// chat(instructions: [String], inputs: [Input?], files: [[URL]], images[[URL]], responses: [String], model: AIModel?, apiToken: String?)
if Defaults[.mode] == .remote {
// Regular remote model chat
chat = try await client.chat(
instructions: instructionsHistory,
inputs: inputsHistory,
Expand Down Expand Up @@ -140,9 +140,23 @@ extension OnitModel {
*/
private func trackEventGeneration(prompt: Prompt) {
let eventName = "user_prompted"
var modelName = ""

if Defaults[.mode] == .remote {
if let model = Defaults[.remoteModel] {
if let customProviderName = model.customProviderName {
modelName = "\(customProviderName)/\(model.displayName)"
} else {
modelName = model.displayName
}
}
} else {
modelName = Defaults[.localModel] ?? ""
}

let eventProperties: [String: Any] = [
"prompt_mode": Defaults[.mode].rawValue,
"prompt_model": Defaults[.mode] == .remote ? Defaults[.remoteModel]?.displayName ?? "" : Defaults[.localModel] ?? ""
"prompt_model": modelName
]
PostHogSDK.shared.capture(eventName, properties: eventProperties)
}
Expand Down Expand Up @@ -172,6 +186,8 @@ extension OnitModel {
Defaults[.isXAITokenValidated] = isValid
case .googleAI:
Defaults[.isGoogleAITokenValidated] = isValid
case .custom:
break // TODO: KNA -
}
}

Expand All @@ -186,6 +202,8 @@ extension OnitModel {
return Defaults[.xAIToken]
case .googleAI:
return Defaults[.googleAIToken]
case .custom:
return nil // TODO: KNA -
}
}
return nil
Expand Down
13 changes: 13 additions & 0 deletions macos/Onit/Data/Model/Model+TokenValidation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Onit
//

import Defaults
import Foundation

struct TokenValidationState {
Expand Down Expand Up @@ -101,6 +102,18 @@ extension OnitModel {
let endpoint = GoogleAIValidationEndpoint(apiKey: token)
_ = try await FetchingClient().execute(endpoint)
state.setValid(provider: provider)

case .custom:
// For custom providers, we'll validate by trying to fetch the models list
if let customProviderName = Defaults[.remoteModel]?.customProviderName,
let customProvider = Defaults[.availableCustomProviders].first(where: { $0.name == customProviderName }),
let url = URL(string: customProvider.baseURL) {
let endpoint = CustomModelsEndpoint(baseURL: url, token: token)
_ = try await FetchingClient().execute(endpoint)
state.setValid(provider: provider)
} else {
throw FetchingError.invalidURL
}
}
setTokenIsValid(true)
} catch let error as FetchingError {
Expand Down
Loading