Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Anthropic API support with custom version header #934

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream

func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
case APITypeAnthropic:
// https://docs.anthropic.com/en/api/versioning
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
}

if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
Expand Down
15 changes: 15 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
}
}

func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

client.setCommonHeaders(req)

if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}

func TestDecodeResponse(t *testing.T) {
stringInput := ""

Expand Down
22 changes: 21 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const (

azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments"

AnthropicAPIVersion = "2023-06-01"
)

type APIType string
Expand All @@ -20,6 +22,7 @@ const (
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
)

const AzureAPIKeyHeader = "api-key"
Expand All @@ -37,7 +40,7 @@ type ClientConfig struct {
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer
Expand Down Expand Up @@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
}
}

func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,

HTTPClient: &http.Client{},

EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}

func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>"
}
Expand Down
40 changes: 40 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
})
}
}

func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"

config := openai.DefaultAnthropicConfig(apiKey, baseURL)

if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}

if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}

if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}

if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}

func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")

if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}

if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}

expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
Loading