diff --git a/internal/shared/union.go b/internal/shared/union.go index d40977166ee..cc27fa55eaf 100644 --- a/internal/shared/union.go +++ b/internal/shared/union.go @@ -84,6 +84,9 @@ func (UnionString) ImplementsPagerulesPageruleGetResponse() func (UnionString) ImplementsRateLimitsRateLimitNewResponse() {} func (UnionString) ImplementsRateLimitsRateLimitEditResponse() {} func (UnionString) ImplementsRateLimitsRateLimitGetResponse() {} +func (UnionString) ImplementsWorkersAIRunResponse() {} +func (UnionString) ImplementsWorkersAIRunParamsBody() {} +func (UnionString) ImplementsWorkersAIRunParamsBodyTextEmbeddingsText() {} func (UnionString) ImplementsWorkersScriptTailDeleteResponse() {} func (UnionString) ImplementsWorkersRouteNewResponse() {} func (UnionString) ImplementsWorkersRouteDeleteResponse() {} diff --git a/workers/ai.go b/workers/ai.go index 7cb83a97779..a9944e3c2b5 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -3,14 +3,20 @@ package workers import ( + "bytes" "context" "fmt" + "mime/multipart" "net/http" + "reflect" + "github.com/cloudflare/cloudflare-go/v2/internal/apiform" "github.com/cloudflare/cloudflare-go/v2/internal/apijson" "github.com/cloudflare/cloudflare-go/v2/internal/param" "github.com/cloudflare/cloudflare-go/v2/internal/requestconfig" + "github.com/cloudflare/cloudflare-go/v2/internal/shared" "github.com/cloudflare/cloudflare-go/v2/option" + "github.com/tidwall/gjson" ) // AIService contains methods and other services that help with interacting with @@ -51,61 +57,398 @@ func (r *AIService) Run(ctx context.Context, modelName string, params AIRunParam return } -type AIRunResponse = interface{} +// Union satisfied by [workers.AIRunResponseTextClassification], +// [shared.UnionString], [workers.AIRunResponseSentenceSimilarity], +// [workers.AIRunResponseTextEmbeddings], [workers.AIRunResponseSpeechRecognition], +// [workers.AIRunResponseImageClassification], +// [workers.AIRunResponseObjectDetection], [workers.AIRunResponseObject], +// [shared.UnionString], [workers.AIRunResponseTranslation], +// [workers.AIRunResponseSummarization] or [workers.AIRunResponseImageToText]. +type AIRunResponse interface { + ImplementsWorkersAIRunResponse() +} -type AIRunParams struct { - AccountID param.Field[string] `path:"account_id,required"` - Body param.Field[interface{}] `json:"body,required"` +func init() { + apijson.RegisterUnion( + reflect.TypeOf((*AIRunResponse)(nil)).Elem(), + "", + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseTextClassification{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.String, + Type: reflect.TypeOf(shared.UnionString("")), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseSentenceSimilarity{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseTextEmbeddings{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseSpeechRecognition{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseImageClassification{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseObjectDetection{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseObject{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.String, + Type: reflect.TypeOf(shared.UnionString("")), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseTranslation{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseSummarization{}), + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(AIRunResponseImageToText{}), + }, + ) } -func (r AIRunParams) MarshalJSON() (data []byte, err error) { - return apijson.MarshalRoot(r.Body) +type AIRunResponseTextClassification []AIRunResponseTextClassification + +func (r AIRunResponseTextClassification) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseSentenceSimilarity []float64 + +func (r AIRunResponseSentenceSimilarity) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseTextEmbeddings struct { + Data [][]float64 `json:"data"` + Shape []float64 `json:"shape"` + JSON aiRunResponseTextEmbeddingsJSON `json:"-"` } -type AIRunResponseEnvelope struct { - Errors []AIRunResponseEnvelopeErrors `json:"errors,required"` - Messages []string `json:"messages,required"` - Result AIRunResponse `json:"result,required"` - Success bool `json:"success,required"` - JSON aiRunResponseEnvelopeJSON `json:"-"` +// aiRunResponseTextEmbeddingsJSON contains the JSON metadata for the struct +// [AIRunResponseTextEmbeddings] +type aiRunResponseTextEmbeddingsJSON struct { + Data apijson.Field + Shape apijson.Field + raw string + ExtraFields map[string]apijson.Field } -// aiRunResponseEnvelopeJSON contains the JSON metadata for the struct -// [AIRunResponseEnvelope] -type aiRunResponseEnvelopeJSON struct { - Errors apijson.Field - Messages apijson.Field - Result apijson.Field - Success apijson.Field +func (r *AIRunResponseTextEmbeddings) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseTextEmbeddingsJSON) RawJSON() string { + return r.raw +} + +func (r AIRunResponseTextEmbeddings) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseSpeechRecognition struct { + Text string `json:"text,required"` + WordCount float64 `json:"word_count"` + Words []AIRunResponseSpeechRecognitionWord `json:"words"` + JSON aiRunResponseSpeechRecognitionJSON `json:"-"` +} + +// aiRunResponseSpeechRecognitionJSON contains the JSON metadata for the struct +// [AIRunResponseSpeechRecognition] +type aiRunResponseSpeechRecognitionJSON struct { + Text apijson.Field + WordCount apijson.Field + Words apijson.Field raw string ExtraFields map[string]apijson.Field } -func (r *AIRunResponseEnvelope) UnmarshalJSON(data []byte) (err error) { +func (r *AIRunResponseSpeechRecognition) UnmarshalJSON(data []byte) (err error) { return apijson.UnmarshalRoot(data, r) } -func (r aiRunResponseEnvelopeJSON) RawJSON() string { +func (r aiRunResponseSpeechRecognitionJSON) RawJSON() string { return r.raw } -type AIRunResponseEnvelopeErrors struct { - Message string `json:"message,required"` - JSON aiRunResponseEnvelopeErrorsJSON `json:"-"` +func (r AIRunResponseSpeechRecognition) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseSpeechRecognitionWord struct { + End float64 `json:"end"` + Start float64 `json:"start"` + Word string `json:"word"` + JSON aiRunResponseSpeechRecognitionWordJSON `json:"-"` } -// aiRunResponseEnvelopeErrorsJSON contains the JSON metadata for the struct -// [AIRunResponseEnvelopeErrors] -type aiRunResponseEnvelopeErrorsJSON struct { - Message apijson.Field +// aiRunResponseSpeechRecognitionWordJSON contains the JSON metadata for the struct +// [AIRunResponseSpeechRecognitionWord] +type aiRunResponseSpeechRecognitionWordJSON struct { + End apijson.Field + Start apijson.Field + Word apijson.Field raw string ExtraFields map[string]apijson.Field } -func (r *AIRunResponseEnvelopeErrors) UnmarshalJSON(data []byte) (err error) { +func (r *AIRunResponseSpeechRecognitionWord) UnmarshalJSON(data []byte) (err error) { return apijson.UnmarshalRoot(data, r) } -func (r aiRunResponseEnvelopeErrorsJSON) RawJSON() string { +func (r aiRunResponseSpeechRecognitionWordJSON) RawJSON() string { + return r.raw +} + +type AIRunResponseImageClassification []AIRunResponseImageClassification + +func (r AIRunResponseImageClassification) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseObjectDetection []AIRunResponseObjectDetection + +func (r AIRunResponseObjectDetection) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseObject struct { + Response string `json:"response"` + JSON aiRunResponseObjectJSON `json:"-"` +} + +// aiRunResponseObjectJSON contains the JSON metadata for the struct +// [AIRunResponseObject] +type aiRunResponseObjectJSON struct { + Response apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *AIRunResponseObject) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseObjectJSON) RawJSON() string { + return r.raw +} + +func (r AIRunResponseObject) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseTranslation struct { + TranslatedText string `json:"translated_text"` + JSON aiRunResponseTranslationJSON `json:"-"` +} + +// aiRunResponseTranslationJSON contains the JSON metadata for the struct +// [AIRunResponseTranslation] +type aiRunResponseTranslationJSON struct { + TranslatedText apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *AIRunResponseTranslation) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseTranslationJSON) RawJSON() string { + return r.raw +} + +func (r AIRunResponseTranslation) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseSummarization struct { + Summary string `json:"summary"` + JSON aiRunResponseSummarizationJSON `json:"-"` +} + +// aiRunResponseSummarizationJSON contains the JSON metadata for the struct +// [AIRunResponseSummarization] +type aiRunResponseSummarizationJSON struct { + Summary apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *AIRunResponseSummarization) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseSummarizationJSON) RawJSON() string { + return r.raw +} + +func (r AIRunResponseSummarization) ImplementsWorkersAIRunResponse() {} + +type AIRunResponseImageToText struct { + Description string `json:"description"` + JSON aiRunResponseImageToTextJSON `json:"-"` +} + +// aiRunResponseImageToTextJSON contains the JSON metadata for the struct +// [AIRunResponseImageToText] +type aiRunResponseImageToTextJSON struct { + Description apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *AIRunResponseImageToText) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseImageToTextJSON) RawJSON() string { + return r.raw +} + +func (r AIRunResponseImageToText) ImplementsWorkersAIRunResponse() {} + +type AIRunParams struct { + AccountID param.Field[string] `path:"account_id,required"` + Body param.Field[AIRunParamsBody] `json:"body,required" format:"binary"` +} + +func (r AIRunParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +// Satisfied by [workers.AIRunParamsBodyTextClassification], +// [workers.AIRunParamsBodyTextToImage], +// [workers.AIRunParamsBodySentenceSimilarity], +// [workers.AIRunParamsBodyTextEmbeddings], [shared.UnionString], +// [workers.AIRunParamsBodyObject], [shared.UnionString], +// [workers.AIRunParamsBodyObject], [shared.UnionString], +// [workers.AIRunParamsBodyObject], [workers.AIRunParamsBodyObject], +// [workers.AIRunParamsBodyObject], [workers.AIRunParamsBodyTranslation], +// [workers.AIRunParamsBodySummarization], [shared.UnionString], +// [workers.AIRunParamsBodyObject]. +type AIRunParamsBody interface { + ImplementsWorkersAIRunParamsBody() +} + +type AIRunParamsBodyTextClassification struct { + Text param.Field[string] `json:"text,required"` +} + +func (r AIRunParamsBodyTextClassification) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodyTextClassification) ImplementsWorkersAIRunParamsBody() {} + +type AIRunParamsBodyTextToImage struct { + Prompt param.Field[string] `json:"prompt,required"` + Guidance param.Field[float64] `json:"guidance"` + Image param.Field[[]float64] `json:"image"` + Mask param.Field[[]float64] `json:"mask"` + NumSteps param.Field[int64] `json:"num_steps"` + Strength param.Field[float64] `json:"strength"` +} + +func (r AIRunParamsBodyTextToImage) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodyTextToImage) ImplementsWorkersAIRunParamsBody() {} + +type AIRunParamsBodySentenceSimilarity struct { + Sentences param.Field[[]string] `json:"sentences,required"` + Source param.Field[string] `json:"source,required"` +} + +func (r AIRunParamsBodySentenceSimilarity) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodySentenceSimilarity) ImplementsWorkersAIRunParamsBody() {} + +type AIRunParamsBodyTextEmbeddings struct { + Text param.Field[AIRunParamsBodyTextEmbeddingsText] `json:"text,required"` +} + +func (r AIRunParamsBodyTextEmbeddings) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodyTextEmbeddings) ImplementsWorkersAIRunParamsBody() {} + +// Satisfied by [shared.UnionString], +// [workers.AIRunParamsBodyTextEmbeddingsTextArray]. +type AIRunParamsBodyTextEmbeddingsText interface { + ImplementsWorkersAIRunParamsBodyTextEmbeddingsText() +} + +type AIRunParamsBodyTextEmbeddingsTextArray []string + +func (r AIRunParamsBodyTextEmbeddingsTextArray) ImplementsWorkersAIRunParamsBodyTextEmbeddingsText() { +} + +type AIRunParamsBodyObject struct { + Audio param.Field[[]float64] `json:"audio"` +} + +func (r AIRunParamsBodyObject) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodyObject) ImplementsWorkersAIRunParamsBody() {} + +type AIRunParamsBodyTranslation struct { + TargetLang param.Field[string] `json:"target_lang,required"` + Text param.Field[string] `json:"text,required"` + SourceLang param.Field[string] `json:"source_lang"` +} + +func (r AIRunParamsBodyTranslation) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodyTranslation) ImplementsWorkersAIRunParamsBody() {} + +type AIRunParamsBodySummarization struct { + InputText param.Field[string] `json:"input_text,required"` + MaxLength param.Field[int64] `json:"max_length"` +} + +func (r AIRunParamsBodySummarization) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r AIRunParamsBodySummarization) ImplementsWorkersAIRunParamsBody() {} + +type AIRunResponseEnvelope struct { + Result AIRunResponse `json:"result" format:"binary"` + JSON aiRunResponseEnvelopeJSON `json:"-"` +} + +// aiRunResponseEnvelopeJSON contains the JSON metadata for the struct +// [AIRunResponseEnvelope] +type aiRunResponseEnvelopeJSON struct { + Result apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *AIRunResponseEnvelope) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r aiRunResponseEnvelopeJSON) RawJSON() string { return r.raw } diff --git a/workers/ai_test.go b/workers/ai_test.go index 3ed6deaaf6f..d7a010be830 100644 --- a/workers/ai_test.go +++ b/workers/ai_test.go @@ -14,7 +14,7 @@ import ( "github.com/cloudflare/cloudflare-go/v2/workers" ) -func TestAIRun(t *testing.T) { +func TestAIRunWithOptionalParams(t *testing.T) { t.Skip("skipped: tests are disabled for the time being") baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { @@ -33,7 +33,9 @@ func TestAIRun(t *testing.T) { "string", workers.AIRunParams{ AccountID: cloudflare.F("023e105f4ecef8ad9ca31a8372d0c353"), - Body: cloudflare.F[any](map[string]interface{}{}), + Body: cloudflare.F[workers.AIRunParamsBody](workers.AIRunParamsBodyTextClassification(workers.AIRunParamsBodyTextClassification{ + Text: cloudflare.F("string"), + })), }, ) if err != nil {