diff --git a/.gitignore b/.gitignore index fd3ad8e..c6d6835 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ website/vendor # Keep windows files with windows line endings *.winfile eol=crlf +**/.terraform.tfstate.lock.info diff --git a/internal/provider/ml_model/model.go b/internal/provider/ml_model/model.go index e509a28..cfbc8dc 100644 --- a/internal/provider/ml_model/model.go +++ b/internal/provider/ml_model/model.go @@ -2,7 +2,8 @@ package ml_model import ( "context" - + "encoding/json" + "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/types" @@ -15,14 +16,28 @@ func NewVellumMLModelCreateRequest(ctx context.Context, mlModelModel *TfMLModelR developedBy, _ := vellum.NewMlModelDeveloperFromString(mlModelModel.DevelopedBy.ValueString()) family, _ := vellum.NewMlModelFamilyFromString(mlModelModel.Family.ValueString()) - // TODO: Pass in actual values rather than dummy keys - // Create an empty slice for features - features := []vellum.MlModelFeature{"CHAT_MESSAGE_SYSTEM"} + features := []vellum.MlModelFeature{} + for _, feature := range mlModelModel.ExecConfig.Features.Elements() { + feature, _ := vellum.NewMlModelFeatureFromString(feature.(types.String).ValueString()) + features = append(features, feature) + } + + metadata := map[string]interface{}{} + for key, tfvalue := range mlModelModel.ExecConfig.Metadata.Elements() { + value := tfvalue.(types.String).ValueString() + var v interface{} + if err := json.Unmarshal([]byte(value), &v); err != nil { + metadata[key] = value + } else { + metadata[key] = v + } + } + execConfig := vellum.MlModelExecConfigRequest{ - ModelIdentifier: "test", - BaseUrl: "http://localhost:8080", - Metadata: map[string]interface{}{"key": "value"}, + ModelIdentifier: mlModelModel.ExecConfig.ModelIdentifier.ValueString(), + BaseUrl: mlModelModel.ExecConfig.BaseUrl.ValueString(), Features: features, + Metadata: metadata, } request := vellum.MlModelCreateRequest{ @@ -45,6 +60,30 @@ func NewTfMLModelModel(ctx context.Context, model *TfMLModelResourceModel, mlMod HostedBy: types.StringValue(string(mlModel.HostedBy)), DevelopedBy: types.StringValue(string(mlModel.DevelopedBy.Value)), Family: types.StringValue(string(mlModel.Family.Value)), + ExecConfig: TfMLModelExecConfig{ + ModelIdentifier: types.StringValue(mlModel.ExecConfig.ModelIdentifier), + BaseUrl: types.StringValue(mlModel.ExecConfig.BaseUrl), + Features: types.ListValueMust( + types.StringType, + func() []attr.Value { + var features []attr.Value + for _, feature := range mlModel.ExecConfig.Features { + features = append(features, types.StringValue(string(feature))) + } + return features + }(), + ), + Metadata: types.MapValueMust( + types.StringType, + func() map[string]attr.Value { + metadata := map[string]attr.Value{} + for key, value := range mlModel.ExecConfig.Metadata { + metadata[key] = types.StringValue(value) + } + return metadata + }(), + ), + }, } return mlModelModel, nil diff --git a/internal/provider/ml_model/resource.go b/internal/provider/ml_model/resource.go index 4123761..a2947cc 100644 --- a/internal/provider/ml_model/resource.go +++ b/internal/provider/ml_model/resource.go @@ -6,7 +6,9 @@ package ml_model import ( "context" "fmt" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/path" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/resource/schema" @@ -28,13 +30,21 @@ func Resource() resource.Resource { return &MLModelResource{} } +type TfMLModelExecConfig struct { + ModelIdentifier types.String `tfsdk:"model_identifier"` + BaseUrl types.String `tfsdk:"base_url"` + Features types.List `tfsdk:"features"` + Metadata types.Map `tfsdk:"metadata"` +} + type TfMLModelResourceModel struct { - Id types.String `tfsdk:"id"` - Name types.String `tfsdk:"name"` - Visibility types.String `tfsdk:"visibility"` - HostedBy types.String `tfsdk:"hosted_by"` - DevelopedBy types.String `tfsdk:"developed_by"` - Family types.String `tfsdk:"family"` + Id types.String `tfsdk:"id"` + Name types.String `tfsdk:"name"` + Visibility types.String `tfsdk:"visibility"` + HostedBy types.String `tfsdk:"hosted_by"` + DevelopedBy types.String `tfsdk:"developed_by"` + Family types.String `tfsdk:"family"` + ExecConfig TfMLModelExecConfig `tfsdk:"exec_config"` } func (r *MLModelResource) Metadata(ctx context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) { @@ -155,6 +165,34 @@ func (r *MLModelResource) Schema(ctx context.Context, req resource.SchemaRequest ), }, }, + "exec_config": schema.ObjectAttribute{ + Description: "The execution configuration of the ML Model.", + MarkdownDescription: "The execution configuration of the ML Model.", + Required: true, + AttributeTypes: map[string]attr.Type{ + "model_identifier": schema.StringAttribute{ + Description: "The model identifier", + MarkdownDescription: "The model identifier", + Required: true, + }.GetType(), + "base_url": schema.StringAttribute{ + Description: "The base URL", + MarkdownDescription: "The base URL", + Required: true, + }.GetType(), + "features": schema.ListAttribute{ + Description: "The features", + MarkdownDescription: "The features", + Required: true, + ElementType: schema.StringAttribute{}.GetType(), + }.GetType(), + "metadata": schema.MapAttribute{ + Description: "Arbitrary JSON object", + Required: true, + ElementType: types.StringType, + }.GetType(), + }, + }, }, } } diff --git a/internal/sdk/ml_models.go b/internal/sdk/ml_models.go index 5545ac8..dde9328 100644 --- a/internal/sdk/ml_models.go +++ b/internal/sdk/ml_models.go @@ -384,8 +384,8 @@ func (m *MlModelDisplayTagEnumValueLabel) String() string { type MlModelExecConfig struct { ModelIdentifier string `json:"model_identifier"` BaseUrl string `json:"base_url"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Features []MlModelFeature `json:"features,omitempty"` + Metadata map[string]string `json:"metadata"` + Features []MlModelFeature `json:"features"` TokenizerConfig *MlModelTokenizerConfig `json:"tokenizer_config,omitempty"` RequestConfig *MlModelRequestConfig `json:"request_config,omitempty"` ResponseConfig *MlModelResponseConfig `json:"response_config,omitempty"` @@ -419,8 +419,8 @@ func (m *MlModelExecConfig) String() string { type MlModelExecConfigRequest struct { ModelIdentifier string `json:"model_identifier"` BaseUrl string `json:"base_url"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Features []MlModelFeature `json:"features,omitempty"` + Metadata map[string]interface{} `json:"metadata"` + Features []MlModelFeature `json:"features"` TokenizerConfig *MlModelTokenizerConfigRequest `json:"tokenizer_config,omitempty"` RequestConfig *MlModelRequestConfigRequest `json:"request_config,omitempty"` ResponseConfig *MlModelResponseConfigRequest `json:"response_config,omitempty"`