Skip to content

Commit

Permalink
Merge pull request #20 from vellum-ai/noa/ml-model-details
Browse files Browse the repository at this point in the history
Flesh Out Support for ML Model Exec Configs
  • Loading branch information
noanflaherty committed Aug 8, 2024
2 parents 35cd2d1 + e91fe3e commit 002a03c
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ website/vendor

# Keep windows files with windows line endings
*.winfile eol=crlf
**/.terraform.tfstate.lock.info
53 changes: 46 additions & 7 deletions internal/provider/ml_model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package ml_model

import (
"context"

"encoding/json"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/types"

Expand All @@ -15,14 +16,28 @@ func NewVellumMLModelCreateRequest(ctx context.Context, mlModelModel *TfMLModelR
developedBy, _ := vellum.NewMlModelDeveloperFromString(mlModelModel.DevelopedBy.ValueString())
family, _ := vellum.NewMlModelFamilyFromString(mlModelModel.Family.ValueString())

// TODO: Pass in actual values rather than dummy keys
// Create an empty slice for features
features := []vellum.MlModelFeature{"CHAT_MESSAGE_SYSTEM"}
features := []vellum.MlModelFeature{}
for _, feature := range mlModelModel.ExecConfig.Features.Elements() {
feature, _ := vellum.NewMlModelFeatureFromString(feature.(types.String).ValueString())
features = append(features, feature)
}

metadata := map[string]interface{}{}
for key, tfvalue := range mlModelModel.ExecConfig.Metadata.Elements() {
value := tfvalue.(types.String).ValueString()
var v interface{}
if err := json.Unmarshal([]byte(value), &v); err != nil {
metadata[key] = value
} else {
metadata[key] = v
}
}

execConfig := vellum.MlModelExecConfigRequest{
ModelIdentifier: "test",
BaseUrl: "http://localhost:8080",
Metadata: map[string]interface{}{"key": "value"},
ModelIdentifier: mlModelModel.ExecConfig.ModelIdentifier.ValueString(),
BaseUrl: mlModelModel.ExecConfig.BaseUrl.ValueString(),
Features: features,
Metadata: metadata,
}

request := vellum.MlModelCreateRequest{
Expand All @@ -45,6 +60,30 @@ func NewTfMLModelModel(ctx context.Context, model *TfMLModelResourceModel, mlMod
HostedBy: types.StringValue(string(mlModel.HostedBy)),
DevelopedBy: types.StringValue(string(mlModel.DevelopedBy.Value)),
Family: types.StringValue(string(mlModel.Family.Value)),
ExecConfig: TfMLModelExecConfig{
ModelIdentifier: types.StringValue(mlModel.ExecConfig.ModelIdentifier),
BaseUrl: types.StringValue(mlModel.ExecConfig.BaseUrl),
Features: types.ListValueMust(
types.StringType,
func() []attr.Value {
var features []attr.Value
for _, feature := range mlModel.ExecConfig.Features {
features = append(features, types.StringValue(string(feature)))
}
return features
}(),
),
Metadata: types.MapValueMust(
types.StringType,
func() map[string]attr.Value {
metadata := map[string]attr.Value{}
for key, value := range mlModel.ExecConfig.Metadata {
metadata[key] = types.StringValue(value)
}
return metadata
}(),
),
},
}

return mlModelModel, nil
Expand Down
50 changes: 44 additions & 6 deletions internal/provider/ml_model/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package ml_model
import (
"context"
"fmt"

"github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
Expand All @@ -28,13 +30,21 @@ func Resource() resource.Resource {
return &MLModelResource{}
}

type TfMLModelExecConfig struct {
ModelIdentifier types.String `tfsdk:"model_identifier"`
BaseUrl types.String `tfsdk:"base_url"`
Features types.List `tfsdk:"features"`
Metadata types.Map `tfsdk:"metadata"`
}

type TfMLModelResourceModel struct {
Id types.String `tfsdk:"id"`
Name types.String `tfsdk:"name"`
Visibility types.String `tfsdk:"visibility"`
HostedBy types.String `tfsdk:"hosted_by"`
DevelopedBy types.String `tfsdk:"developed_by"`
Family types.String `tfsdk:"family"`
Id types.String `tfsdk:"id"`
Name types.String `tfsdk:"name"`
Visibility types.String `tfsdk:"visibility"`
HostedBy types.String `tfsdk:"hosted_by"`
DevelopedBy types.String `tfsdk:"developed_by"`
Family types.String `tfsdk:"family"`
ExecConfig TfMLModelExecConfig `tfsdk:"exec_config"`
}

func (r *MLModelResource) Metadata(ctx context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) {
Expand Down Expand Up @@ -155,6 +165,34 @@ func (r *MLModelResource) Schema(ctx context.Context, req resource.SchemaRequest
),
},
},
"exec_config": schema.ObjectAttribute{
Description: "The execution configuration of the ML Model.",
MarkdownDescription: "The execution configuration of the ML Model.",
Required: true,
AttributeTypes: map[string]attr.Type{
"model_identifier": schema.StringAttribute{
Description: "The model identifier",
MarkdownDescription: "The model identifier",
Required: true,
}.GetType(),
"base_url": schema.StringAttribute{
Description: "The base URL",
MarkdownDescription: "The base URL",
Required: true,
}.GetType(),
"features": schema.ListAttribute{
Description: "The features",
MarkdownDescription: "The features",
Required: true,
ElementType: schema.StringAttribute{}.GetType(),
}.GetType(),
"metadata": schema.MapAttribute{
Description: "Arbitrary JSON object",
Required: true,
ElementType: types.StringType,
}.GetType(),
},
},
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions internal/sdk/ml_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ func (m *MlModelDisplayTagEnumValueLabel) String() string {
type MlModelExecConfig struct {
ModelIdentifier string `json:"model_identifier"`
BaseUrl string `json:"base_url"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Features []MlModelFeature `json:"features,omitempty"`
Metadata map[string]string `json:"metadata"`
Features []MlModelFeature `json:"features"`
TokenizerConfig *MlModelTokenizerConfig `json:"tokenizer_config,omitempty"`
RequestConfig *MlModelRequestConfig `json:"request_config,omitempty"`
ResponseConfig *MlModelResponseConfig `json:"response_config,omitempty"`
Expand Down Expand Up @@ -419,8 +419,8 @@ func (m *MlModelExecConfig) String() string {
type MlModelExecConfigRequest struct {
ModelIdentifier string `json:"model_identifier"`
BaseUrl string `json:"base_url"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Features []MlModelFeature `json:"features,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
Features []MlModelFeature `json:"features"`
TokenizerConfig *MlModelTokenizerConfigRequest `json:"tokenizer_config,omitempty"`
RequestConfig *MlModelRequestConfigRequest `json:"request_config,omitempty"`
ResponseConfig *MlModelResponseConfigRequest `json:"response_config,omitempty"`
Expand Down

0 comments on commit 002a03c

Please sign in to comment.