From 3d1b4098cd9502538539ff963814fea8b3b96f65 Mon Sep 17 00:00:00 2001 From: Brian Flad Date: Thu, 14 Dec 2023 17:56:53 +0100 Subject: [PATCH] all: Initial provider defined functions implementation (#209) Reference: https://github.com/hashicorp/terraform-plugin-go/pull/351 The next versions of the plugin protocol (5.5/6.5) include support for provider defined functions. This change includes initial implementation of that support including: - Temporarily pointing at terraform-plugin-go with provider function support (will be pointed at final terraform-plugin-go release before merge) - Updates to all provider server packages for new `GetFunctions` and `CallFunction` RPCs --- .../unreleased/FEATURES-20231107-141509.yaml | 5 + internal/tf5testserver/tf5testserver.go | 24 + internal/tf6testserver/tf6testserver.go | 24 + internal/tfprotov5tov6/tfprotov5tov6.go | 118 +++++ internal/tfprotov5tov6/tfprotov5tov6_test.go | 460 +++++++++++++++++- internal/tfprotov6tov5/tfprotov6tov5.go | 118 +++++ internal/tfprotov6tov5/tfprotov6tov5_test.go | 460 +++++++++++++++++- tf5muxserver/diagnostics.go | 21 + tf5muxserver/mux_server.go | 68 ++- tf5muxserver/mux_server_CallFunction.go | 57 +++ tf5muxserver/mux_server_CallFunction_test.go | 82 ++++ tf5muxserver/mux_server_GetFunctions.go | 68 +++ tf5muxserver/mux_server_GetFunctions_test.go | 333 +++++++++++++ tf5muxserver/mux_server_GetMetadata.go | 22 + tf5muxserver/mux_server_GetMetadata_test.go | 80 +++ tf5muxserver/mux_server_GetProviderSchema.go | 12 + .../mux_server_GetProviderSchema_test.go | 88 ++++ tf5muxserver/mux_server_test.go | 430 ++++++++++++++++ tf5to6server/tf5to6server.go | 57 +++ tf5to6server/tf5to6server_test.go | 79 +++ tf6muxserver/diagnostics.go | 21 + tf6muxserver/mux_server.go | 68 ++- tf6muxserver/mux_server_CallFunction.go | 57 +++ tf6muxserver/mux_server_CallFunction_test.go | 82 ++++ tf6muxserver/mux_server_GetFunctions.go | 68 +++ tf6muxserver/mux_server_GetFunctions_test.go | 333 +++++++++++++ tf6muxserver/mux_server_GetMetadata.go | 22 + tf6muxserver/mux_server_GetMetadata_test.go | 80 +++ tf6muxserver/mux_server_GetProviderSchema.go | 12 + .../mux_server_GetProviderSchema_test.go | 88 ++++ tf6muxserver/mux_server_test.go | 430 ++++++++++++++++ tf6to5server/tf6to5server.go | 57 +++ tf6to5server/tf6to5server_test.go | 79 +++ 33 files changed, 3991 insertions(+), 12 deletions(-) create mode 100644 .changes/unreleased/FEATURES-20231107-141509.yaml create mode 100644 tf5muxserver/mux_server_CallFunction.go create mode 100644 tf5muxserver/mux_server_CallFunction_test.go create mode 100644 tf5muxserver/mux_server_GetFunctions.go create mode 100644 tf5muxserver/mux_server_GetFunctions_test.go create mode 100644 tf6muxserver/mux_server_CallFunction.go create mode 100644 tf6muxserver/mux_server_CallFunction_test.go create mode 100644 tf6muxserver/mux_server_GetFunctions.go create mode 100644 tf6muxserver/mux_server_GetFunctions_test.go diff --git a/.changes/unreleased/FEATURES-20231107-141509.yaml b/.changes/unreleased/FEATURES-20231107-141509.yaml new file mode 100644 index 0000000..2d6aada --- /dev/null +++ b/.changes/unreleased/FEATURES-20231107-141509.yaml @@ -0,0 +1,5 @@ +kind: FEATURES +body: 'all: Upgrade protocol versions to support provider-defined functions' +time: 2023-11-07T14:15:09.783296-05:00 +custom: + Issue: "209" diff --git a/internal/tf5testserver/tf5testserver.go b/internal/tf5testserver/tf5testserver.go index d8b04da..71abb9a 100644 --- a/internal/tf5testserver/tf5testserver.go +++ b/internal/tf5testserver/tf5testserver.go @@ -16,9 +16,14 @@ var _ tfprotov5.ProviderServer = &TestServer{} type TestServer struct { ApplyResourceChangeCalled map[string]bool + CallFunctionCalled map[string]bool + ConfigureProviderCalled bool ConfigureProviderResponse *tfprotov5.ConfigureProviderResponse + GetFunctionsCalled bool + GetFunctionsResponse *tfprotov5.GetFunctionsResponse + GetMetadataCalled bool GetMetadataResponse *tfprotov5.GetMetadataResponse @@ -59,6 +64,15 @@ func (s *TestServer) ApplyResourceChange(_ context.Context, req *tfprotov5.Apply return nil, nil } +func (s *TestServer) CallFunction(_ context.Context, req *tfprotov5.CallFunctionRequest) (*tfprotov5.CallFunctionResponse, error) { + if s.CallFunctionCalled == nil { + s.CallFunctionCalled = make(map[string]bool) + } + + s.CallFunctionCalled[req.Name] = true + return nil, nil +} + func (s *TestServer) ConfigureProvider(_ context.Context, _ *tfprotov5.ConfigureProviderRequest) (*tfprotov5.ConfigureProviderResponse, error) { s.ConfigureProviderCalled = true @@ -69,6 +83,16 @@ func (s *TestServer) ConfigureProvider(_ context.Context, _ *tfprotov5.Configure return &tfprotov5.ConfigureProviderResponse{}, nil } +func (s *TestServer) GetFunctions(_ context.Context, _ *tfprotov5.GetFunctionsRequest) (*tfprotov5.GetFunctionsResponse, error) { + s.GetFunctionsCalled = true + + if s.GetFunctionsResponse != nil { + return s.GetFunctionsResponse, nil + } + + return &tfprotov5.GetFunctionsResponse{}, nil +} + func (s *TestServer) GetMetadata(_ context.Context, _ *tfprotov5.GetMetadataRequest) (*tfprotov5.GetMetadataResponse, error) { s.GetMetadataCalled = true diff --git a/internal/tf6testserver/tf6testserver.go b/internal/tf6testserver/tf6testserver.go index 0cced96..64b2df2 100644 --- a/internal/tf6testserver/tf6testserver.go +++ b/internal/tf6testserver/tf6testserver.go @@ -16,9 +16,14 @@ var _ tfprotov6.ProviderServer = &TestServer{} type TestServer struct { ApplyResourceChangeCalled map[string]bool + CallFunctionCalled map[string]bool + ConfigureProviderCalled bool ConfigureProviderResponse *tfprotov6.ConfigureProviderResponse + GetFunctionsCalled bool + GetFunctionsResponse *tfprotov6.GetFunctionsResponse + GetMetadataCalled bool GetMetadataResponse *tfprotov6.GetMetadataResponse @@ -59,6 +64,15 @@ func (s *TestServer) ApplyResourceChange(_ context.Context, req *tfprotov6.Apply return nil, nil } +func (s *TestServer) CallFunction(_ context.Context, req *tfprotov6.CallFunctionRequest) (*tfprotov6.CallFunctionResponse, error) { + if s.CallFunctionCalled == nil { + s.CallFunctionCalled = make(map[string]bool) + } + + s.CallFunctionCalled[req.Name] = true + return nil, nil +} + func (s *TestServer) ConfigureProvider(_ context.Context, _ *tfprotov6.ConfigureProviderRequest) (*tfprotov6.ConfigureProviderResponse, error) { s.ConfigureProviderCalled = true @@ -69,6 +83,16 @@ func (s *TestServer) ConfigureProvider(_ context.Context, _ *tfprotov6.Configure return &tfprotov6.ConfigureProviderResponse{}, nil } +func (s *TestServer) GetFunctions(_ context.Context, _ *tfprotov6.GetFunctionsRequest) (*tfprotov6.GetFunctionsResponse, error) { + s.GetFunctionsCalled = true + + if s.GetFunctionsResponse != nil { + return s.GetFunctionsResponse, nil + } + + return &tfprotov6.GetFunctionsResponse{}, nil +} + func (s *TestServer) GetMetadata(_ context.Context, _ *tfprotov6.GetMetadataRequest) (*tfprotov6.GetMetadataResponse, error) { s.GetMetadataCalled = true diff --git a/internal/tfprotov5tov6/tfprotov5tov6.go b/internal/tfprotov5tov6/tfprotov5tov6.go index 8924ea8..52bc39a 100644 --- a/internal/tfprotov5tov6/tfprotov5tov6.go +++ b/internal/tfprotov5tov6/tfprotov5tov6.go @@ -36,6 +36,34 @@ func ApplyResourceChangeResponse(in *tfprotov5.ApplyResourceChangeResponse) *tfp } } +func CallFunctionRequest(in *tfprotov5.CallFunctionRequest) *tfprotov6.CallFunctionRequest { + if in == nil { + return nil + } + + out := &tfprotov6.CallFunctionRequest{ + Arguments: make([]*tfprotov6.DynamicValue, 0, len(in.Arguments)), + Name: in.Name, + } + + for _, argument := range in.Arguments { + out.Arguments = append(out.Arguments, DynamicValue(argument)) + } + + return out +} + +func CallFunctionResponse(in *tfprotov5.CallFunctionResponse) *tfprotov6.CallFunctionResponse { + if in == nil { + return nil + } + + return &tfprotov6.CallFunctionResponse{ + Diagnostics: Diagnostics(in.Diagnostics), + Result: DynamicValue(in.Result), + } +} + func ConfigureProviderRequest(in *tfprotov5.ConfigureProviderRequest) *tfprotov6.ConfigureProviderRequest { if in == nil { return nil @@ -98,6 +126,84 @@ func DynamicValue(in *tfprotov5.DynamicValue) *tfprotov6.DynamicValue { } } +func Function(in *tfprotov5.Function) *tfprotov6.Function { + if in == nil { + return nil + } + + out := &tfprotov6.Function{ + DeprecationMessage: in.DeprecationMessage, + Description: in.Description, + DescriptionKind: StringKind(in.DescriptionKind), + Parameters: make([]*tfprotov6.FunctionParameter, 0, len(in.Parameters)), + Return: FunctionReturn(in.Return), + Summary: in.Summary, + VariadicParameter: FunctionParameter(in.VariadicParameter), + } + + for _, parameter := range in.Parameters { + out.Parameters = append(out.Parameters, FunctionParameter(parameter)) + } + + return out +} + +func FunctionMetadata(in tfprotov5.FunctionMetadata) tfprotov6.FunctionMetadata { + return tfprotov6.FunctionMetadata{ + Name: in.Name, + } +} + +func FunctionParameter(in *tfprotov5.FunctionParameter) *tfprotov6.FunctionParameter { + if in == nil { + return nil + } + + return &tfprotov6.FunctionParameter{ + AllowNullValue: in.AllowNullValue, + AllowUnknownValues: in.AllowUnknownValues, + Description: in.Description, + DescriptionKind: StringKind(in.DescriptionKind), + Name: in.Name, + Type: in.Type, + } +} + +func FunctionReturn(in *tfprotov5.FunctionReturn) *tfprotov6.FunctionReturn { + if in == nil { + return nil + } + + return &tfprotov6.FunctionReturn{ + Type: in.Type, + } +} + +func GetFunctionsRequest(in *tfprotov5.GetFunctionsRequest) *tfprotov6.GetFunctionsRequest { + if in == nil { + return nil + } + + return &tfprotov6.GetFunctionsRequest{} +} + +func GetFunctionsResponse(in *tfprotov5.GetFunctionsResponse) *tfprotov6.GetFunctionsResponse { + if in == nil { + return nil + } + + functions := make(map[string]*tfprotov6.Function, len(in.Functions)) + + for name, function := range in.Functions { + functions[name] = Function(function) + } + + return &tfprotov6.GetFunctionsResponse{ + Diagnostics: Diagnostics(in.Diagnostics), + Functions: functions, + } +} + func GetMetadataRequest(in *tfprotov5.GetMetadataRequest) *tfprotov6.GetMetadataRequest { if in == nil { return nil @@ -114,6 +220,7 @@ func GetMetadataResponse(in *tfprotov5.GetMetadataResponse) *tfprotov6.GetMetada resp := &tfprotov6.GetMetadataResponse{ DataSources: make([]tfprotov6.DataSourceMetadata, 0, len(in.DataSources)), Diagnostics: Diagnostics(in.Diagnostics), + Functions: make([]tfprotov6.FunctionMetadata, 0, len(in.Functions)), Resources: make([]tfprotov6.ResourceMetadata, 0, len(in.Resources)), ServerCapabilities: ServerCapabilities(in.ServerCapabilities), } @@ -122,6 +229,10 @@ func GetMetadataResponse(in *tfprotov5.GetMetadataResponse) *tfprotov6.GetMetada resp.DataSources = append(resp.DataSources, DataSourceMetadata(datasource)) } + for _, function := range in.Functions { + resp.Functions = append(resp.Functions, FunctionMetadata(function)) + } + for _, resource := range in.Resources { resp.Resources = append(resp.Resources, ResourceMetadata(resource)) } @@ -148,6 +259,12 @@ func GetProviderSchemaResponse(in *tfprotov5.GetProviderSchemaResponse) *tfproto dataSourceSchemas[k] = Schema(v) } + functions := make(map[string]*tfprotov6.Function, len(in.Functions)) + + for name, function := range in.Functions { + functions[name] = Function(function) + } + resourceSchemas := make(map[string]*tfprotov6.Schema, len(in.ResourceSchemas)) for k, v := range in.ResourceSchemas { @@ -157,6 +274,7 @@ func GetProviderSchemaResponse(in *tfprotov5.GetProviderSchemaResponse) *tfproto return &tfprotov6.GetProviderSchemaResponse{ DataSourceSchemas: dataSourceSchemas, Diagnostics: Diagnostics(in.Diagnostics), + Functions: functions, Provider: Schema(in.Provider), ProviderMeta: Schema(in.ProviderMeta), ResourceSchemas: resourceSchemas, diff --git a/internal/tfprotov5tov6/tfprotov5tov6_test.go b/internal/tfprotov5tov6/tfprotov5tov6_test.go index b476851..ca33e4f 100644 --- a/internal/tfprotov5tov6/tfprotov5tov6_test.go +++ b/internal/tfprotov5tov6/tfprotov5tov6_test.go @@ -16,6 +16,14 @@ import ( var ( testBytes []byte = []byte("test") + testTfprotov5DataSourceMetadata tfprotov5.DataSourceMetadata = tfprotov5.DataSourceMetadata{ + TypeName: "test_data_source", + } + + testTfprotov6DataSourceMetadata tfprotov6.DataSourceMetadata = tfprotov6.DataSourceMetadata{ + TypeName: "test_data_source", + } + testTfprotov5Diagnostics []*tfprotov5.Diagnostic = []*tfprotov5.Diagnostic{ { Detail: "test detail", @@ -34,6 +42,36 @@ var ( testTfprotov5DynamicValue tfprotov5.DynamicValue testTfprotov6DynamicValue tfprotov6.DynamicValue + testTfprotov5Function *tfprotov5.Function = &tfprotov5.Function{ + Parameters: []*tfprotov5.FunctionParameter{}, + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + } + + testTfprotov5FunctionMetadata tfprotov5.FunctionMetadata = tfprotov5.FunctionMetadata{ + Name: "test_function", + } + + testTfprotov6Function *tfprotov6.Function = &tfprotov6.Function{ + Parameters: []*tfprotov6.FunctionParameter{}, + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + } + + testTfprotov6FunctionMetadata tfprotov6.FunctionMetadata = tfprotov6.FunctionMetadata{ + Name: "test_function", + } + + testTfprotov5ResourceMetadata tfprotov5.ResourceMetadata = tfprotov5.ResourceMetadata{ + TypeName: "test_resource", + } + + testTfprotov6ResourceMetadata tfprotov6.ResourceMetadata = tfprotov6.ResourceMetadata{ + TypeName: "test_resource", + } + testTfprotov5Schema *tfprotov5.Schema = &tfprotov5.Schema{ Block: &tfprotov5.SchemaBlock{ Attributes: []*tfprotov5.SchemaAttribute{ @@ -153,6 +191,86 @@ func TestApplyResourceChangeResponse(t *testing.T) { } } +func TestCallFunctionRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.CallFunctionRequest + expected *tfprotov6.CallFunctionRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.CallFunctionRequest{ + Arguments: []*tfprotov5.DynamicValue{ + &testTfprotov5DynamicValue, + }, + Name: "test_function", + }, + expected: &tfprotov6.CallFunctionRequest{ + Arguments: []*tfprotov6.DynamicValue{ + &testTfprotov6DynamicValue, + }, + Name: "test_function", + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.CallFunctionRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestCallFunctionResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.CallFunctionResponse + expected *tfprotov6.CallFunctionResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.CallFunctionResponse{ + Diagnostics: testTfprotov5Diagnostics, + Result: &testTfprotov5DynamicValue, + }, + expected: &tfprotov6.CallFunctionResponse{ + Diagnostics: testTfprotov6Diagnostics, + Result: &testTfprotov6DynamicValue, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.CallFunctionResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + func TestConfigureProviderRequest(t *testing.T) { t.Parallel() @@ -327,6 +445,338 @@ func TestDynamicValue(t *testing.T) { } } +func TestFunction(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.Function + expected *tfprotov6.Function + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.Function{ + DeprecationMessage: "test deprecation message", + Description: "test description", + DescriptionKind: tfprotov5.StringKindPlain, + Parameters: []*tfprotov5.FunctionParameter{ + { + Type: tftypes.String, + }, + }, + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + Summary: "test summary", + VariadicParameter: &tfprotov5.FunctionParameter{ + Type: tftypes.String, + }, + }, + expected: &tfprotov6.Function{ + DeprecationMessage: "test deprecation message", + Description: "test description", + DescriptionKind: tfprotov6.StringKindPlain, + Parameters: []*tfprotov6.FunctionParameter{ + { + Type: tftypes.String, + }, + }, + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + Summary: "test summary", + VariadicParameter: &tfprotov6.FunctionParameter{ + Type: tftypes.String, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.Function(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionMetadata(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in tfprotov5.FunctionMetadata + expected tfprotov6.FunctionMetadata + }{ + "all-valid-fields": { + in: tfprotov5.FunctionMetadata{ + Name: "test_function", + }, + expected: tfprotov6.FunctionMetadata{ + Name: "test_function", + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.FunctionMetadata(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionParameter(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.FunctionParameter + expected *tfprotov6.FunctionParameter + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.FunctionParameter{ + Description: "test description", + DescriptionKind: tfprotov5.StringKindPlain, + Type: tftypes.String, + }, + expected: &tfprotov6.FunctionParameter{ + Description: "test description", + DescriptionKind: tfprotov6.StringKindPlain, + Type: tftypes.String, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.FunctionParameter(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionReturn(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.FunctionReturn + expected *tfprotov6.FunctionReturn + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + expected: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.FunctionReturn(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetFunctionsRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.GetFunctionsRequest + expected *tfprotov6.GetFunctionsRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.GetFunctionsRequest{}, + expected: &tfprotov6.GetFunctionsRequest{}, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.GetFunctionsRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetFunctionsResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.GetFunctionsResponse + expected *tfprotov6.GetFunctionsResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.GetFunctionsResponse{ + Diagnostics: testTfprotov5Diagnostics, + Functions: map[string]*tfprotov5.Function{ + "test_function": testTfprotov5Function, + }, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: testTfprotov6Diagnostics, + Functions: map[string]*tfprotov6.Function{ + "test_function": testTfprotov6Function, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.GetFunctionsResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetMetadataRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.GetMetadataRequest + expected *tfprotov6.GetMetadataRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.GetMetadataRequest{}, + expected: &tfprotov6.GetMetadataRequest{}, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.GetMetadataRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetMetadataResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov5.GetMetadataResponse + expected *tfprotov6.GetMetadataResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov5.GetMetadataResponse{ + DataSources: []tfprotov5.DataSourceMetadata{ + testTfprotov5DataSourceMetadata, + }, + Diagnostics: testTfprotov5Diagnostics, + Functions: []tfprotov5.FunctionMetadata{ + testTfprotov5FunctionMetadata, + }, + Resources: []tfprotov5.ResourceMetadata{ + testTfprotov5ResourceMetadata, + }, + }, + expected: &tfprotov6.GetMetadataResponse{ + DataSources: []tfprotov6.DataSourceMetadata{ + testTfprotov6DataSourceMetadata, + }, + Diagnostics: testTfprotov6Diagnostics, + Functions: []tfprotov6.FunctionMetadata{ + testTfprotov6FunctionMetadata, + }, + Resources: []tfprotov6.ResourceMetadata{ + testTfprotov6ResourceMetadata, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov5tov6.GetMetadataResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + func TestGetProviderSchemaRequest(t *testing.T) { t.Parallel() @@ -375,7 +825,10 @@ func TestGetProviderSchemaResponse(t *testing.T) { DataSourceSchemas: map[string]*tfprotov5.Schema{ "test_data_source": testTfprotov5Schema, }, - Diagnostics: testTfprotov5Diagnostics, + Diagnostics: testTfprotov5Diagnostics, + Functions: map[string]*tfprotov5.Function{ + "test_function": testTfprotov5Function, + }, Provider: testTfprotov5Schema, ProviderMeta: testTfprotov5Schema, ResourceSchemas: map[string]*tfprotov5.Schema{ @@ -386,7 +839,10 @@ func TestGetProviderSchemaResponse(t *testing.T) { DataSourceSchemas: map[string]*tfprotov6.Schema{ "test_data_source": testTfprotov6Schema, }, - Diagnostics: testTfprotov6Diagnostics, + Diagnostics: testTfprotov6Diagnostics, + Functions: map[string]*tfprotov6.Function{ + "test_function": testTfprotov6Function, + }, Provider: testTfprotov6Schema, ProviderMeta: testTfprotov6Schema, ResourceSchemas: map[string]*tfprotov6.Schema{ diff --git a/internal/tfprotov6tov5/tfprotov6tov5.go b/internal/tfprotov6tov5/tfprotov6tov5.go index 39356fb..53751b4 100644 --- a/internal/tfprotov6tov5/tfprotov6tov5.go +++ b/internal/tfprotov6tov5/tfprotov6tov5.go @@ -41,6 +41,34 @@ func ApplyResourceChangeResponse(in *tfprotov6.ApplyResourceChangeResponse) *tfp } } +func CallFunctionRequest(in *tfprotov6.CallFunctionRequest) *tfprotov5.CallFunctionRequest { + if in == nil { + return nil + } + + out := &tfprotov5.CallFunctionRequest{ + Arguments: make([]*tfprotov5.DynamicValue, 0, len(in.Arguments)), + Name: in.Name, + } + + for _, argument := range in.Arguments { + out.Arguments = append(out.Arguments, DynamicValue(argument)) + } + + return out +} + +func CallFunctionResponse(in *tfprotov6.CallFunctionResponse) *tfprotov5.CallFunctionResponse { + if in == nil { + return nil + } + + return &tfprotov5.CallFunctionResponse{ + Diagnostics: Diagnostics(in.Diagnostics), + Result: DynamicValue(in.Result), + } +} + func ConfigureProviderRequest(in *tfprotov6.ConfigureProviderRequest) *tfprotov5.ConfigureProviderRequest { if in == nil { return nil @@ -103,6 +131,84 @@ func DynamicValue(in *tfprotov6.DynamicValue) *tfprotov5.DynamicValue { } } +func Function(in *tfprotov6.Function) *tfprotov5.Function { + if in == nil { + return nil + } + + out := &tfprotov5.Function{ + DeprecationMessage: in.DeprecationMessage, + Description: in.Description, + DescriptionKind: StringKind(in.DescriptionKind), + Parameters: make([]*tfprotov5.FunctionParameter, 0, len(in.Parameters)), + Return: FunctionReturn(in.Return), + Summary: in.Summary, + VariadicParameter: FunctionParameter(in.VariadicParameter), + } + + for _, parameter := range in.Parameters { + out.Parameters = append(out.Parameters, FunctionParameter(parameter)) + } + + return out +} + +func FunctionMetadata(in tfprotov6.FunctionMetadata) tfprotov5.FunctionMetadata { + return tfprotov5.FunctionMetadata{ + Name: in.Name, + } +} + +func FunctionParameter(in *tfprotov6.FunctionParameter) *tfprotov5.FunctionParameter { + if in == nil { + return nil + } + + return &tfprotov5.FunctionParameter{ + AllowNullValue: in.AllowNullValue, + AllowUnknownValues: in.AllowUnknownValues, + Description: in.Description, + DescriptionKind: StringKind(in.DescriptionKind), + Name: in.Name, + Type: in.Type, + } +} + +func FunctionReturn(in *tfprotov6.FunctionReturn) *tfprotov5.FunctionReturn { + if in == nil { + return nil + } + + return &tfprotov5.FunctionReturn{ + Type: in.Type, + } +} + +func GetFunctionsRequest(in *tfprotov6.GetFunctionsRequest) *tfprotov5.GetFunctionsRequest { + if in == nil { + return nil + } + + return &tfprotov5.GetFunctionsRequest{} +} + +func GetFunctionsResponse(in *tfprotov6.GetFunctionsResponse) *tfprotov5.GetFunctionsResponse { + if in == nil { + return nil + } + + functions := make(map[string]*tfprotov5.Function, len(in.Functions)) + + for name, function := range in.Functions { + functions[name] = Function(function) + } + + return &tfprotov5.GetFunctionsResponse{ + Diagnostics: Diagnostics(in.Diagnostics), + Functions: functions, + } +} + func GetMetadataRequest(in *tfprotov6.GetMetadataRequest) *tfprotov5.GetMetadataRequest { if in == nil { return nil @@ -119,6 +225,7 @@ func GetMetadataResponse(in *tfprotov6.GetMetadataResponse) *tfprotov5.GetMetada resp := &tfprotov5.GetMetadataResponse{ DataSources: make([]tfprotov5.DataSourceMetadata, 0, len(in.DataSources)), Diagnostics: Diagnostics(in.Diagnostics), + Functions: make([]tfprotov5.FunctionMetadata, 0, len(in.Functions)), Resources: make([]tfprotov5.ResourceMetadata, 0, len(in.Resources)), ServerCapabilities: ServerCapabilities(in.ServerCapabilities), } @@ -127,6 +234,10 @@ func GetMetadataResponse(in *tfprotov6.GetMetadataResponse) *tfprotov5.GetMetada resp.DataSources = append(resp.DataSources, DataSourceMetadata(datasource)) } + for _, function := range in.Functions { + resp.Functions = append(resp.Functions, FunctionMetadata(function)) + } + for _, resource := range in.Resources { resp.Resources = append(resp.Resources, ResourceMetadata(resource)) } @@ -159,6 +270,12 @@ func GetProviderSchemaResponse(in *tfprotov6.GetProviderSchemaResponse) (*tfprot dataSourceSchemas[k] = v5Schema } + functions := make(map[string]*tfprotov5.Function, len(in.Functions)) + + for name, function := range in.Functions { + functions[name] = Function(function) + } + provider, err := Schema(in.Provider) if err != nil { @@ -186,6 +303,7 @@ func GetProviderSchemaResponse(in *tfprotov6.GetProviderSchemaResponse) (*tfprot return &tfprotov5.GetProviderSchemaResponse{ DataSourceSchemas: dataSourceSchemas, Diagnostics: Diagnostics(in.Diagnostics), + Functions: functions, Provider: provider, ProviderMeta: providerMeta, ResourceSchemas: resourceSchemas, diff --git a/internal/tfprotov6tov5/tfprotov6tov5_test.go b/internal/tfprotov6tov5/tfprotov6tov5_test.go index 8d76fd9..f400df9 100644 --- a/internal/tfprotov6tov5/tfprotov6tov5_test.go +++ b/internal/tfprotov6tov5/tfprotov6tov5_test.go @@ -18,6 +18,14 @@ import ( var ( testBytes []byte = []byte("test") + testTfprotov5DataSourceMetadata tfprotov5.DataSourceMetadata = tfprotov5.DataSourceMetadata{ + TypeName: "test_data_source", + } + + testTfprotov6DataSourceMetadata tfprotov6.DataSourceMetadata = tfprotov6.DataSourceMetadata{ + TypeName: "test_data_source", + } + testTfprotov5Diagnostics []*tfprotov5.Diagnostic = []*tfprotov5.Diagnostic{ { Detail: "test detail", @@ -36,6 +44,36 @@ var ( testTfprotov5DynamicValue tfprotov5.DynamicValue testTfprotov6DynamicValue tfprotov6.DynamicValue + testTfprotov5Function *tfprotov5.Function = &tfprotov5.Function{ + Parameters: []*tfprotov5.FunctionParameter{}, + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + } + + testTfprotov5FunctionMetadata tfprotov5.FunctionMetadata = tfprotov5.FunctionMetadata{ + Name: "test_function", + } + + testTfprotov6Function *tfprotov6.Function = &tfprotov6.Function{ + Parameters: []*tfprotov6.FunctionParameter{}, + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + } + + testTfprotov6FunctionMetadata tfprotov6.FunctionMetadata = tfprotov6.FunctionMetadata{ + Name: "test_function", + } + + testTfprotov5ResourceMetadata tfprotov5.ResourceMetadata = tfprotov5.ResourceMetadata{ + TypeName: "test_resource", + } + + testTfprotov6ResourceMetadata tfprotov6.ResourceMetadata = tfprotov6.ResourceMetadata{ + TypeName: "test_resource", + } + testTfprotov5Schema *tfprotov5.Schema = &tfprotov5.Schema{ Block: &tfprotov5.SchemaBlock{ Attributes: []*tfprotov5.SchemaAttribute{ @@ -155,6 +193,86 @@ func TestApplyResourceChangeResponse(t *testing.T) { } } +func TestCallFunctionRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.CallFunctionRequest + expected *tfprotov5.CallFunctionRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.CallFunctionRequest{ + Arguments: []*tfprotov6.DynamicValue{ + &testTfprotov6DynamicValue, + }, + Name: "test_function", + }, + expected: &tfprotov5.CallFunctionRequest{ + Arguments: []*tfprotov5.DynamicValue{ + &testTfprotov5DynamicValue, + }, + Name: "test_function", + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.CallFunctionRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestCallFunctionResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.CallFunctionResponse + expected *tfprotov5.CallFunctionResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.CallFunctionResponse{ + Diagnostics: testTfprotov6Diagnostics, + Result: &testTfprotov6DynamicValue, + }, + expected: &tfprotov5.CallFunctionResponse{ + Diagnostics: testTfprotov5Diagnostics, + Result: &testTfprotov5DynamicValue, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.CallFunctionResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + func TestConfigureProviderRequest(t *testing.T) { t.Parallel() @@ -329,6 +447,338 @@ func TestDynamicValue(t *testing.T) { } } +func TestFunction(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.Function + expected *tfprotov5.Function + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.Function{ + DeprecationMessage: "test deprecation message", + Description: "test description", + DescriptionKind: tfprotov6.StringKindPlain, + Parameters: []*tfprotov6.FunctionParameter{ + { + Type: tftypes.String, + }, + }, + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + Summary: "test summary", + VariadicParameter: &tfprotov6.FunctionParameter{ + Type: tftypes.String, + }, + }, + expected: &tfprotov5.Function{ + DeprecationMessage: "test deprecation message", + Description: "test description", + DescriptionKind: tfprotov5.StringKindPlain, + Parameters: []*tfprotov5.FunctionParameter{ + { + Type: tftypes.String, + }, + }, + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + Summary: "test summary", + VariadicParameter: &tfprotov5.FunctionParameter{ + Type: tftypes.String, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.Function(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionMetadata(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in tfprotov6.FunctionMetadata + expected tfprotov5.FunctionMetadata + }{ + "all-valid-fields": { + in: tfprotov6.FunctionMetadata{ + Name: "test_function", + }, + expected: tfprotov5.FunctionMetadata{ + Name: "test_function", + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.FunctionMetadata(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionParameter(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.FunctionParameter + expected *tfprotov5.FunctionParameter + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.FunctionParameter{ + Description: "test description", + DescriptionKind: tfprotov6.StringKindPlain, + Type: tftypes.String, + }, + expected: &tfprotov5.FunctionParameter{ + Description: "test description", + DescriptionKind: tfprotov5.StringKindPlain, + Type: tftypes.String, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.FunctionParameter(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestFunctionReturn(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.FunctionReturn + expected *tfprotov5.FunctionReturn + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + expected: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.FunctionReturn(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetFunctionsRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.GetFunctionsRequest + expected *tfprotov5.GetFunctionsRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.GetFunctionsRequest{}, + expected: &tfprotov5.GetFunctionsRequest{}, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.GetFunctionsRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetFunctionsResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.GetFunctionsResponse + expected *tfprotov5.GetFunctionsResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.GetFunctionsResponse{ + Diagnostics: testTfprotov6Diagnostics, + Functions: map[string]*tfprotov6.Function{ + "test_function": testTfprotov6Function, + }, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: testTfprotov5Diagnostics, + Functions: map[string]*tfprotov5.Function{ + "test_function": testTfprotov5Function, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.GetFunctionsResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetMetadataRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.GetMetadataRequest + expected *tfprotov5.GetMetadataRequest + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.GetMetadataRequest{}, + expected: &tfprotov5.GetMetadataRequest{}, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.GetMetadataRequest(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + +func TestGetMetadataResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in *tfprotov6.GetMetadataResponse + expected *tfprotov5.GetMetadataResponse + }{ + "nil": { + in: nil, + expected: nil, + }, + "all-valid-fields": { + in: &tfprotov6.GetMetadataResponse{ + DataSources: []tfprotov6.DataSourceMetadata{ + testTfprotov6DataSourceMetadata, + }, + Diagnostics: testTfprotov6Diagnostics, + Functions: []tfprotov6.FunctionMetadata{ + testTfprotov6FunctionMetadata, + }, + Resources: []tfprotov6.ResourceMetadata{ + testTfprotov6ResourceMetadata, + }, + }, + expected: &tfprotov5.GetMetadataResponse{ + DataSources: []tfprotov5.DataSourceMetadata{ + testTfprotov5DataSourceMetadata, + }, + Diagnostics: testTfprotov5Diagnostics, + Functions: []tfprotov5.FunctionMetadata{ + testTfprotov5FunctionMetadata, + }, + Resources: []tfprotov5.ResourceMetadata{ + testTfprotov5ResourceMetadata, + }, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tfprotov6tov5.GetMetadataResponse(testCase.in) + + if diff := cmp.Diff(got, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} + func TestGetProviderSchemaRequest(t *testing.T) { t.Parallel() @@ -378,7 +828,10 @@ func TestGetProviderSchemaResponse(t *testing.T) { DataSourceSchemas: map[string]*tfprotov6.Schema{ "test_data_source": testTfprotov6Schema, }, - Diagnostics: testTfprotov6Diagnostics, + Diagnostics: testTfprotov6Diagnostics, + Functions: map[string]*tfprotov6.Function{ + "test_function": testTfprotov6Function, + }, Provider: testTfprotov6Schema, ProviderMeta: testTfprotov6Schema, ResourceSchemas: map[string]*tfprotov6.Schema{ @@ -389,7 +842,10 @@ func TestGetProviderSchemaResponse(t *testing.T) { DataSourceSchemas: map[string]*tfprotov5.Schema{ "test_data_source": testTfprotov5Schema, }, - Diagnostics: testTfprotov5Diagnostics, + Diagnostics: testTfprotov5Diagnostics, + Functions: map[string]*tfprotov5.Function{ + "test_function": testTfprotov5Function, + }, Provider: testTfprotov5Schema, ProviderMeta: testTfprotov5Schema, ResourceSchemas: map[string]*tfprotov5.Schema{ diff --git a/tf5muxserver/diagnostics.go b/tf5muxserver/diagnostics.go index a9c1321..4348d23 100644 --- a/tf5muxserver/diagnostics.go +++ b/tf5muxserver/diagnostics.go @@ -40,6 +40,27 @@ func diagnosticsHasError(diagnostics []*tfprotov5.Diagnostic) bool { return false } +func functionDuplicateError(name string) *tfprotov5.Diagnostic { + return &tfprotov5.Diagnostic{ + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: " + name, + } +} + +func functionMissingError(name string) *tfprotov5.Diagnostic { + return &tfprotov5.Diagnostic{ + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Function Not Implemented", + Detail: "The combined provider does not implement the requested function. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Missing function: " + name, + } +} + func resourceDuplicateError(typeName string) *tfprotov5.Diagnostic { return &tfprotov5.Diagnostic{ Severity: tfprotov5.DiagnosticSeverityError, diff --git a/tf5muxserver/mux_server.go b/tf5muxserver/mux_server.go index 90b0ba9..91137db 100644 --- a/tf5muxserver/mux_server.go +++ b/tf5muxserver/mux_server.go @@ -22,6 +22,9 @@ type muxServer struct { // Routing for data source types dataSources map[string]tfprotov5.ProviderServer + // Routing for functions + functions map[string]tfprotov5.ProviderServer + // Routing for resource types resources map[string]tfprotov5.ProviderServer @@ -87,6 +90,41 @@ func (s *muxServer) getDataSourceServer(ctx context.Context, typeName string) (t return server, s.serverDiscoveryDiagnostics, nil } +func (s *muxServer) getFunctionServer(ctx context.Context, name string) (tfprotov5.ProviderServer, []*tfprotov5.Diagnostic, error) { + s.serverDiscoveryMutex.RLock() + server, ok := s.functions[name] + discoveryComplete := s.serverDiscoveryComplete + s.serverDiscoveryMutex.RUnlock() + + if discoveryComplete { + if ok { + return server, s.serverDiscoveryDiagnostics, nil + } + + return nil, []*tfprotov5.Diagnostic{ + functionMissingError(name), + }, nil + } + + err := s.serverDiscovery(ctx) + + if err != nil || diagnosticsHasError(s.serverDiscoveryDiagnostics) { + return nil, s.serverDiscoveryDiagnostics, err + } + + s.serverDiscoveryMutex.RLock() + server, ok = s.functions[name] + s.serverDiscoveryMutex.RUnlock() + + if !ok { + return nil, []*tfprotov5.Diagnostic{ + functionMissingError(name), + }, nil + } + + return server, s.serverDiscoveryDiagnostics, nil +} + func (s *muxServer) getResourceServer(ctx context.Context, typeName string) (tfprotov5.ProviderServer, []*tfprotov5.Diagnostic, error) { s.serverDiscoveryMutex.RLock() server, ok := s.resources[typeName] @@ -122,10 +160,10 @@ func (s *muxServer) getResourceServer(ctx context.Context, typeName string) (tfp return server, s.serverDiscoveryDiagnostics, nil } -// serverDiscovery will populate the mux server "routing" for resource types by -// calling all underlying server GetMetadata RPC and falling back to -// GetProviderSchema RPC. It is intended to only be called through -// getDataSourceServer and getResourceServer. +// serverDiscovery will populate the mux server "routing" for functions and +// resource types by calling all underlying server GetMetadata RPC and falling +// back to GetProviderSchema RPC. It is intended to only be called through +// getDataSourceServer, getFunctionServer, and getResourceServer. // // The error return represents gRPC errors, which except for the GetMetadata // call returning the gRPC unimplemented error, is always returned. @@ -163,6 +201,16 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { s.dataSources[serverDataSource.TypeName] = server } + for _, serverFunction := range metadataResp.Functions { + if _, ok := s.functions[serverFunction.Name]; ok { + s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, functionDuplicateError(serverFunction.Name)) + + continue + } + + s.functions[serverFunction.Name] = server + } + for _, serverResource := range metadataResp.Resources { if _, ok := s.resources[serverResource.TypeName]; ok { s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, resourceDuplicateError(serverResource.TypeName)) @@ -205,6 +253,16 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { s.dataSources[typeName] = server } + for name := range providerSchemaResp.Functions { + if _, ok := s.functions[name]; ok { + s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + } + for typeName := range providerSchemaResp.ResourceSchemas { if _, ok := s.resources[typeName]; ok { s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, resourceDuplicateError(typeName)) @@ -230,9 +288,11 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { // - All provider meta schemas exactly match // - Only one provider implements each managed resource // - Only one provider implements each data source +// - Only one provider implements each function func NewMuxServer(_ context.Context, servers ...func() tfprotov5.ProviderServer) (*muxServer, error) { result := muxServer{ dataSources: make(map[string]tfprotov5.ProviderServer), + functions: make(map[string]tfprotov5.ProviderServer), resources: make(map[string]tfprotov5.ProviderServer), resourceCapabilities: make(map[string]*tfprotov5.ServerCapabilities), } diff --git a/tf5muxserver/mux_server_CallFunction.go b/tf5muxserver/mux_server_CallFunction.go new file mode 100644 index 0000000..7e4ec8c --- /dev/null +++ b/tf5muxserver/mux_server_CallFunction.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf5muxserver + +import ( + "context" + + "github.com/hashicorp/terraform-plugin-go/tfprotov5" + "github.com/hashicorp/terraform-plugin-mux/internal/logging" +) + +// CallFunction calls the CallFunction method of the underlying provider +// serving the function. +func (s *muxServer) CallFunction(ctx context.Context, req *tfprotov5.CallFunctionRequest) (*tfprotov5.CallFunctionResponse, error) { + rpc := "CallFunction" + ctx = logging.InitContext(ctx) + ctx = logging.RpcContext(ctx, rpc) + + server, diags, err := s.getFunctionServer(ctx, req.Name) + + if err != nil { + return nil, err + } + + if diagnosticsHasError(diags) { + return &tfprotov5.CallFunctionResponse{ + Diagnostics: diags, + }, nil + } + + ctx = logging.Tfprotov5ProviderServerContext(ctx, server) + + // Remove and call server.CallFunction below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := server.(tfprotov5.FunctionServer) + + if !ok { + resp := &tfprotov5.CallFunctionResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Provider Functions Not Implemented", + Detail: "A provider-defined function call was received by the provider, however the provider does not implement functions. " + + "Either upgrade the provider to a version that implements provider-defined functions or this is a bug in Terraform that should be reported to the Terraform maintainers.", + }, + }, + } + + return resp, nil + } + + logging.MuxTrace(ctx, "calling downstream server") + + // return server.CallFunction(ctx, req) + return functionServer.CallFunction(ctx, req) +} diff --git a/tf5muxserver/mux_server_CallFunction_test.go b/tf5muxserver/mux_server_CallFunction_test.go new file mode 100644 index 0000000..66c1256 --- /dev/null +++ b/tf5muxserver/mux_server_CallFunction_test.go @@ -0,0 +1,82 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf5muxserver_test + +import ( + "context" + "testing" + + "github.com/hashicorp/terraform-plugin-go/tfprotov5" + + "github.com/hashicorp/terraform-plugin-mux/internal/tf5testserver" + "github.com/hashicorp/terraform-plugin-mux/tf5muxserver" +) + +func TestMuxServerCallFunction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function1": {}, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Fatal("muxServer should implement tfprotov5.FunctionServer") + } + + // _, err = muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function1", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } + + if testServer2.CallFunctionCalled["test_function1"] { + t.Errorf("unexpected test_function1 CallFunction called on server2") + } + + // _, err = muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function2", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if testServer1.CallFunctionCalled["test_function2"] { + t.Errorf("unexpected test_function2 CallFunction called on server1") + } + + if !testServer2.CallFunctionCalled["test_function2"] { + t.Errorf("expected test_function2 CallFunction to be called on server2") + } +} diff --git a/tf5muxserver/mux_server_GetFunctions.go b/tf5muxserver/mux_server_GetFunctions.go new file mode 100644 index 0000000..c8927fc --- /dev/null +++ b/tf5muxserver/mux_server_GetFunctions.go @@ -0,0 +1,68 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf5muxserver + +import ( + "context" + "fmt" + + "github.com/hashicorp/terraform-plugin-go/tfprotov5" + + "github.com/hashicorp/terraform-plugin-mux/internal/logging" +) + +// GetFunctions merges the functions returned by the tfprotov5.ProviderServers +// associated with muxServer into a single response. Functions must be returned +// from only one server or an error diagnostic is returned. +func (s *muxServer) GetFunctions(ctx context.Context, req *tfprotov5.GetFunctionsRequest) (*tfprotov5.GetFunctionsResponse, error) { + rpc := "GetFunctions" + ctx = logging.InitContext(ctx) + ctx = logging.RpcContext(ctx, rpc) + + s.serverDiscoveryMutex.Lock() + defer s.serverDiscoveryMutex.Unlock() + + resp := &tfprotov5.GetFunctionsResponse{ + Functions: make(map[string]*tfprotov5.Function), + } + + for _, server := range s.servers { + ctx := logging.Tfprotov5ProviderServerContext(ctx, server) + + // Remove and call server.GetFunctions below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := server.(tfprotov5.FunctionServer) + + if !ok { + continue + } + + logging.MuxTrace(ctx, "calling downstream server") + + // serverResp, err := server.GetFunctions(ctx, &tfprotov5.GetFunctionsRequest{}) + serverResp, err := functionServer.GetFunctions(ctx, &tfprotov5.GetFunctionsRequest{}) + + if err != nil { + return resp, fmt.Errorf("error calling GetFunctions for %T: %w", server, err) + } + + resp.Diagnostics = append(resp.Diagnostics, serverResp.Diagnostics...) + + for name, definition := range serverResp.Functions { + if _, ok := resp.Functions[name]; ok { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + resp.Functions[name] = definition + } + } + + // Intentionally not setting overall server discovery as complete, as data + // sources and resources are not discovered via this RPC. + + return resp, nil +} diff --git a/tf5muxserver/mux_server_GetFunctions_test.go b/tf5muxserver/mux_server_GetFunctions_test.go new file mode 100644 index 0000000..4482913 --- /dev/null +++ b/tf5muxserver/mux_server_GetFunctions_test.go @@ -0,0 +1,333 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf5muxserver_test + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform-plugin-go/tfprotov5" + "github.com/hashicorp/terraform-plugin-go/tftypes" + + "github.com/hashicorp/terraform-plugin-mux/internal/tf5testserver" + "github.com/hashicorp/terraform-plugin-mux/tf5muxserver" +) + +func TestMuxServerGetFunctions(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + servers []func() tfprotov5.ProviderServer + expected *tfprotov5.GetFunctionsResponse + }{ + "combined": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function1": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function2": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function1": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function2": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }, + "duplicate-function": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + Functions: map[string]*tfprotov5.Function{ + "test_function": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }, + "error-once": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov5.Function{}, + }, + }, + "error-multiple": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov5.Function{}, + }, + }, + "warning-once": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + Functions: map[string]*tfprotov5.Function{}, + }, + }, + "warning-multiple": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + Functions: map[string]*tfprotov5.Function{}, + }, + }, + "warning-then-error": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{}).ProviderServer, + (&tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov5.GetFunctionsResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov5.Function{}, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + muxServer, err := tf5muxserver.NewMuxServer(context.Background(), testCase.servers...) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Fatal("muxServer should implement tfprotov5.FunctionServer") + } + + // resp, err := muxServer.ProviderServer().GetFunctions(context.Background(), &tfprotov5.GetFunctionsRequest{}) + resp, err := functionServer.GetFunctions(context.Background(), &tfprotov5.GetFunctionsRequest{}) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if diff := cmp.Diff(resp, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} diff --git a/tf5muxserver/mux_server_GetMetadata.go b/tf5muxserver/mux_server_GetMetadata.go index 10b76d6..bd3e1b4 100644 --- a/tf5muxserver/mux_server_GetMetadata.go +++ b/tf5muxserver/mux_server_GetMetadata.go @@ -26,6 +26,7 @@ func (s *muxServer) GetMetadata(ctx context.Context, req *tfprotov5.GetMetadataR resp := &tfprotov5.GetMetadataResponse{ DataSources: make([]tfprotov5.DataSourceMetadata, 0), + Functions: make([]tfprotov5.FunctionMetadata, 0), Resources: make([]tfprotov5.ResourceMetadata, 0), ServerCapabilities: serverCapabilities, } @@ -53,6 +54,17 @@ func (s *muxServer) GetMetadata(ctx context.Context, req *tfprotov5.GetMetadataR resp.DataSources = append(resp.DataSources, datasource) } + for _, function := range serverResp.Functions { + if functionMetadataContainsName(resp.Functions, function.Name) { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(function.Name)) + + continue + } + + s.functions[function.Name] = server + resp.Functions = append(resp.Functions, function) + } + for _, resource := range serverResp.Resources { if resourceMetadataContainsTypeName(resp.Resources, resource.TypeName) { resp.Diagnostics = append(resp.Diagnostics, resourceDuplicateError(resource.TypeName)) @@ -79,6 +91,16 @@ func datasourceMetadataContainsTypeName(metadatas []tfprotov5.DataSourceMetadata return false } +func functionMetadataContainsName(metadatas []tfprotov5.FunctionMetadata, name string) bool { + for _, metadata := range metadatas { + if name == metadata.Name { + return true + } + } + + return false +} + func resourceMetadataContainsTypeName(metadatas []tfprotov5.ResourceMetadata, typeName string) bool { for _, metadata := range metadatas { if typeName == metadata.TypeName { diff --git a/tf5muxserver/mux_server_GetMetadata_test.go b/tf5muxserver/mux_server_GetMetadata_test.go index 0184a72..fe30d25 100644 --- a/tf5muxserver/mux_server_GetMetadata_test.go +++ b/tf5muxserver/mux_server_GetMetadata_test.go @@ -21,6 +21,7 @@ func TestMuxServerGetMetadata(t *testing.T) { servers []func() tfprotov5.ProviderServer expectedDataSources []tfprotov5.DataSourceMetadata expectedDiagnostics []*tfprotov5.Diagnostic + expectedFunctions []tfprotov5.FunctionMetadata expectedResources []tfprotov5.ResourceMetadata expectedServerCapabilities *tfprotov5.ServerCapabilities }{ @@ -41,6 +42,11 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_foo", }, }, + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function1", + }, + }, }, }).ProviderServer, (&tf5testserver.TestServer{ @@ -58,6 +64,14 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_quux", }, }, + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function2", + }, + { + Name: "test_function3", + }, + }, }, }).ProviderServer, }, @@ -83,6 +97,17 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_quux", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function1", + }, + { + Name: "test_function2", + }, + { + Name: "test_function3", + }, + }, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, PlanDestroy: true, @@ -124,6 +149,50 @@ func TestMuxServerGetMetadata(t *testing.T) { "Duplicate data source type: test_foo", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, + expectedResources: []tfprotov5.ResourceMetadata{}, + expectedServerCapabilities: &tfprotov5.ServerCapabilities{ + GetProviderSchemaOptional: true, + PlanDestroy: true, + }, + }, + "duplicate-function": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function", + }, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function", + }, + }, + }, + }).ProviderServer, + }, + expectedDataSources: []tfprotov5.DataSourceMetadata{}, + expectedDiagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + expectedFunctions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function", + }, + }, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -162,6 +231,7 @@ func TestMuxServerGetMetadata(t *testing.T) { "Duplicate resource type: test_foo", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{ { TypeName: "test_foo", @@ -198,6 +268,7 @@ func TestMuxServerGetMetadata(t *testing.T) { }).ProviderServer, }, expectedDataSources: []tfprotov5.DataSourceMetadata{}, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{ { TypeName: "test_with_server_capabilities", @@ -235,6 +306,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -280,6 +352,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -310,6 +383,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -355,6 +429,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -400,6 +475,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov5.FunctionMetadata{}, expectedResources: []tfprotov5.ResourceMetadata{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -434,6 +510,10 @@ func TestMuxServerGetMetadata(t *testing.T) { t.Errorf("diagnostics didn't match expectations: %s", diff) } + if diff := cmp.Diff(resp.Functions, testCase.expectedFunctions); diff != "" { + t.Errorf("functions didn't match expectations: %s", diff) + } + if diff := cmp.Diff(resp.Resources, testCase.expectedResources); diff != "" { t.Errorf("resources didn't match expectations: %s", diff) } diff --git a/tf5muxserver/mux_server_GetProviderSchema.go b/tf5muxserver/mux_server_GetProviderSchema.go index 77320c0..19bfd87 100644 --- a/tf5muxserver/mux_server_GetProviderSchema.go +++ b/tf5muxserver/mux_server_GetProviderSchema.go @@ -26,6 +26,7 @@ func (s *muxServer) GetProviderSchema(ctx context.Context, req *tfprotov5.GetPro resp := &tfprotov5.GetProviderSchemaResponse{ DataSourceSchemas: make(map[string]*tfprotov5.Schema), + Functions: make(map[string]*tfprotov5.Function), ResourceSchemas: make(map[string]*tfprotov5.Schema), ServerCapabilities: serverCapabilities, } @@ -94,6 +95,17 @@ func (s *muxServer) GetProviderSchema(ctx context.Context, req *tfprotov5.GetPro s.dataSources[dataSourceType] = server resp.DataSourceSchemas[dataSourceType] = schema } + + for name, definition := range serverResp.Functions { + if _, ok := resp.Functions[name]; ok { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + resp.Functions[name] = definition + } } s.serverDiscoveryComplete = true diff --git a/tf5muxserver/mux_server_GetProviderSchema_test.go b/tf5muxserver/mux_server_GetProviderSchema_test.go index 9a23cc1..a923a17 100644 --- a/tf5muxserver/mux_server_GetProviderSchema_test.go +++ b/tf5muxserver/mux_server_GetProviderSchema_test.go @@ -22,6 +22,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { servers []func() tfprotov5.ProviderServer expectedDataSourceSchemas map[string]*tfprotov5.Schema expectedDiagnostics []*tfprotov5.Diagnostic + expectedFunctions map[string]*tfprotov5.Function expectedProviderSchema *tfprotov5.Schema expectedProviderMetaSchema *tfprotov5.Schema expectedResourceSchemas map[string]*tfprotov5.Schema @@ -144,6 +145,13 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + Functions: map[string]*tfprotov5.Function{ + "test_function1": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, }, }).ProviderServer, (&tf5testserver.TestServer{ @@ -259,6 +267,18 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + Functions: map[string]*tfprotov5.Function{ + "test_function2": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, }, }).ProviderServer, }, @@ -425,6 +445,23 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + expectedFunctions: map[string]*tfprotov5.Function{ + "test_function1": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function2": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov5.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, PlanDestroy: true, @@ -460,6 +497,44 @@ func TestMuxServerGetProviderSchema(t *testing.T) { "Duplicate data source type: test_foo", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, + expectedResourceSchemas: map[string]*tfprotov5.Schema{}, + expectedServerCapabilities: &tfprotov5.ServerCapabilities{ + GetProviderSchemaOptional: true, + PlanDestroy: true, + }, + }, + "duplicate-function": { + servers: []func() tfprotov5.ProviderServer{ + (&tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, + }, + }).ProviderServer, + (&tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, + }, + }).ProviderServer, + }, + expectedDataSourceSchemas: map[string]*tfprotov5.Schema{}, + expectedDiagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + expectedFunctions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, expectedServerCapabilities: &tfprotov5.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -494,6 +569,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { "Duplicate resource type: test_foo", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{ "test_foo": {}, }, @@ -569,6 +645,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { ), }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedProviderSchema: &tfprotov5.Schema{ Block: &tfprotov5.SchemaBlock{ Attributes: []*tfprotov5.SchemaAttribute{ @@ -653,6 +730,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { ), }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedProviderMetaSchema: &tfprotov5.Schema{ Block: &tfprotov5.SchemaBlock{ Attributes: []*tfprotov5.SchemaAttribute{ @@ -692,6 +770,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }).ProviderServer, }, expectedDataSourceSchemas: map[string]*tfprotov5.Schema{}, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{ "test_with_server_capabilities": {}, "test_without_server_capabilities": {}, @@ -725,6 +804,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, }, "error-multiple": { @@ -766,6 +846,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, }, "warning-once": { @@ -792,6 +873,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, }, "warning-multiple": { @@ -833,6 +915,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, }, "warning-then-error": { @@ -874,6 +957,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov5.Function{}, expectedResourceSchemas: map[string]*tfprotov5.Schema{}, }, } @@ -904,6 +988,10 @@ func TestMuxServerGetProviderSchema(t *testing.T) { t.Errorf("diagnostics didn't match expectations: %s", diff) } + if diff := cmp.Diff(resp.Functions, testCase.expectedFunctions); diff != "" { + t.Errorf("functions didn't match expectations: %s", diff) + } + if diff := cmp.Diff(resp.Provider, testCase.expectedProviderSchema); diff != "" { t.Errorf("provider schema didn't match expectations: %s", diff) } diff --git a/tf5muxserver/mux_server_test.go b/tf5muxserver/mux_server_test.go index 78ebbb1..fd0a428 100644 --- a/tf5muxserver/mux_server_test.go +++ b/tf5muxserver/mux_server_test.go @@ -399,6 +399,436 @@ func TestMuxServerGetDataSourceServer_Missing(t *testing.T) { } } +func TestMuxServerGetFunctionServer_GetProviderSchema(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function1": {}, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // _, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_GetProviderSchema_Duplicate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, // intentionally duplicated + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, // intentionally duplicated + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server2") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function2", + }, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // _, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata_Duplicate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function", // intentionally duplicated + }, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function", // intentionally duplicated + }, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server2") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata_Partial(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // _, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_Missing(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf5testserver.TestServer{ + GetMetadataResponse: &tfprotov5.GetMetadataResponse{ + Functions: []tfprotov5.FunctionMetadata{ + { + Name: "test_function2", + }, + }, + }, + } + + servers := []func() tfprotov5.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf5muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Function Not Implemented", + Detail: "The combined provider does not implement the requested function. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Missing function: test_function_nonexistent", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov5.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov5.FunctionServer") + } + + // resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function_nonexistent", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function_nonexistent"] { + t.Errorf("unexpected test_function_nonexistent CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function_nonexistent"] { + t.Errorf("unexpected test_function_nonexistent CallFunction called on server2") + } +} + func TestMuxServerGetResourceServer_GetProviderSchema(t *testing.T) { t.Parallel() diff --git a/tf5to6server/tf5to6server.go b/tf5to6server/tf5to6server.go index dadd7b5..8fae458 100644 --- a/tf5to6server/tf5to6server.go +++ b/tf5to6server/tf5to6server.go @@ -43,6 +43,38 @@ func (s v5tov6Server) ApplyResourceChange(ctx context.Context, req *tfprotov6.Ap return tfprotov5tov6.ApplyResourceChangeResponse(v5Resp), nil } +func (s v5tov6Server) CallFunction(ctx context.Context, req *tfprotov6.CallFunctionRequest) (*tfprotov6.CallFunctionResponse, error) { + // Remove and call s.v5Server.CallFunction below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := s.v5Server.(tfprotov5.FunctionServer) + + if !ok { + v6Resp := &tfprotov6.CallFunctionResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Provider Functions Not Implemented", + Detail: "A provider-defined function call was received by the provider, however the provider does not implement functions. " + + "Either upgrade the provider to a version that implements provider-defined functions or this is a bug in Terraform that should be reported to the Terraform maintainers.", + }, + }, + } + + return v6Resp, nil + } + + v5Req := tfprotov6tov5.CallFunctionRequest(req) + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + // v5Resp, err := s.v5Server.CallFunction(ctx, v5Req) + v5Resp, err := functionServer.CallFunction(ctx, v5Req) + + if err != nil { + return nil, err + } + + return tfprotov5tov6.CallFunctionResponse(v5Resp), nil +} + func (s v5tov6Server) ConfigureProvider(ctx context.Context, req *tfprotov6.ConfigureProviderRequest) (*tfprotov6.ConfigureProviderResponse, error) { v5Req := tfprotov6tov5.ConfigureProviderRequest(req) v5Resp, err := s.v5Server.ConfigureProvider(ctx, v5Req) @@ -54,6 +86,31 @@ func (s v5tov6Server) ConfigureProvider(ctx context.Context, req *tfprotov6.Conf return tfprotov5tov6.ConfigureProviderResponse(v5Resp), nil } +func (s v5tov6Server) GetFunctions(ctx context.Context, req *tfprotov6.GetFunctionsRequest) (*tfprotov6.GetFunctionsResponse, error) { + // Remove and call s.v5Server.GetFunctions below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := s.v5Server.(tfprotov5.FunctionServer) + + if !ok { + v6Resp := &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{}, + } + + return v6Resp, nil + } + + v5Req := tfprotov6tov5.GetFunctionsRequest(req) + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + // v5Resp, err := s.v5Server.GetFunctions(ctx, v5Req) + v5Resp, err := functionServer.GetFunctions(ctx, v5Req) + + if err != nil { + return nil, err + } + + return tfprotov5tov6.GetFunctionsResponse(v5Resp), nil +} + func (s v5tov6Server) GetMetadata(ctx context.Context, req *tfprotov6.GetMetadataRequest) (*tfprotov6.GetMetadataResponse, error) { v5Req := tfprotov6tov5.GetMetadataRequest(req) v5Resp, err := s.v5Server.GetMetadata(ctx, v5Req) diff --git a/tf5to6server/tf5to6server_test.go b/tf5to6server/tf5to6server_test.go index 7f7d6b9..1856409 100644 --- a/tf5to6server/tf5to6server_test.go +++ b/tf5to6server/tf5to6server_test.go @@ -29,6 +29,9 @@ func TestUpgradeServer(t *testing.T) { DataSourceSchemas: map[string]*tfprotov5.Schema{ "test_data_source": {}, }, + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, Provider: &tfprotov5.Schema{ Block: &tfprotov5.SchemaBlock{ Attributes: []*tfprotov5.SchemaAttribute{ @@ -157,6 +160,45 @@ func TestV6ToV5ServerApplyResourceChange(t *testing.T) { } } +func TestV6ToV5ServerCallFunction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + v5server := &tf5testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov5.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, + }, + } + + v6server, err := tf5to6server.UpgradeServer(context.Background(), v5server.ProviderServer) + + if err != nil { + t.Fatalf("unexpected error upgrading server: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := v6server.(tfprotov6.FunctionServer) + + if !ok { + t.Fatal("v6server should implement tfprotov6.FunctionServer") + } + + //_, err = v6server.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !v5server.CallFunctionCalled["test_function"] { + t.Errorf("expected test_function CallFunction to be called") + } +} + func TestV6ToV5ServerConfigureProvider(t *testing.T) { t.Parallel() @@ -186,6 +228,43 @@ func TestV6ToV5ServerConfigureProvider(t *testing.T) { } } +func TestV6ToV5ServerGetFunctions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + v5server := &tf5testserver.TestServer{ + GetFunctionsResponse: &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{ + "test_function": {}, + }, + }, + } + + v6server, err := tf5to6server.UpgradeServer(context.Background(), v5server.ProviderServer) + + if err != nil { + t.Fatalf("unexpected error upgrading server: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := v6server.(tfprotov6.FunctionServer) + + if !ok { + t.Fatal("v6server should implement tfprotov6.FunctionServer") + } + + //_, err = v6server.GetFunction(ctx, &tfprotov6.GetFunctionRequest{}) + _, err = functionServer.GetFunctions(ctx, &tfprotov6.GetFunctionsRequest{}) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !v5server.GetFunctionsCalled { + t.Errorf("expected GetFunctions to be called") + } +} + func TestV6ToV5ServerGetMetadata(t *testing.T) { t.Parallel() diff --git a/tf6muxserver/diagnostics.go b/tf6muxserver/diagnostics.go index 89d448f..80572fc 100644 --- a/tf6muxserver/diagnostics.go +++ b/tf6muxserver/diagnostics.go @@ -42,6 +42,27 @@ func diagnosticsHasError(diagnostics []*tfprotov6.Diagnostic) bool { return false } +func functionDuplicateError(name string) *tfprotov6.Diagnostic { + return &tfprotov6.Diagnostic{ + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: " + name, + } +} + +func functionMissingError(name string) *tfprotov6.Diagnostic { + return &tfprotov6.Diagnostic{ + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Function Not Implemented", + Detail: "The combined provider does not implement the requested function. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Missing function: " + name, + } +} + func resourceDuplicateError(typeName string) *tfprotov6.Diagnostic { return &tfprotov6.Diagnostic{ Severity: tfprotov6.DiagnosticSeverityError, diff --git a/tf6muxserver/mux_server.go b/tf6muxserver/mux_server.go index 7f13696..243f4c7 100644 --- a/tf6muxserver/mux_server.go +++ b/tf6muxserver/mux_server.go @@ -22,6 +22,9 @@ type muxServer struct { // Routing for data source types dataSources map[string]tfprotov6.ProviderServer + // Routing for functions + functions map[string]tfprotov6.ProviderServer + // Routing for resource types resources map[string]tfprotov6.ProviderServer @@ -87,6 +90,41 @@ func (s *muxServer) getDataSourceServer(ctx context.Context, typeName string) (t return server, s.serverDiscoveryDiagnostics, nil } +func (s *muxServer) getFunctionServer(ctx context.Context, name string) (tfprotov6.ProviderServer, []*tfprotov6.Diagnostic, error) { + s.serverDiscoveryMutex.RLock() + server, ok := s.functions[name] + discoveryComplete := s.serverDiscoveryComplete + s.serverDiscoveryMutex.RUnlock() + + if discoveryComplete { + if ok { + return server, s.serverDiscoveryDiagnostics, nil + } + + return nil, []*tfprotov6.Diagnostic{ + functionMissingError(name), + }, nil + } + + err := s.serverDiscovery(ctx) + + if err != nil || diagnosticsHasError(s.serverDiscoveryDiagnostics) { + return nil, s.serverDiscoveryDiagnostics, err + } + + s.serverDiscoveryMutex.RLock() + server, ok = s.functions[name] + s.serverDiscoveryMutex.RUnlock() + + if !ok { + return nil, []*tfprotov6.Diagnostic{ + functionMissingError(name), + }, nil + } + + return server, s.serverDiscoveryDiagnostics, nil +} + func (s *muxServer) getResourceServer(ctx context.Context, typeName string) (tfprotov6.ProviderServer, []*tfprotov6.Diagnostic, error) { s.serverDiscoveryMutex.RLock() server, ok := s.resources[typeName] @@ -122,10 +160,10 @@ func (s *muxServer) getResourceServer(ctx context.Context, typeName string) (tfp return server, s.serverDiscoveryDiagnostics, nil } -// serverDiscovery will populate the mux server "routing" for resource types by -// calling all underlying server GetMetadata RPC and falling back to -// GetProviderSchema RPC. It is intended to only be called through -// getDataSourceServer and getResourceServer. +// serverDiscovery will populate the mux server "routing" for functions and +// resource types by calling all underlying server GetMetadata RPC and falling +// back to GetProviderSchema RPC. It is intended to only be called through +// getDataSourceServer, getFunctionServer, and getResourceServer. // // The error return represents gRPC errors, which except for the GetMetadata // call returning the gRPC unimplemented error, is always returned. @@ -163,6 +201,16 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { s.dataSources[serverDataSource.TypeName] = server } + for _, serverFunction := range metadataResp.Functions { + if _, ok := s.functions[serverFunction.Name]; ok { + s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, functionDuplicateError(serverFunction.Name)) + + continue + } + + s.functions[serverFunction.Name] = server + } + for _, serverResource := range metadataResp.Resources { if _, ok := s.resources[serverResource.TypeName]; ok { s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, resourceDuplicateError(serverResource.TypeName)) @@ -205,6 +253,16 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { s.dataSources[typeName] = server } + for name := range providerSchemaResp.Functions { + if _, ok := s.functions[name]; ok { + s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + } + for typeName := range providerSchemaResp.ResourceSchemas { if _, ok := s.resources[typeName]; ok { s.serverDiscoveryDiagnostics = append(s.serverDiscoveryDiagnostics, resourceDuplicateError(typeName)) @@ -231,9 +289,11 @@ func (s *muxServer) serverDiscovery(ctx context.Context) error { // - All provider meta schemas exactly match // - Only one provider implements each managed resource // - Only one provider implements each data source +// - Only one provider implements each function func NewMuxServer(_ context.Context, servers ...func() tfprotov6.ProviderServer) (*muxServer, error) { result := muxServer{ dataSources: make(map[string]tfprotov6.ProviderServer), + functions: make(map[string]tfprotov6.ProviderServer), resources: make(map[string]tfprotov6.ProviderServer), resourceCapabilities: make(map[string]*tfprotov6.ServerCapabilities), servers: make([]tfprotov6.ProviderServer, 0, len(servers)), diff --git a/tf6muxserver/mux_server_CallFunction.go b/tf6muxserver/mux_server_CallFunction.go new file mode 100644 index 0000000..b600d49 --- /dev/null +++ b/tf6muxserver/mux_server_CallFunction.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf6muxserver + +import ( + "context" + + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + "github.com/hashicorp/terraform-plugin-mux/internal/logging" +) + +// CallFunction calls the CallFunction method of the underlying provider +// serving the function. +func (s *muxServer) CallFunction(ctx context.Context, req *tfprotov6.CallFunctionRequest) (*tfprotov6.CallFunctionResponse, error) { + rpc := "CallFunction" + ctx = logging.InitContext(ctx) + ctx = logging.RpcContext(ctx, rpc) + + server, diags, err := s.getFunctionServer(ctx, req.Name) + + if err != nil { + return nil, err + } + + if diagnosticsHasError(diags) { + return &tfprotov6.CallFunctionResponse{ + Diagnostics: diags, + }, nil + } + + ctx = logging.Tfprotov6ProviderServerContext(ctx, server) + + // Remove and call server.CallFunction below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := server.(tfprotov6.FunctionServer) + + if !ok { + resp := &tfprotov6.CallFunctionResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Provider Functions Not Implemented", + Detail: "A provider-defined function call was received by the provider, however the provider does not implement functions. " + + "Either upgrade the provider to a version that implements provider-defined functions or this is a bug in Terraform that should be reported to the Terraform maintainers.", + }, + }, + } + + return resp, nil + } + + logging.MuxTrace(ctx, "calling downstream server") + + // return server.CallFunction(ctx, req) + return functionServer.CallFunction(ctx, req) +} diff --git a/tf6muxserver/mux_server_CallFunction_test.go b/tf6muxserver/mux_server_CallFunction_test.go new file mode 100644 index 0000000..17f6b3f --- /dev/null +++ b/tf6muxserver/mux_server_CallFunction_test.go @@ -0,0 +1,82 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf6muxserver_test + +import ( + "context" + "testing" + + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + + "github.com/hashicorp/terraform-plugin-mux/internal/tf6testserver" + "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" +) + +func TestMuxServerCallFunction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function1": {}, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Fatal("muxServer should implement tfprotov6.FunctionServer") + } + + // _, err = muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function1", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } + + if testServer2.CallFunctionCalled["test_function1"] { + t.Errorf("unexpected test_function1 CallFunction called on server2") + } + + // _, err = muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function2", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if testServer1.CallFunctionCalled["test_function2"] { + t.Errorf("unexpected test_function2 CallFunction called on server1") + } + + if !testServer2.CallFunctionCalled["test_function2"] { + t.Errorf("expected test_function2 CallFunction to be called on server2") + } +} diff --git a/tf6muxserver/mux_server_GetFunctions.go b/tf6muxserver/mux_server_GetFunctions.go new file mode 100644 index 0000000..cbb484e --- /dev/null +++ b/tf6muxserver/mux_server_GetFunctions.go @@ -0,0 +1,68 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf6muxserver + +import ( + "context" + "fmt" + + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + + "github.com/hashicorp/terraform-plugin-mux/internal/logging" +) + +// GetFunctions merges the functions returned by the tfprotov6.ProviderServers +// associated with muxServer into a single response. Functions must be returned +// from only one server or an error diagnostic is returned. +func (s *muxServer) GetFunctions(ctx context.Context, req *tfprotov6.GetFunctionsRequest) (*tfprotov6.GetFunctionsResponse, error) { + rpc := "GetFunctions" + ctx = logging.InitContext(ctx) + ctx = logging.RpcContext(ctx, rpc) + + s.serverDiscoveryMutex.Lock() + defer s.serverDiscoveryMutex.Unlock() + + resp := &tfprotov6.GetFunctionsResponse{ + Functions: make(map[string]*tfprotov6.Function), + } + + for _, server := range s.servers { + ctx := logging.Tfprotov6ProviderServerContext(ctx, server) + + // Remove and call server.GetFunctions below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := server.(tfprotov6.FunctionServer) + + if !ok { + continue + } + + logging.MuxTrace(ctx, "calling downstream server") + + // serverResp, err := server.GetFunctions(ctx, &tfprotov6.GetFunctionsRequest{}) + serverResp, err := functionServer.GetFunctions(ctx, &tfprotov6.GetFunctionsRequest{}) + + if err != nil { + return resp, fmt.Errorf("error calling GetFunctions for %T: %w", server, err) + } + + resp.Diagnostics = append(resp.Diagnostics, serverResp.Diagnostics...) + + for name, definition := range serverResp.Functions { + if _, ok := resp.Functions[name]; ok { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + resp.Functions[name] = definition + } + } + + // Intentionally not setting overall server discovery as complete, as data + // sources and resources are not discovered via this RPC. + + return resp, nil +} diff --git a/tf6muxserver/mux_server_GetFunctions_test.go b/tf6muxserver/mux_server_GetFunctions_test.go new file mode 100644 index 0000000..185047f --- /dev/null +++ b/tf6muxserver/mux_server_GetFunctions_test.go @@ -0,0 +1,333 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tf6muxserver_test + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + "github.com/hashicorp/terraform-plugin-go/tftypes" + + "github.com/hashicorp/terraform-plugin-mux/internal/tf6testserver" + "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" +) + +func TestMuxServerGetFunctions(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + servers []func() tfprotov6.ProviderServer + expected *tfprotov6.GetFunctionsResponse + }{ + "combined": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function1": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function2": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function1": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function2": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }, + "duplicate-function": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + Functions: map[string]*tfprotov6.Function{ + "test_function": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, + }, + }, + "error-once": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov6.Function{}, + }, + }, + "error-multiple": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov6.Function{}, + }, + }, + "warning-once": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + Functions: map[string]*tfprotov6.Function{}, + }, + }, + "warning-multiple": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + Functions: map[string]*tfprotov6.Function{}, + }, + }, + "warning-then-error": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{}).ProviderServer, + (&tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + }, + }).ProviderServer, + }, + expected: &tfprotov6.GetFunctionsResponse{ + Diagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityWarning, + Summary: "test warning summary", + Detail: "test warning details", + }, + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "test error summary", + Detail: "test error details", + }, + }, + Functions: map[string]*tfprotov6.Function{}, + }, + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + muxServer, err := tf6muxserver.NewMuxServer(context.Background(), testCase.servers...) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Fatal("muxServer should implement tfprotov6.FunctionServer") + } + + // resp, err := muxServer.ProviderServer().GetFunctions(context.Background(), &tfprotov6.GetFunctionsRequest{}) + resp, err := functionServer.GetFunctions(context.Background(), &tfprotov6.GetFunctionsRequest{}) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if diff := cmp.Diff(resp, testCase.expected); diff != "" { + t.Errorf("unexpected difference: %s", diff) + } + }) + } +} diff --git a/tf6muxserver/mux_server_GetMetadata.go b/tf6muxserver/mux_server_GetMetadata.go index 14727b1..a181654 100644 --- a/tf6muxserver/mux_server_GetMetadata.go +++ b/tf6muxserver/mux_server_GetMetadata.go @@ -26,6 +26,7 @@ func (s *muxServer) GetMetadata(ctx context.Context, req *tfprotov6.GetMetadataR resp := &tfprotov6.GetMetadataResponse{ DataSources: make([]tfprotov6.DataSourceMetadata, 0), + Functions: make([]tfprotov6.FunctionMetadata, 0), Resources: make([]tfprotov6.ResourceMetadata, 0), ServerCapabilities: serverCapabilities, } @@ -53,6 +54,17 @@ func (s *muxServer) GetMetadata(ctx context.Context, req *tfprotov6.GetMetadataR resp.DataSources = append(resp.DataSources, datasource) } + for _, function := range serverResp.Functions { + if functionMetadataContainsName(resp.Functions, function.Name) { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(function.Name)) + + continue + } + + s.functions[function.Name] = server + resp.Functions = append(resp.Functions, function) + } + for _, resource := range serverResp.Resources { if resourceMetadataContainsTypeName(resp.Resources, resource.TypeName) { resp.Diagnostics = append(resp.Diagnostics, resourceDuplicateError(resource.TypeName)) @@ -79,6 +91,16 @@ func datasourceMetadataContainsTypeName(metadatas []tfprotov6.DataSourceMetadata return false } +func functionMetadataContainsName(metadatas []tfprotov6.FunctionMetadata, name string) bool { + for _, metadata := range metadatas { + if name == metadata.Name { + return true + } + } + + return false +} + func resourceMetadataContainsTypeName(metadatas []tfprotov6.ResourceMetadata, typeName string) bool { for _, metadata := range metadatas { if typeName == metadata.TypeName { diff --git a/tf6muxserver/mux_server_GetMetadata_test.go b/tf6muxserver/mux_server_GetMetadata_test.go index eb9dd14..dadccf2 100644 --- a/tf6muxserver/mux_server_GetMetadata_test.go +++ b/tf6muxserver/mux_server_GetMetadata_test.go @@ -21,6 +21,7 @@ func TestMuxServerGetMetadata(t *testing.T) { servers []func() tfprotov6.ProviderServer expectedDataSources []tfprotov6.DataSourceMetadata expectedDiagnostics []*tfprotov6.Diagnostic + expectedFunctions []tfprotov6.FunctionMetadata expectedResources []tfprotov6.ResourceMetadata expectedServerCapabilities *tfprotov6.ServerCapabilities }{ @@ -41,6 +42,11 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_foo", }, }, + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function1", + }, + }, }, }).ProviderServer, (&tf6testserver.TestServer{ @@ -58,6 +64,14 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_quux", }, }, + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function2", + }, + { + Name: "test_function3", + }, + }, }, }).ProviderServer, }, @@ -83,6 +97,17 @@ func TestMuxServerGetMetadata(t *testing.T) { TypeName: "test_quux", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function1", + }, + { + Name: "test_function2", + }, + { + Name: "test_function3", + }, + }, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, PlanDestroy: true, @@ -124,6 +149,50 @@ func TestMuxServerGetMetadata(t *testing.T) { "Duplicate data source type: test_foo", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, + expectedResources: []tfprotov6.ResourceMetadata{}, + expectedServerCapabilities: &tfprotov6.ServerCapabilities{ + GetProviderSchemaOptional: true, + PlanDestroy: true, + }, + }, + "duplicate-function": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function", + }, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function", + }, + }, + }, + }).ProviderServer, + }, + expectedDataSources: []tfprotov6.DataSourceMetadata{}, + expectedDiagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + expectedFunctions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function", + }, + }, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -162,6 +231,7 @@ func TestMuxServerGetMetadata(t *testing.T) { "Duplicate resource type: test_foo", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{ { TypeName: "test_foo", @@ -198,6 +268,7 @@ func TestMuxServerGetMetadata(t *testing.T) { }).ProviderServer, }, expectedDataSources: []tfprotov6.DataSourceMetadata{}, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{ { TypeName: "test_with_server_capabilities", @@ -235,6 +306,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -280,6 +352,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -310,6 +383,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -355,6 +429,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -400,6 +475,7 @@ func TestMuxServerGetMetadata(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: []tfprotov6.FunctionMetadata{}, expectedResources: []tfprotov6.ResourceMetadata{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -434,6 +510,10 @@ func TestMuxServerGetMetadata(t *testing.T) { t.Errorf("diagnostics didn't match expectations: %s", diff) } + if diff := cmp.Diff(resp.Functions, testCase.expectedFunctions); diff != "" { + t.Errorf("functions didn't match expectations: %s", diff) + } + if diff := cmp.Diff(resp.Resources, testCase.expectedResources); diff != "" { t.Errorf("resources didn't match expectations: %s", diff) } diff --git a/tf6muxserver/mux_server_GetProviderSchema.go b/tf6muxserver/mux_server_GetProviderSchema.go index 7b95c90..e725dcc 100644 --- a/tf6muxserver/mux_server_GetProviderSchema.go +++ b/tf6muxserver/mux_server_GetProviderSchema.go @@ -26,6 +26,7 @@ func (s *muxServer) GetProviderSchema(ctx context.Context, req *tfprotov6.GetPro resp := &tfprotov6.GetProviderSchemaResponse{ DataSourceSchemas: make(map[string]*tfprotov6.Schema), + Functions: make(map[string]*tfprotov6.Function), ResourceSchemas: make(map[string]*tfprotov6.Schema), ServerCapabilities: serverCapabilities, } @@ -94,6 +95,17 @@ func (s *muxServer) GetProviderSchema(ctx context.Context, req *tfprotov6.GetPro s.dataSources[dataSourceType] = server resp.DataSourceSchemas[dataSourceType] = schema } + + for name, definition := range serverResp.Functions { + if _, ok := resp.Functions[name]; ok { + resp.Diagnostics = append(resp.Diagnostics, functionDuplicateError(name)) + + continue + } + + s.functions[name] = server + resp.Functions[name] = definition + } } s.serverDiscoveryComplete = true diff --git a/tf6muxserver/mux_server_GetProviderSchema_test.go b/tf6muxserver/mux_server_GetProviderSchema_test.go index 89c430a..aa2c18e 100644 --- a/tf6muxserver/mux_server_GetProviderSchema_test.go +++ b/tf6muxserver/mux_server_GetProviderSchema_test.go @@ -22,6 +22,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { servers []func() tfprotov6.ProviderServer expectedDataSourceSchemas map[string]*tfprotov6.Schema expectedDiagnostics []*tfprotov6.Diagnostic + expectedFunctions map[string]*tfprotov6.Function expectedProviderSchema *tfprotov6.Schema expectedProviderMetaSchema *tfprotov6.Schema expectedResourceSchemas map[string]*tfprotov6.Schema @@ -144,6 +145,13 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + Functions: map[string]*tfprotov6.Function{ + "test_function1": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, }, }).ProviderServer, (&tf6testserver.TestServer{ @@ -259,6 +267,18 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + Functions: map[string]*tfprotov6.Function{ + "test_function2": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, }, }).ProviderServer, }, @@ -425,6 +445,23 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }, }, }, + expectedFunctions: map[string]*tfprotov6.Function{ + "test_function1": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function2": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + "test_function3": { + Return: &tfprotov6.FunctionReturn{ + Type: tftypes.String, + }, + }, + }, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, PlanDestroy: true, @@ -460,6 +497,44 @@ func TestMuxServerGetProviderSchema(t *testing.T) { "Duplicate data source type: test_foo", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, + expectedResourceSchemas: map[string]*tfprotov6.Schema{}, + expectedServerCapabilities: &tfprotov6.ServerCapabilities{ + GetProviderSchemaOptional: true, + PlanDestroy: true, + }, + }, + "duplicate-function": { + servers: []func() tfprotov6.ProviderServer{ + (&tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, + }, + }).ProviderServer, + (&tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, + }, + }).ProviderServer, + }, + expectedDataSourceSchemas: map[string]*tfprotov6.Schema{}, + expectedDiagnostics: []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + }, + expectedFunctions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, expectedServerCapabilities: &tfprotov6.ServerCapabilities{ GetProviderSchemaOptional: true, @@ -494,6 +569,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { "Duplicate resource type: test_foo", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{ "test_foo": {}, }, @@ -569,6 +645,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { ), }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedProviderSchema: &tfprotov6.Schema{ Block: &tfprotov6.SchemaBlock{ Attributes: []*tfprotov6.SchemaAttribute{ @@ -653,6 +730,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { ), }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedProviderMetaSchema: &tfprotov6.Schema{ Block: &tfprotov6.SchemaBlock{ Attributes: []*tfprotov6.SchemaAttribute{ @@ -691,6 +769,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { }).ProviderServer, }, expectedDataSourceSchemas: map[string]*tfprotov6.Schema{}, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{ "test_with_server_capabilities": {}, "test_without_server_capabilities": {}, @@ -724,6 +803,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, }, "error-multiple": { @@ -765,6 +845,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, }, "warning-once": { @@ -791,6 +872,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, }, "warning-multiple": { @@ -832,6 +914,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test warning details", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, }, "warning-then-error": { @@ -873,6 +956,7 @@ func TestMuxServerGetProviderSchema(t *testing.T) { Detail: "test error details", }, }, + expectedFunctions: map[string]*tfprotov6.Function{}, expectedResourceSchemas: map[string]*tfprotov6.Schema{}, }, } @@ -903,6 +987,10 @@ func TestMuxServerGetProviderSchema(t *testing.T) { t.Errorf("diagnostics didn't match expectations: %s", diff) } + if diff := cmp.Diff(resp.Functions, testCase.expectedFunctions); diff != "" { + t.Errorf("functions didn't match expectations: %s", diff) + } + if diff := cmp.Diff(resp.Provider, testCase.expectedProviderSchema); diff != "" { t.Errorf("provider schema didn't match expectations: %s", diff) } diff --git a/tf6muxserver/mux_server_test.go b/tf6muxserver/mux_server_test.go index 088160b..5a0362a 100644 --- a/tf6muxserver/mux_server_test.go +++ b/tf6muxserver/mux_server_test.go @@ -399,6 +399,436 @@ func TestMuxServerGetDataSourceServer_Missing(t *testing.T) { } } +func TestMuxServerGetFunctionServer_GetProviderSchema(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function1": {}, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + //_, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_GetProviderSchema_Duplicate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, // intentionally duplicated + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, // intentionally duplicated + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + // resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server2") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function2", + }, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + // _, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata_Duplicate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function", // intentionally duplicated + }, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function", // intentionally duplicated + }, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Invalid Provider Server Combination", + Detail: "The combined provider has multiple implementations of the same function name across underlying providers. " + + "Functions must be implemented by only one underlying provider. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Duplicate function: test_function", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + // resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function"] { + t.Errorf("unexpected test_function CallFunction called on server2") + } +} + +func TestMuxServerGetFunctionServer_GetMetadata_Partial(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function2": {}, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + // _, _ = muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + _, _ = functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function1", + }) + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if !testServer1.CallFunctionCalled["test_function1"] { + t.Errorf("expected test_function1 CallFunction to be called on server1") + } +} + +func TestMuxServerGetFunctionServer_Missing(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testServer1 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function1", + }, + }, + }, + } + testServer2 := &tf6testserver.TestServer{ + GetMetadataResponse: &tfprotov6.GetMetadataResponse{ + Functions: []tfprotov6.FunctionMetadata{ + { + Name: "test_function2", + }, + }, + }, + } + + servers := []func() tfprotov6.ProviderServer{testServer1.ProviderServer, testServer2.ProviderServer} + muxServer, err := tf6muxserver.NewMuxServer(ctx, servers...) + + if err != nil { + t.Fatalf("unexpected error setting up factory: %s", err) + } + + // When GetProviderSchemaOptional is enabled, the secondary provider + // instances will receive non-GetProviderSchema RPCs such as + // CallFunction which will cause getFunctionServer to perform + // server discovery. This testing also simulates concurrent operations from + // Terraform to verify the mutex does not deadlock. + var wg sync.WaitGroup + + expectedDiags := []*tfprotov6.Diagnostic{ + { + Severity: tfprotov6.DiagnosticSeverityError, + Summary: "Function Not Implemented", + Detail: "The combined provider does not implement the requested function. " + + "This is always an issue in the provider implementation and should be reported to the provider developers.\n\n" + + "Missing function: test_function_nonexistent", + }, + } + + terraformOp := func() { + defer wg.Done() + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := muxServer.ProviderServer().(tfprotov6.FunctionServer) + + if !ok { + t.Error("muxServer should implement tfprotov6.FunctionServer") + } + + //resp, _ := muxServer.ProviderServer().CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + resp, _ := functionServer.CallFunction(ctx, &tfprotov6.CallFunctionRequest{ + Name: "test_function_nonexistent", + }) + + if diff := cmp.Diff(resp.Diagnostics, expectedDiags); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + } + + wg.Add(2) + go terraformOp() + go terraformOp() + + wg.Wait() + + if testServer1.CallFunctionCalled["test_function_nonexistent"] { + t.Errorf("unexpected test_function_nonexistent CallFunction called on server1") + } + + if testServer2.CallFunctionCalled["test_function_nonexistent"] { + t.Errorf("unexpected test_function_nonexistent CallFunction called on server2") + } +} + func TestMuxServerGetResourceServer_GetProviderSchema(t *testing.T) { t.Parallel() diff --git a/tf6to5server/tf6to5server.go b/tf6to5server/tf6to5server.go index 6d13b48..d7cd394 100644 --- a/tf6to5server/tf6to5server.go +++ b/tf6to5server/tf6to5server.go @@ -55,6 +55,38 @@ func (s v6tov5Server) ApplyResourceChange(ctx context.Context, req *tfprotov5.Ap return tfprotov6tov5.ApplyResourceChangeResponse(v6Resp), nil } +func (s v6tov5Server) CallFunction(ctx context.Context, req *tfprotov5.CallFunctionRequest) (*tfprotov5.CallFunctionResponse, error) { + // Remove and call s.v6Server.CallFunction below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := s.v6Server.(tfprotov6.FunctionServer) + + if !ok { + v5Resp := &tfprotov5.CallFunctionResponse{ + Diagnostics: []*tfprotov5.Diagnostic{ + { + Severity: tfprotov5.DiagnosticSeverityError, + Summary: "Provider Functions Not Implemented", + Detail: "A provider-defined function call was received by the provider, however the provider does not implement functions. " + + "Either upgrade the provider to a version that implements provider-defined functions or this is a bug in Terraform that should be reported to the Terraform maintainers.", + }, + }, + } + + return v5Resp, nil + } + + v6Req := tfprotov5tov6.CallFunctionRequest(req) + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + // v6Resp, err := s.v6Server.CallFunction(ctx, v6Req) + v6Resp, err := functionServer.CallFunction(ctx, v6Req) + + if err != nil { + return nil, err + } + + return tfprotov6tov5.CallFunctionResponse(v6Resp), nil +} + func (s v6tov5Server) ConfigureProvider(ctx context.Context, req *tfprotov5.ConfigureProviderRequest) (*tfprotov5.ConfigureProviderResponse, error) { v6Req := tfprotov5tov6.ConfigureProviderRequest(req) v6Resp, err := s.v6Server.ConfigureProvider(ctx, v6Req) @@ -66,6 +98,31 @@ func (s v6tov5Server) ConfigureProvider(ctx context.Context, req *tfprotov5.Conf return tfprotov6tov5.ConfigureProviderResponse(v6Resp), nil } +func (s v6tov5Server) GetFunctions(ctx context.Context, req *tfprotov5.GetFunctionsRequest) (*tfprotov5.GetFunctionsResponse, error) { + // Remove and call s.v6Server.GetFunctions below directly. + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := s.v6Server.(tfprotov6.FunctionServer) + + if !ok { + v5Resp := &tfprotov5.GetFunctionsResponse{ + Functions: map[string]*tfprotov5.Function{}, + } + + return v5Resp, nil + } + + v6Req := tfprotov5tov6.GetFunctionsRequest(req) + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + // v6Resp, err := s.v6Server.GetFunctions(ctx, v6Req) + v6Resp, err := functionServer.GetFunctions(ctx, v6Req) + + if err != nil { + return nil, err + } + + return tfprotov6tov5.GetFunctionsResponse(v6Resp), nil +} + func (s v6tov5Server) GetMetadata(ctx context.Context, req *tfprotov5.GetMetadataRequest) (*tfprotov5.GetMetadataResponse, error) { v6Req := tfprotov5tov6.GetMetadataRequest(req) v6Resp, err := s.v6Server.GetMetadata(ctx, v6Req) diff --git a/tf6to5server/tf6to5server_test.go b/tf6to5server/tf6to5server_test.go index 006cfec..de8ba35 100644 --- a/tf6to5server/tf6to5server_test.go +++ b/tf6to5server/tf6to5server_test.go @@ -31,6 +31,9 @@ func TestDowngradeServer(t *testing.T) { DataSourceSchemas: map[string]*tfprotov6.Schema{ "test_data_source": {}, }, + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, Provider: &tfprotov6.Schema{ Block: &tfprotov6.SchemaBlock{ Attributes: []*tfprotov6.SchemaAttribute{ @@ -255,6 +258,45 @@ func TestV6ToV5ServerApplyResourceChange(t *testing.T) { } } +func TestV6ToV5ServerCallFunction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + v6server := &tf6testserver.TestServer{ + GetProviderSchemaResponse: &tfprotov6.GetProviderSchemaResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, + }, + } + + v5server, err := tf6to5server.DowngradeServer(context.Background(), v6server.ProviderServer) + + if err != nil { + t.Fatalf("unexpected error downgrading server: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := v5server.(tfprotov5.FunctionServer) + + if !ok { + t.Fatal("v5server should implement tfprotov5.FunctionServer") + } + + // _, err = v5server.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + _, err = functionServer.CallFunction(ctx, &tfprotov5.CallFunctionRequest{ + Name: "test_function", + }) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !v6server.CallFunctionCalled["test_function"] { + t.Errorf("expected test_function CallFunction to be called") + } +} + func TestV6ToV5ServerConfigureProvider(t *testing.T) { t.Parallel() @@ -284,6 +326,43 @@ func TestV6ToV5ServerConfigureProvider(t *testing.T) { } } +func TestV5ToV6ServerGetFunctions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + v6server := &tf6testserver.TestServer{ + GetFunctionsResponse: &tfprotov6.GetFunctionsResponse{ + Functions: map[string]*tfprotov6.Function{ + "test_function": {}, + }, + }, + } + + v5server, err := tf6to5server.DowngradeServer(context.Background(), v6server.ProviderServer) + + if err != nil { + t.Fatalf("unexpected error downgrading server: %s", err) + } + + // Reference: https://github.com/hashicorp/terraform-plugin-mux/issues/210 + functionServer, ok := v5server.(tfprotov5.FunctionServer) + + if !ok { + t.Fatal("v5server should implement tfprotov5.FunctionServer") + } + + //_, err = v5server.GetFunctions(ctx, &tfprotov5.GetFunctionsRequest{}) + _, err = functionServer.GetFunctions(ctx, &tfprotov5.GetFunctionsRequest{}) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !v6server.GetFunctionsCalled { + t.Errorf("expected GetFunctions to be called") + } +} + func TestV5ToV6ServerGetMetadata(t *testing.T) { t.Parallel()