diff --git a/cmd/aries-agent-rest/startcmd/start.go b/cmd/aries-agent-rest/startcmd/start.go index 6775c034c..3c5f39e9f 100644 --- a/cmd/aries-agent-rest/startcmd/start.go +++ b/cmd/aries-agent-rest/startcmd/start.go @@ -233,7 +233,8 @@ var ( } ) -type agentParameters struct { +// AgentParameters represents the various options to run an Aries Agent. +type AgentParameters struct { server server host, defaultLabel, transportReturnRoute string tlsCertFile, tlsKeyFile string @@ -289,144 +290,153 @@ func Cmd(server server) (*cobra.Command, error) { return startCmd, nil } -func createStartCMD(server server) *cobra.Command { //nolint: funlen,gocyclo,gocognit - return &cobra.Command{ - Use: "start", - Short: "Start an agent", - Long: `Start an Aries agent controller`, - RunE: func(cmd *cobra.Command, args []string) error { - // log level - logLevel, err := getUserSetVar(cmd, agentLogLevelFlagName, agentLogLevelEnvKey, true) - if err != nil { - return err - } +// NewAgentParameters constructs AgentParameters with the given cobra command. +func NewAgentParameters(server server, cmd *cobra.Command) (*AgentParameters, error) { //nolint: funlen,gocyclo + // log level + logLevel, err := getUserSetVar(cmd, agentLogLevelFlagName, agentLogLevelEnvKey, true) + if err != nil { + return nil, err + } - err = setLogLevel(logLevel) - if err != nil { - return err - } + err = setLogLevel(logLevel) + if err != nil { + return nil, err + } - host, err := getUserSetVar(cmd, agentHostFlagName, agentHostEnvKey, false) - if err != nil { - return err - } + host, err := getUserSetVar(cmd, agentHostFlagName, agentHostEnvKey, false) + if err != nil { + return nil, err + } - token, err := getUserSetVar(cmd, agentTokenFlagName, agentTokenEnvKey, true) - if err != nil { - return err - } + token, err := getUserSetVar(cmd, agentTokenFlagName, agentTokenEnvKey, true) + if err != nil { + return nil, err + } - inboundHosts, err := getUserSetVars(cmd, agentInboundHostFlagName, agentInboundHostEnvKey, true) - if err != nil { - return err - } + inboundHosts, err := getUserSetVars(cmd, agentInboundHostFlagName, agentInboundHostEnvKey, true) + if err != nil { + return nil, err + } - inboundHostExternals, err := getUserSetVars(cmd, agentInboundHostExternalFlagName, - agentInboundHostExternalEnvKey, true) - if err != nil { - return err - } + inboundHostExternals, err := getUserSetVars(cmd, agentInboundHostExternalFlagName, + agentInboundHostExternalEnvKey, true) + if err != nil { + return nil, err + } - websocketReadLimit, err := getWebSocketReadLimit(cmd) - if err != nil { - return err - } + websocketReadLimit, err := getWebSocketReadLimit(cmd) + if err != nil { + return nil, err + } - dbParam, err := getDBParam(cmd) - if err != nil { - return err - } + dbParam, err := getDBParam(cmd) + if err != nil { + return nil, err + } - defaultLabel, err := getUserSetVar(cmd, agentDefaultLabelFlagName, agentDefaultLabelEnvKey, true) - if err != nil { - return err - } + defaultLabel, err := getUserSetVar(cmd, agentDefaultLabelFlagName, agentDefaultLabelEnvKey, true) + if err != nil { + return nil, err + } - autoAccept, err := getAutoAcceptValue(cmd) - if err != nil { - return err - } + autoAccept, err := getAutoAcceptValue(cmd) + if err != nil { + return nil, err + } - webhookURLs, err := getUserSetVars(cmd, agentWebhookFlagName, agentWebhookEnvKey, autoAccept) - if err != nil { - return err - } + webhookURLs, err := getUserSetVars(cmd, agentWebhookFlagName, agentWebhookEnvKey, autoAccept) + if err != nil { + return nil, err + } - httpResolvers, err := getUserSetVars(cmd, agentHTTPResolverFlagName, agentHTTPResolverEnvKey, true) - if err != nil { - return err - } + httpResolvers, err := getUserSetVars(cmd, agentHTTPResolverFlagName, agentHTTPResolverEnvKey, true) + if err != nil { + return nil, err + } - outboundTransports, err := getUserSetVars(cmd, agentOutboundTransportFlagName, - agentOutboundTransportEnvKey, true) - if err != nil { - return err - } + outboundTransports, err := getUserSetVars(cmd, agentOutboundTransportFlagName, + agentOutboundTransportEnvKey, true) + if err != nil { + return nil, err + } - transportReturnRoute, err := getUserSetVar(cmd, agentTransportReturnRouteFlagName, - agentTransportReturnRouteEnvKey, true) - if err != nil { - return err - } + transportReturnRoute, err := getUserSetVar(cmd, agentTransportReturnRouteFlagName, + agentTransportReturnRouteEnvKey, true) + if err != nil { + return nil, err + } - contextProviderURLs, err := getUserSetVars(cmd, agentContextProviderFlagName, agentContextProviderEnvKey, true) - if err != nil { - return err - } + contextProviderURLs, err := getUserSetVars(cmd, agentContextProviderFlagName, agentContextProviderEnvKey, true) + if err != nil { + return nil, err + } - autoExecuteRFC0593, err := getAutoExecuteRFC0593(cmd) - if err != nil { - return err - } + autoExecuteRFC0593, err := getAutoExecuteRFC0593(cmd) + if err != nil { + return nil, err + } - tlsCertFile, err := getUserSetVar(cmd, agentTLSCertFileFlagName, agentTLSCertFileEnvKey, true) - if err != nil { - return err - } + tlsCertFile, err := getUserSetVar(cmd, agentTLSCertFileFlagName, agentTLSCertFileEnvKey, true) + if err != nil { + return nil, err + } - tlsKeyFile, err := getUserSetVar(cmd, agentTLSKeyFileFlagName, agentTLSKeyFileEnvKey, true) - if err != nil { - return err - } + tlsKeyFile, err := getUserSetVar(cmd, agentTLSKeyFileFlagName, agentTLSKeyFileEnvKey, true) + if err != nil { + return nil, err + } - keyType, err := getUserSetVar(cmd, agentKeyTypeFlagName, agentKeyTypeEnvKey, true) - if err != nil { - return err - } + keyType, err := getUserSetVar(cmd, agentKeyTypeFlagName, agentKeyTypeEnvKey, true) + if err != nil { + return nil, err + } - keyAgreementType, err := getUserSetVar(cmd, agentKeyAgreementTypeFlagName, agentKeyAgreementTypeEnvKey, true) - if err != nil { - return err - } + keyAgreementType, err := getUserSetVar(cmd, agentKeyAgreementTypeFlagName, agentKeyAgreementTypeEnvKey, true) + if err != nil { + return nil, err + } - mediaTypeProfiles, err := getUserSetVars(cmd, agentMediaTypeProfilesFlagName, agentMediaTypeProfilesEnvKey, true) + mediaTypeProfiles, err := getUserSetVars(cmd, agentMediaTypeProfilesFlagName, agentMediaTypeProfilesEnvKey, true) + if err != nil { + return nil, err + } + + parameters := &AgentParameters{ + server: server, + host: host, + token: token, + inboundHostInternals: inboundHosts, + inboundHostExternals: inboundHostExternals, + websocketReadLimit: websocketReadLimit, + dbParam: dbParam, + defaultLabel: defaultLabel, + webhookURLs: webhookURLs, + httpResolvers: httpResolvers, + outboundTransports: outboundTransports, + autoAccept: autoAccept, + transportReturnRoute: transportReturnRoute, + contextProviderURLs: contextProviderURLs, + tlsCertFile: tlsCertFile, + tlsKeyFile: tlsKeyFile, + autoExecuteRFC0593: autoExecuteRFC0593, + keyType: keyType, + keyAgreementType: keyAgreementType, + mediaTypeProfiles: mediaTypeProfiles, + } + + return parameters, nil +} + +func createStartCMD(server server) *cobra.Command { + return &cobra.Command{ + Use: "start", + Short: "Start an agent", + Long: `Start an Aries agent controller`, + RunE: func(cmd *cobra.Command, args []string) error { + parameters, err := NewAgentParameters(server, cmd) if err != nil { return err } - - parameters := &agentParameters{ - server: server, - host: host, - token: token, - inboundHostInternals: inboundHosts, - inboundHostExternals: inboundHostExternals, - websocketReadLimit: websocketReadLimit, - dbParam: dbParam, - defaultLabel: defaultLabel, - webhookURLs: webhookURLs, - httpResolvers: httpResolvers, - outboundTransports: outboundTransports, - autoAccept: autoAccept, - transportReturnRoute: transportReturnRoute, - contextProviderURLs: contextProviderURLs, - tlsCertFile: tlsCertFile, - tlsKeyFile: tlsKeyFile, - autoExecuteRFC0593: autoExecuteRFC0593, - keyType: keyType, - keyAgreementType: keyAgreementType, - mediaTypeProfiles: mediaTypeProfiles, - } - return startAgent(parameters) }, } @@ -585,7 +595,7 @@ func createFlags(startCmd *cobra.Command) { } func getUserSetVar(cmd *cobra.Command, flagName, envKey string, isOptional bool) (string, error) { - if cmd.Flags().Changed(flagName) { + if cmd != nil && cmd.Flags().Changed(flagName) { value, err := cmd.Flags().GetString(flagName) if err != nil { return "", fmt.Errorf(flagName+" flag not found: %s", err) @@ -605,7 +615,7 @@ func getUserSetVar(cmd *cobra.Command, flagName, envKey string, isOptional bool) } func getUserSetVars(cmd *cobra.Command, flagName, envKey string, isOptional bool) ([]string, error) { - if cmd.Flags().Changed(flagName) { + if cmd != nil && cmd.Flags().Changed(flagName) { value, err := cmd.Flags().GetStringSlice(flagName) if err != nil { return nil, fmt.Errorf(flagName+" flag not found: %s", err) @@ -775,9 +785,10 @@ func authorizationMiddleware(token string) mux.MiddlewareFunc { return middleware } -func startAgent(parameters *agentParameters) error { +// NewRouter returns a Router for the Aries Agent. +func (parameters *AgentParameters) NewRouter() (*mux.Router, error) { if parameters.host == "" { - return errMissingHost + return nil, errMissingHost } // set message handler @@ -785,7 +796,7 @@ func startAgent(parameters *agentParameters) error { ctx, err := createAriesAgent(parameters) if err != nil { - return err + return nil, err } // get all HTTP REST API handlers available for controller API @@ -794,7 +805,7 @@ func startAgent(parameters *agentParameters) error { controller.WithMessageHandler(parameters.msgHandler), controller.WithAutoExecuteRFC0593(parameters.autoExecuteRFC0593)) if err != nil { - return fmt.Errorf("failed to start aries agent rest on port [%s], failed to get rest service api : %w", + return nil, fmt.Errorf("failed to start aries agent rest on port [%s], failed to get rest service api : %w", parameters.host, err) } @@ -808,7 +819,17 @@ func startAgent(parameters *agentParameters) error { router.HandleFunc(handler.Path(), handler.Handle()).Methods(handler.Method()) } + return router, nil +} + +func startAgent(parameters *AgentParameters) error { logger.Infof("Starting aries agent rest on host [%s]", parameters.host) + + router, err := parameters.NewRouter() + if err != nil { + return err + } + // start server on given port and serve using given handlers handler := cors.New( cors.Options{ @@ -826,7 +847,7 @@ func startAgent(parameters *agentParameters) error { } //nolint:funlen,gocyclo -func createAriesAgent(parameters *agentParameters) (*context.Provider, error) { +func createAriesAgent(parameters *AgentParameters) (*context.Provider, error) { var opts []aries.Option storePro, err := createStoreProviders(parameters) @@ -898,7 +919,7 @@ func createAriesAgent(parameters *agentParameters) (*context.Provider, error) { return ctx, nil } -func createStoreProviders(parameters *agentParameters) (storage.Provider, error) { +func createStoreProviders(parameters *AgentParameters) (storage.Provider, error) { provider, supported := supportedStorageProviders[parameters.dbParam.dbType] if !supported { return nil, fmt.Errorf("key database type not set to a valid type." + diff --git a/cmd/aries-agent-rest/startcmd/start_test.go b/cmd/aries-agent-rest/startcmd/start_test.go index b742076e6..61a382ad5 100644 --- a/cmd/aries-agent-rest/startcmd/start_test.go +++ b/cmd/aries-agent-rest/startcmd/start_test.go @@ -104,7 +104,7 @@ func TestStartAriesDRequests(t *testing.T) { testInboundHostURL := randomURL() go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -295,7 +295,7 @@ func TestStartCmdWithMissingHostArg(t *testing.T) { } func TestStartAgentWithBlankHost(t *testing.T) { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &mockServer{}, inboundHostInternals: []string{randomURL()}, } @@ -535,7 +535,7 @@ func TestStartMultipleAgentsWithSameHost(t *testing.T) { inboundHost2 := "localhost:8097" go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: host, inboundHostInternals: []string{httpProtocol + "@" + inboundHost}, @@ -548,7 +548,7 @@ func TestStartMultipleAgentsWithSameHost(t *testing.T) { waitForServerToStart(t, host, inboundHost) - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: host, inboundHostInternals: []string{httpProtocol + "@" + inboundHost2}, @@ -568,7 +568,7 @@ func TestStartAriesErrorWithResolvers(t *testing.T) { testHostURL := randomURL() testInboundHostURL := randomURL() - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -586,7 +586,7 @@ func TestStartAriesErrorWithResolvers(t *testing.T) { testHostURL := randomURL() testInboundHostURL := randomURL() - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -606,7 +606,7 @@ func TestStartAriesWithOutboundTransports(t *testing.T) { testInboundHostURL := randomURL() go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -627,7 +627,7 @@ func TestStartAriesWithOutboundTransports(t *testing.T) { testHostURL := randomURL() testInboundHostURL := randomURL() - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -647,7 +647,7 @@ func TestStartAriesWithInboundTransport(t *testing.T) { testInboundHostURL := randomURL() go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{websocketProtocol + "@" + testInboundHostURL}, @@ -668,7 +668,7 @@ func TestStartAriesWithInboundTransport(t *testing.T) { testHostURL := randomURL() testInboundHostURL := randomURL() - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{"wss" + "@" + testInboundHostURL}, @@ -687,7 +687,7 @@ func TestStartAriesWithAutoAccept(t *testing.T) { testInboundHostURL := randomURL() go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL}, @@ -706,7 +706,7 @@ func TestStartAriesWithAutoAccept(t *testing.T) { } func TestStartAriesTLS(t *testing.T) { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: ":0", dbParam: &dbParam{dbType: databaseTypeMemOption}, @@ -756,7 +756,7 @@ func TestCreateAriesWithKeyType(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - parameters := &agentParameters{ + parameters := &AgentParameters{ dbParam: &dbParam{dbType: databaseTypeMemOption}, keyType: tc.kt, } @@ -794,7 +794,7 @@ func TestCreateAriesWithKeyAgreementType(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - parameters := &agentParameters{ + parameters := &AgentParameters{ dbParam: &dbParam{dbType: databaseTypeMemOption}, keyAgreementType: tc.kt, } @@ -832,7 +832,7 @@ func TestCreateAriesWithMediaTypeProfiles(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - parameters := &agentParameters{ + parameters := &AgentParameters{ dbParam: &dbParam{dbType: databaseTypeMemOption}, mediaTypeProfiles: tc.mtp, } @@ -854,7 +854,7 @@ func TestStartAriesWithAuthorization(t *testing.T) { testInboundHostURL := randomURL() go func() { - parameters := &agentParameters{ + parameters := &AgentParameters{ server: &HTTPServer{}, host: testHostURL, token: goodToken, @@ -893,7 +893,7 @@ func TestStartAriesWithAuthorization(t *testing.T) { func TestStoreProvider(t *testing.T) { t.Run("test invalid database type", func(t *testing.T) { - _, err := createAriesAgent(&agentParameters{dbParam: &dbParam{dbType: "data1"}}) + _, err := createAriesAgent(&AgentParameters{dbParam: &dbParam{dbType: "data1"}}) require.Error(t, err) require.Contains(t, err.Error(), "database type not set to a valid type") }) @@ -926,6 +926,100 @@ func TestStartCmdInvalidAutoExecuteRFC0593Value(t *testing.T) { require.Contains(t, err.Error(), "invalid syntax") } +// nolint: errcheck,gosec +func TestNewAgentParametersUsingEnv(t *testing.T) { + os.Setenv(agentHostEnvKey, "agentHost") + defer os.Unsetenv(agentHostEnvKey) + + os.Setenv(agentTokenEnvKey, "agentToken") + defer os.Unsetenv(agentTokenEnvKey) + + os.Setenv(databaseTypeEnvKey, "databaseType") + defer os.Unsetenv(databaseTypeEnvKey) + + os.Setenv(databasePrefixEnvKey, "databasePrefix") + defer os.Unsetenv(databasePrefixEnvKey) + + os.Setenv(databaseTimeoutEnvKey, "1") + defer os.Unsetenv(databaseTimeoutEnvKey) + + os.Setenv(agentWebhookEnvKey, "agentWebhook") + defer os.Unsetenv(agentWebhookEnvKey) + + os.Setenv(agentDefaultLabelEnvKey, "agentDefaultLabel") + defer os.Unsetenv(agentDefaultLabelEnvKey) + + os.Setenv(agentLogLevelEnvKey, "DEBUG") + defer os.Unsetenv(agentLogLevelEnvKey) + + os.Setenv(agentHTTPResolverEnvKey, "agentHTTPResolver") + defer os.Unsetenv(agentHTTPResolverEnvKey) + + os.Setenv(agentOutboundTransportEnvKey, "agentOutboundTransport") + defer os.Unsetenv(agentOutboundTransportEnvKey) + + os.Setenv(agentTLSCertFileEnvKey, "agentTLSCertFile") + defer os.Unsetenv(agentTLSCertFileEnvKey) + + os.Setenv(agentTLSKeyFileEnvKey, "agentTLSKeyFile") + defer os.Unsetenv(agentTLSKeyFileEnvKey) + + os.Setenv(agentInboundHostEnvKey, "agentInboundHost") + defer os.Unsetenv(agentInboundHostEnvKey) + + os.Setenv(agentInboundHostExternalEnvKey, "agentInboundHostExternal") + defer os.Unsetenv(agentInboundHostExternalEnvKey) + + os.Setenv(agentWebSocketReadLimitEnvKey, "0") + defer os.Unsetenv(agentWebSocketReadLimitEnvKey) + + os.Setenv(agentAutoAcceptEnvKey, "true") + defer os.Unsetenv(agentAutoAcceptEnvKey) + + os.Setenv(agentTransportReturnRouteEnvKey, "agentTransportReturnRoute") + defer os.Unsetenv(agentTransportReturnRouteEnvKey) + + os.Setenv(agentAutoExecuteRFC0593EnvKey, "true") + defer os.Unsetenv(agentAutoExecuteRFC0593EnvKey) + + os.Setenv(agentContextProviderEnvKey, "agentContextProvider") + defer os.Unsetenv(agentContextProviderEnvKey) + + os.Setenv(agentKeyTypeEnvKey, "agentKeyType") + defer os.Unsetenv(agentKeyTypeEnvKey) + + os.Setenv(agentKeyAgreementTypeEnvKey, "agentKeyAgreementType") + defer os.Unsetenv(agentKeyAgreementTypeEnvKey) + + os.Setenv(agentMediaTypeProfilesEnvKey, "agentMediaTypeProfiles") + defer os.Unsetenv(agentMediaTypeProfilesEnvKey) + + parameters, err := NewAgentParameters(&mockServer{}, nil) + + require.Nil(t, err) + require.Equal(t, spi.DEBUG, log.GetLevel("")) + require.Equal(t, "agentHost", parameters.host) + require.Equal(t, "agentToken", parameters.token) + require.Equal(t, "agentInboundHost", parameters.inboundHostInternals[0]) + require.Equal(t, "agentInboundHostExternal", parameters.inboundHostExternals[0]) + require.Equal(t, int64(0), parameters.websocketReadLimit) + require.Equal(t, "databaseType", parameters.dbParam.dbType) + require.Equal(t, "databasePrefix", parameters.dbParam.prefix) + require.Equal(t, uint64(1), parameters.dbParam.timeout) + require.Equal(t, "agentDefaultLabel", parameters.defaultLabel) + require.Equal(t, true, parameters.autoAccept) + require.Equal(t, "agentWebhook", parameters.webhookURLs[0]) + require.Equal(t, "agentOutboundTransport", parameters.outboundTransports[0]) + require.Equal(t, "agentTransportReturnRoute", parameters.transportReturnRoute) + require.Equal(t, "agentContextProvider", parameters.contextProviderURLs[0]) + require.Equal(t, true, parameters.autoExecuteRFC0593) + require.Equal(t, "agentTLSCertFile", parameters.tlsCertFile) + require.Equal(t, "agentTLSKeyFile", parameters.tlsKeyFile) + require.Equal(t, "agentKeyType", parameters.keyType) + require.Equal(t, "agentKeyAgreementType", parameters.keyAgreementType) + require.Equal(t, "agentMediaTypeProfiles", parameters.mediaTypeProfiles[0]) +} + func waitForServerToStart(t *testing.T, host, inboundHost string) { if err := listenFor(host); err != nil { t.Fatal(err)