diff --git a/examples/server/run_server/run_server.go b/examples/server/run_server/run_server.go index 0e44b35..709bc6f 100644 --- a/examples/server/run_server/run_server.go +++ b/examples/server/run_server/run_server.go @@ -1,10 +1,15 @@ package main import ( - "github.com/bitcoin-sv/go-paymail/logging" + "encoding/json" + "fmt" + "net/http" "time" + "github.com/bitcoin-sv/go-paymail/logging" + "github.com/bitcoin-sv/go-paymail/server" + "github.com/julienschmidt/httprouter" ) func main() { @@ -19,14 +24,16 @@ func main() { config, err := server.NewConfig( new(demoServiceProvider), server.WithBasicRoutes(), - server.WithDomain("localhost"), // todo: make this work locally? + server.WithDomain("localhost"), server.WithDomain("another.com"), server.WithDomain("test.com"), server.WithGenericCapabilities(), server.WithPort(3000), server.WithServiceName("BsvAliasCustom"), server.WithTimeout(15*time.Second), + server.WithCapabilities(customCapabilities()), ) + config.Prefix = "http://" //normally paymail requires https, but for demo purposes we'll use http if err != nil { logger.Fatal().Msg(err.Error()) } @@ -34,3 +41,26 @@ func main() { // Create & start the server server.StartServer(server.CreateServer(config), config.Logger) } + +func customCapabilities() map[string]any { + exampleBrfcKey := "406cef0ae2d6" + return map[string]any{ + "custom_static_boolean": false, + "custom_static_int": 10, + exampleBrfcKey: true, + "custom_callable_cap": server.CallableCapability{ + Path: fmt.Sprintf("/display_paymail/%s", server.PaymailAddressTemplate), + Method: http.MethodGet, + Handler: func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + incomingPaymail := p.ByName(server.PaymailAddressParamName) + + response := map[string]string{ + "paymail": incomingPaymail, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }, + }, + } +} diff --git a/server/capabilities.go b/server/capabilities.go index 0ac482b..bfd3bc4 100644 --- a/server/capabilities.go +++ b/server/capabilities.go @@ -1,38 +1,88 @@ package server import ( + "fmt" "net/http" + "strings" "github.com/bitcoin-sv/go-paymail" "github.com/julienschmidt/httprouter" ) -// GenericCapabilities will make generic capabilities -func GenericCapabilities(bsvAliasVersion string, senderValidation bool) *paymail.CapabilitiesPayload { - return &paymail.CapabilitiesPayload{ - BsvAlias: bsvAliasVersion, - Capabilities: map[string]interface{}{ - paymail.BRFCPaymentDestination: "/address/{alias}@{domain.tld}", - paymail.BRFCPki: "/id/{alias}@{domain.tld}", - paymail.BRFCPublicProfile: "/public-profile/{alias}@{domain.tld}", - paymail.BRFCSenderValidation: senderValidation, - paymail.BRFCVerifyPublicKeyOwner: "/verify-pubkey/{alias}@{domain.tld}/{pubkey}", +type CallableCapability struct { + Path string + Method string + Handler httprouter.Handle +} + +type CallableCapabilitiesMap map[string]CallableCapability +type StaticCapabilitiesMap map[string]any + +func (c *Configuration) SetGenericCapabilities() { + _addCapabilities(c.callableCapabilities, + CallableCapabilitiesMap{ + paymail.BRFCPaymentDestination: CallableCapability{ + Path: fmt.Sprintf("/address/%s", PaymailAddressTemplate), + Method: http.MethodPost, + Handler: c.resolveAddress, + }, + paymail.BRFCPki: CallableCapability{ + Path: fmt.Sprintf("/id/%s", PaymailAddressTemplate), + Method: http.MethodGet, + Handler: c.showPKI, + }, + paymail.BRFCPublicProfile: CallableCapability{ + Path: fmt.Sprintf("/public-profile/%s", PaymailAddressTemplate), + Method: http.MethodGet, + Handler: c.publicProfile, + }, + paymail.BRFCVerifyPublicKeyOwner: CallableCapability{ + Path: fmt.Sprintf("/verify-pubkey/%s/%s", PaymailAddressTemplate, PubKeyTemplate), + Method: http.MethodGet, + Handler: c.verifyPubKey, + }, }, - } + ) + _addCapabilities(c.staticCapabilities, + StaticCapabilitiesMap{ + paymail.BRFCSenderValidation: c.SenderValidationEnabled, + }, + ) +} + +func (c *Configuration) SetP2PCapabilities() { + _addCapabilities(c.callableCapabilities, + CallableCapabilitiesMap{ + paymail.BRFCP2PTransactions: CallableCapability{ + Path: fmt.Sprintf("/receive-transaction/%s", PaymailAddressTemplate), + Method: http.MethodPost, + Handler: c.p2pReceiveTx, + }, + paymail.BRFCP2PPaymentDestination: CallableCapability{ + Path: fmt.Sprintf("/p2p-payment-destination/%s", PaymailAddressTemplate), + Method: http.MethodPost, + Handler: c.p2pDestination, + }, + }, + ) } -// P2PCapabilities will make generic capabilities & add additional p2p capabilities -func P2PCapabilities(bsvAliasVersion string, senderValidation bool) *paymail.CapabilitiesPayload { - c := GenericCapabilities(bsvAliasVersion, senderValidation) - c.Capabilities[paymail.BRFCP2PTransactions] = "/receive-transaction/{alias}@{domain.tld}" - c.Capabilities[paymail.BRFCP2PPaymentDestination] = "/p2p-payment-destination/{alias}@{domain.tld}" - return c +func (c *Configuration) SetBeefCapabilities() { + _addCapabilities(c.callableCapabilities, + CallableCapabilitiesMap{ + paymail.BRFCBeefTransaction: CallableCapability{ + Path: fmt.Sprintf("/beef/%s", PaymailAddressTemplate), + Method: http.MethodPost, + Handler: c.p2pReceiveBeefTx, + }, + }, + ) } -// BeefCapabilities will add beef capabilities to given ones -func BeefCapabilities(c *paymail.CapabilitiesPayload) *paymail.CapabilitiesPayload { - c.Capabilities[paymail.BRFCBeefTransaction] = "/beef/{alias}@{domain.tld}" - return c +func _addCapabilities[T any](base map[string]T, newCaps map[string]T) { + for key, val := range newCaps { + base[key] = val + } } // showCapabilities will return the service discovery results for the server @@ -40,15 +90,63 @@ func BeefCapabilities(c *paymail.CapabilitiesPayload) *paymail.CapabilitiesPaylo // // Specs: http://bsvalias.org/02-02-capability-discovery.html func (c *Configuration) showCapabilities(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - // Check the domain (allowed, and used for capabilities response) - // todo: bake this into middleware? This is protecting the "req" domain name (like CORs) - domain := getHost(req) - if !c.IsAllowedDomain(domain) { - ErrorResponse(w, req, ErrorUnknownDomain, "domain unknown: "+domain, http.StatusBadRequest, c.Logger) + // Check the host (allowed, and used for capabilities response) + // todo: bake this into middleware? This is protecting the "req" host name (like CORs) + host := "" + if req.URL.IsAbs() || len(req.URL.Host) == 0 { + host = req.Host + } else { + host = req.URL.Host + } + + if !c.IsAllowedDomain(host) { + ErrorResponse(w, req, ErrorUnknownDomain, "domain unknown: "+host, http.StatusBadRequest, c.Logger) + return + } + + capabilities, err := c.EnrichCapabilities(host) + if err != nil { + ErrorResponse(w, req, ErrorEncodingResponse, err.Error(), http.StatusBadRequest, c.Logger) return } - // Set the service URL - capabilities := c.EnrichCapabilities(domain) writeJsonResponse(w, req, c.Logger, capabilities) } + +// EnrichCapabilities will update the capabilities with the appropriate service url +func (c *Configuration) EnrichCapabilities(host string) (*paymail.CapabilitiesPayload, error) { + serviceUrl, err := generateServiceURL(c.Prefix, host, c.APIVersion, c.ServiceName) + if err != nil { + return nil, err + } + payload := &paymail.CapabilitiesPayload{ + BsvAlias: c.BSVAliasVersion, + Capabilities: make(map[string]interface{}), + } + for key, cap := range c.staticCapabilities { + payload.Capabilities[key] = cap + } + for key, cap := range c.callableCapabilities { + payload.Capabilities[key] = serviceUrl + string(cap.Path) + } + return payload, nil +} + +func generateServiceURL(prefix, domain, apiVersion, serviceName string) (string, error) { + if len(prefix) == 0 || len(domain) == 0 { + return "", ErrPrefixOrDomainMissing + } + strBuilder := new(strings.Builder) + strBuilder.WriteString(prefix) + strBuilder.WriteString(domain) + if len(apiVersion) > 0 { + strBuilder.WriteString("/") + strBuilder.WriteString(apiVersion) + } + if len(serviceName) > 0 { + strBuilder.WriteString("/") + strBuilder.WriteString(serviceName) + } + + return strBuilder.String(), nil +} diff --git a/server/capabilities_test.go b/server/capabilities_test.go index 6252ebb..3b6a64a 100644 --- a/server/capabilities_test.go +++ b/server/capabilities_test.go @@ -3,62 +3,48 @@ package server import ( "testing" - "github.com/bitcoin-sv/go-paymail" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -// TestGenericCapabilities will test the method GenericCapabilities() -func TestGenericCapabilities(t *testing.T) { +func TestGenerateServiceURL(t *testing.T) { t.Parallel() t.Run("valid values", func(t *testing.T) { - c := GenericCapabilities("test", true) - require.NotNil(t, c) - assert.Equal(t, "test", c.BsvAlias) - assert.Equal(t, 5, len(c.Capabilities)) + u, err := generateServiceURL("https://", "test.com", "v1", "bsvalias") + assert.NoError(t, err) + assert.Equal(t, "https://test.com/v1/bsvalias", u) }) - t.Run("no alias version", func(t *testing.T) { - c := GenericCapabilities("", true) - require.NotNil(t, c) - assert.Equal(t, "", c.BsvAlias) + t.Run("all invalid values", func(t *testing.T) { + _, err := generateServiceURL("", "", "", "") + assert.Error(t, err) }) - t.Run("sender validation", func(t *testing.T) { - c := GenericCapabilities("", true) - require.NotNil(t, c) - assert.Equal(t, true, c.Capabilities[paymail.BRFCSenderValidation]) + t.Run("missing prefix", func(t *testing.T) { + _, err := generateServiceURL("", "test.com", "v1", "") + assert.Error(t, err) }) -} - -// TestP2PCapabilities will test the method P2PCapabilities() -func TestP2PCapabilities(t *testing.T) { - t.Parallel() - t.Run("valid values", func(t *testing.T) { - c := P2PCapabilities("test", true) - require.NotNil(t, c) - assert.Equal(t, "test", c.BsvAlias) - assert.Equal(t, 7, len(c.Capabilities)) + t.Run("missing domain", func(t *testing.T) { + _, err := generateServiceURL("https://", "", "v1", "") + assert.Error(t, err) }) - t.Run("no alias version", func(t *testing.T) { - c := P2PCapabilities("", true) - require.NotNil(t, c) - assert.Equal(t, "", c.BsvAlias) + t.Run("no api version", func(t *testing.T) { + u, err := generateServiceURL("https://", "test", "", "bsvalias") + assert.NoError(t, err) + assert.Equal(t, "https://test/bsvalias", u) }) - t.Run("sender validation", func(t *testing.T) { - c := P2PCapabilities("", true) - require.NotNil(t, c) - assert.Equal(t, true, c.Capabilities[paymail.BRFCSenderValidation]) + t.Run("no service name", func(t *testing.T) { + u, err := generateServiceURL("https://", "test", "v1", "") + assert.NoError(t, err) + assert.Equal(t, "https://test/v1", u) }) - t.Run("has p2p routes", func(t *testing.T) { - c := P2PCapabilities("", true) - require.NotNil(t, c) - assert.NotEmpty(t, c.Capabilities[paymail.BRFCP2PTransactions]) - assert.NotEmpty(t, c.Capabilities[paymail.BRFCP2PPaymentDestination]) + t.Run("service with explicit port", func(t *testing.T) { + u, err := generateServiceURL("https://", "test:1234", "v1", "bsvalias") + assert.NoError(t, err) + assert.Equal(t, "https://test:1234/v1/bsvalias", u) }) } diff --git a/server/config.go b/server/config.go index e6c4bed..969fe90 100644 --- a/server/config.go +++ b/server/config.go @@ -1,30 +1,36 @@ package server import ( - "github.com/rs/zerolog" + "slices" "strings" "time" + "github.com/rs/zerolog" + "github.com/bitcoin-sv/go-paymail" ) // Configuration paymail server configuration object type Configuration struct { - APIVersion string `json:"api_version"` - BasicRoutes *basicRoutes `json:"basic_routes"` - BSVAliasVersion string `json:"bsv_alias_version"` - Capabilities *paymail.CapabilitiesPayload `json:"capabilities"` - PaymailDomains []*Domain `json:"paymail_domains"` - PaymailDomainsValidationDisabled bool `json:"paymail_domains_validation_disabled"` - Port int `json:"port"` - Prefix string `json:"prefix"` - SenderValidationEnabled bool `json:"sender_validation_enabled"` - ServiceName string `json:"service_name"` - Timeout time.Duration `json:"timeout"` - Logger *zerolog.Logger `json:"logger"` + APIVersion string `json:"api_version"` + BasicRoutes *basicRoutes `json:"basic_routes"` + BSVAliasVersion string `json:"bsv_alias_version"` + PaymailDomains []*Domain `json:"paymail_domains"` + PaymailDomainsValidationDisabled bool `json:"paymail_domains_validation_disabled"` + Port int `json:"port"` + Prefix string `json:"prefix"` + SenderValidationEnabled bool `json:"sender_validation_enabled"` + GenericCapabilitiesEnabled bool `json:"generic_capabilities_enabled"` + P2PCapabilitiesEnabled bool `json:"p2p_capabilities_enabled"` + BeefCapabilitiesEnabled bool `json:"beef_capabilities_enabled"` + ServiceName string `json:"service_name"` + Timeout time.Duration `json:"timeout"` + Logger *zerolog.Logger `json:"logger"` // private - actions PaymailServiceProvider + actions PaymailServiceProvider + callableCapabilities CallableCapabilitiesMap + staticCapabilities StaticCapabilitiesMap } // Domain is the Paymail Domain information @@ -53,12 +59,11 @@ func (c *Configuration) Validate() error { return ErrServiceNameMissing } - // Validate (basic checks for existence of capabilities) - if c.Capabilities == nil { - return ErrCapabilitiesMissing - } else if len(c.Capabilities.BsvAlias) == 0 { + if c.BSVAliasVersion == "" { return ErrBsvAliasMissing - } else if len(c.Capabilities.Capabilities) == 0 { + } + + if c.callableCapabilities == nil || len(c.callableCapabilities) == 0 { return ErrCapabilitiesMissing } @@ -66,29 +71,20 @@ func (c *Configuration) Validate() error { } // IsAllowedDomain will return true if it's an allowed paymail domain -func (c *Configuration) IsAllowedDomain(domain string) (success bool) { - +func (c *Configuration) IsAllowedDomain(domain string) bool { if c.PaymailDomainsValidationDisabled { - success = true - return + return true } - // Sanitize the domain (standard) var err error if domain, err = paymail.SanitizeDomain(domain); err != nil { - // todo: log the error? This should rarely occur - return - } - - // Loop all domains check - for _, d := range c.PaymailDomains { - if strings.EqualFold(d.Name, domain) { - success = true - break - } + c.Logger.Warn().Err(err).Msg("failed to sanitize domain") + return false } - return + return slices.ContainsFunc(c.PaymailDomains, func(d *Domain) bool { + return strings.EqualFold(d.Name, domain) + }) } // AddDomain will add the domain if it does not exist @@ -115,44 +111,6 @@ func (c *Configuration) AddDomain(domain string) (err error) { return } -// EnrichCapabilities will update the capabilities with the appropriate service url -func (c *Configuration) EnrichCapabilities(domain string) *paymail.CapabilitiesPayload { - capabilities := &paymail.CapabilitiesPayload{ - BsvAlias: c.Capabilities.BsvAlias, - Capabilities: make(map[string]interface{}), - } - for key, val := range c.Capabilities.Capabilities { - if w, ok := val.(string); ok { - capabilities.Capabilities[key] = GenerateServiceURL(c.Prefix, domain, c.APIVersion, c.ServiceName) + w - } else { - capabilities.Capabilities[key] = val - } - } - return capabilities -} - -// GenerateServiceURL will create the service URL -func GenerateServiceURL(prefix, domain, apiVersion, serviceName string) string { - - // Require prefix or domain - if len(prefix) == 0 || len(domain) == 0 { - return "" - } - u := prefix + domain - - // Set the api version - if len(apiVersion) > 0 { - u = u + "/" + apiVersion - } - - // Set the service name - if len(serviceName) > 0 { - u = u + "/" + serviceName - } - - return u -} - // NewConfig will make a new server configuration func NewConfig(serviceProvider PaymailServiceProvider, opts ...ConfigOps) (*Configuration, error) { @@ -169,6 +127,16 @@ func NewConfig(serviceProvider PaymailServiceProvider, opts ...ConfigOps) (*Conf opt(config) } + if config.GenericCapabilitiesEnabled { + config.SetGenericCapabilities() + } + if config.P2PCapabilitiesEnabled { + config.SetP2PCapabilities() + } + if config.BeefCapabilitiesEnabled { + config.SetBeefCapabilities() + } + // Validate the configuration if err := config.Validate(); err != nil { return nil, err diff --git a/server/config_options.go b/server/config_options.go index 186eebe..e01d9b2 100644 --- a/server/config_options.go +++ b/server/config_options.go @@ -1,9 +1,10 @@ package server import ( + "time" + "github.com/bitcoin-sv/go-paymail/logging" "github.com/rs/zerolog" - "time" "github.com/bitcoin-sv/go-paymail" ) @@ -20,44 +21,53 @@ func defaultConfigOptions() *Configuration { APIVersion: DefaultAPIVersion, BasicRoutes: &basicRoutes{}, BSVAliasVersion: paymail.DefaultBsvAliasVersion, - Capabilities: GenericCapabilities(paymail.DefaultBsvAliasVersion, DefaultSenderValidation), PaymailDomainsValidationDisabled: false, Port: DefaultServerPort, Prefix: DefaultPrefix, SenderValidationEnabled: DefaultSenderValidation, + GenericCapabilitiesEnabled: true, + P2PCapabilitiesEnabled: false, + BeefCapabilitiesEnabled: false, ServiceName: paymail.DefaultServiceName, Timeout: DefaultTimeout, Logger: logging.GetDefaultLogger(), + callableCapabilities: make(CallableCapabilitiesMap), + staticCapabilities: make(StaticCapabilitiesMap), } } // WithGenericCapabilities will load the generic Paymail capabilities func WithGenericCapabilities() ConfigOps { return func(c *Configuration) { - c.Capabilities = GenericCapabilities(c.BSVAliasVersion, c.SenderValidationEnabled) + c.GenericCapabilitiesEnabled = true } } // WithP2PCapabilities will load the generic & p2p capabilities func WithP2PCapabilities() ConfigOps { return func(c *Configuration) { - c.Capabilities = P2PCapabilities(c.BSVAliasVersion, c.SenderValidationEnabled) + c.GenericCapabilitiesEnabled = true + c.P2PCapabilitiesEnabled = true } } // WithBeefCapabilities will load the beef capabilities func WithBeefCapabilities() ConfigOps { return func(c *Configuration) { - c.Capabilities = BeefCapabilities(c.Capabilities) + c.BeefCapabilitiesEnabled = true } } // WithCapabilities will modify the capabilities -func WithCapabilities(capabilities *paymail.CapabilitiesPayload) ConfigOps { +func WithCapabilities(customCapabilities map[string]any) ConfigOps { return func(c *Configuration) { - if capabilities != nil { - // todo: validate that these are valid capabilities (string->url path) - c.Capabilities = capabilities + for key, cap := range customCapabilities { + switch typedCap := cap.(type) { + case CallableCapability: + c.callableCapabilities[key] = typedCap + default: + c.staticCapabilities[key] = typedCap + } } } } diff --git a/server/config_test.go b/server/config_test.go index 0c9d3eb..9f5b390 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -14,7 +14,6 @@ func testConfig(t *testing.T, domain string) *Configuration { c, err := NewConfig( new(mockServiceProvider), WithDomain(domain), - WithGenericCapabilities(), ) require.NoError(t, err) require.NotNil(t, c) @@ -62,7 +61,7 @@ func TestConfiguration_Validate(t *testing.T) { assert.ErrorIs(t, err, ErrServiceNameMissing) }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing bsv alias", func(t *testing.T) { c := &Configuration{ Port: 12345, ServiceName: "test", @@ -70,31 +69,31 @@ func TestConfiguration_Validate(t *testing.T) { } err := c.Validate() require.Error(t, err) - assert.ErrorIs(t, err, ErrCapabilitiesMissing) + assert.ErrorIs(t, err, ErrBsvAliasMissing) }) - t.Run("invalid capabilities", func(t *testing.T) { + t.Run("missing capabilities", func(t *testing.T) { c := &Configuration{ - Port: 12345, - ServiceName: "test", - PaymailDomains: []*Domain{{Name: "test.com"}}, - Capabilities: &paymail.CapabilitiesPayload{ - BsvAlias: "", - }, + Port: 12345, + ServiceName: "test", + PaymailDomains: []*Domain{{Name: "test.com"}}, + BSVAliasVersion: paymail.DefaultBsvAliasVersion, + callableCapabilities: nil, + staticCapabilities: nil, } err := c.Validate() require.Error(t, err) - assert.ErrorIs(t, err, ErrBsvAliasMissing) + assert.ErrorIs(t, err, ErrCapabilitiesMissing) }) t.Run("zero capabilities", func(t *testing.T) { c := &Configuration{ - Port: 12345, - ServiceName: "test", - PaymailDomains: []*Domain{{Name: "test.com"}}, - Capabilities: &paymail.CapabilitiesPayload{ - BsvAlias: "test", - }, + Port: 12345, + ServiceName: "test", + PaymailDomains: []*Domain{{Name: "test.com"}}, + BSVAliasVersion: paymail.DefaultBsvAliasVersion, + callableCapabilities: make(CallableCapabilitiesMap), + staticCapabilities: make(StaticCapabilitiesMap), } err := c.Validate() require.Error(t, err) @@ -103,22 +102,29 @@ func TestConfiguration_Validate(t *testing.T) { t.Run("basic valid configuration", func(t *testing.T) { c := &Configuration{ - Port: 12345, - ServiceName: "test", - PaymailDomains: []*Domain{{Name: "test.com"}}, - Capabilities: GenericCapabilities("test", false), + Port: 12345, + ServiceName: "test", + BSVAliasVersion: paymail.DefaultBsvAliasVersion, + PaymailDomains: []*Domain{{Name: "test.com"}}, + callableCapabilities: make(CallableCapabilitiesMap), + staticCapabilities: make(StaticCapabilitiesMap), } + c.SetGenericCapabilities() err := c.Validate() require.NoError(t, err) }) t.Run("configuration with domain validation disabled", func(t *testing.T) { c := &Configuration{ - Port: 12345, - ServiceName: "test", - PaymailDomains: []*Domain{}, - Capabilities: GenericCapabilities("test", false), + Port: 12345, + ServiceName: "test", + BSVAliasVersion: paymail.DefaultBsvAliasVersion, + PaymailDomains: []*Domain{}, + PaymailDomainsValidationDisabled: false, + callableCapabilities: make(CallableCapabilitiesMap), + staticCapabilities: make(StaticCapabilitiesMap), } + c.SetGenericCapabilities() assert.False(t, c.PaymailDomainsValidationDisabled) err := c.Validate() assert.ErrorIs(t, err, ErrDomainMissing) @@ -127,6 +133,33 @@ func TestConfiguration_Validate(t *testing.T) { err = c.Validate() assert.NoError(t, err) }) + + t.Run("configuration with SenderValidationEnabled", func(t *testing.T) { + c := &Configuration{ + Port: 12345, + Prefix: "https://", + ServiceName: "test", + BSVAliasVersion: paymail.DefaultBsvAliasVersion, + PaymailDomains: []*Domain{{Name: "test.com"}}, + SenderValidationEnabled: false, + callableCapabilities: make(CallableCapabilitiesMap), + staticCapabilities: make(StaticCapabilitiesMap), + } + c.SetGenericCapabilities() + err := c.Validate() + assert.NoError(t, err) + caps, err := c.EnrichCapabilities("test.com") + assert.NoError(t, err) + assert.False(t, caps.Capabilities[paymail.BRFCSenderValidation].(bool)) + + c.SenderValidationEnabled = true + c.SetGenericCapabilities() + err = c.Validate() + assert.NoError(t, err) + caps, err = c.EnrichCapabilities("test.com") + assert.NoError(t, err) + assert.True(t, caps.Capabilities[paymail.BRFCSenderValidation].(bool)) + }) } // TestConfiguration_IsAllowedDomain will test the method IsAllowedDomain() @@ -235,14 +268,14 @@ func TestConfiguration_EnrichCapabilities(t *testing.T) { c := testConfig(t, testDomain) require.NotNil(t, c) - capabilities := c.EnrichCapabilities(testDomain) - assert.Equal(t, 5, len(capabilities.Capabilities)) - assert.Equal(t, paymail.DefaultBsvAliasVersion, c.Capabilities.BsvAlias) - assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/address/{alias}@{domain.tld}", capabilities.Capabilities[paymail.BRFCPaymentDestination]) - assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/id/{alias}@{domain.tld}", capabilities.Capabilities[paymail.BRFCPki]) - assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/public-profile/{alias}@{domain.tld}", capabilities.Capabilities[paymail.BRFCPublicProfile]) - assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/verify-pubkey/{alias}@{domain.tld}/{pubkey}", capabilities.Capabilities[paymail.BRFCVerifyPublicKeyOwner]) - assert.Equal(t, false, capabilities.Capabilities[paymail.BRFCSenderValidation]) + caps, err := c.EnrichCapabilities(testDomain) + assert.NoError(t, err) + assert.Equal(t, 5, len(caps.Capabilities)) + assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/address/{alias}@{domain.tld}", caps.Capabilities[paymail.BRFCPaymentDestination]) + assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/id/{alias}@{domain.tld}", caps.Capabilities[paymail.BRFCPki]) + assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/public-profile/{alias}@{domain.tld}", caps.Capabilities[paymail.BRFCPublicProfile]) + assert.Equal(t, "https://"+testDomain+"/v1/bsvalias/verify-pubkey/{alias}@{domain.tld}/{pubkey}", caps.Capabilities[paymail.BRFCVerifyPublicKeyOwner]) + assert.Equal(t, false, caps.Capabilities[paymail.BRFCSenderValidation]) }) t.Run("multiple times", func(t *testing.T) { @@ -250,46 +283,23 @@ func TestConfiguration_EnrichCapabilities(t *testing.T) { c := testConfig(t, testDomain) require.NotNil(t, c) - capabilities := c.EnrichCapabilities(testDomain) - assert.Equal(t, 5, len(capabilities.Capabilities)) - - capabilities = c.EnrichCapabilities(testDomain) - assert.Equal(t, 5, len(capabilities.Capabilities)) - }) -} - -// TestGenerateServiceURL will test the method GenerateServiceURL() -func TestGenerateServiceURL(t *testing.T) { - t.Parallel() - - t.Run("valid values", func(t *testing.T) { - u := GenerateServiceURL("https://", "test.com", "v1", "bsvalias") - assert.Equal(t, "https://test.com/v1/bsvalias", u) - }) - - t.Run("all invalid values", func(t *testing.T) { - u := GenerateServiceURL("", "", "", "") - assert.Equal(t, "", u) - }) - - t.Run("missing prefix", func(t *testing.T) { - u := GenerateServiceURL("", "test.com", "v1", "") - assert.Equal(t, "", u) - }) + caps, err := c.EnrichCapabilities(testDomain) + assert.NoError(t, err) + assert.Equal(t, 5, len(caps.Capabilities)) - t.Run("missing domain", func(t *testing.T) { - u := GenerateServiceURL("https://", "", "v1", "") - assert.Equal(t, "", u) + caps, err = c.EnrichCapabilities(testDomain) + assert.NoError(t, err) + assert.Equal(t, 5, len(caps.Capabilities)) }) - t.Run("no api version", func(t *testing.T) { - u := GenerateServiceURL("https://", "test", "", "bsvalias") - assert.Equal(t, "https://test/bsvalias", u) - }) + t.Run("empty domain and prefix", func(t *testing.T) { + testDomain := "test.com" + c := testConfig(t, testDomain) + require.NotNil(t, c) - t.Run("no service name", func(t *testing.T) { - u := GenerateServiceURL("https://", "test", "v1", "") - assert.Equal(t, "https://test/v1", u) + c.Prefix = "" + _, err := c.EnrichCapabilities("") + assert.Error(t, err) }) } @@ -318,7 +328,7 @@ func TestNewConfig(t *testing.T) { ) require.NoError(t, err) require.NotNil(t, c) - assert.Equal(t, 5, len(c.Capabilities.Capabilities)) + assert.Equal(t, 4, len(c.callableCapabilities)) assert.Equal(t, "test.com", c.PaymailDomains[0].Name) }) @@ -374,19 +384,28 @@ func TestNewConfig(t *testing.T) { ) require.NoError(t, err) require.NotNil(t, c) - assert.Equal(t, 7, len(c.Capabilities.Capabilities)) + assert.Equal(t, 6, len(c.callableCapabilities)) }) t.Run("with custom capabilities", func(t *testing.T) { c, err := NewConfig( new(mockServiceProvider), WithDomain("test.com"), - WithCapabilities(GenericCapabilities("test", false)), + WithCapabilities(map[string]any{ + "test": true, + "callable": CallableCapability{ + Path: "/test", + Method: "GET", + Handler: nil, + }, + }), ) require.NoError(t, err) require.NotNil(t, c) - assert.Equal(t, 5, len(c.Capabilities.Capabilities)) - assert.Equal(t, "test", c.Capabilities.BsvAlias) + assert.Equal(t, 5, len(c.callableCapabilities)) + assert.Equal(t, 2, len(c.staticCapabilities)) + assert.True(t, c.staticCapabilities["test"].(bool)) + assert.Equal(t, "/test", c.callableCapabilities["callable"].Path) }) t.Run("with beef capabilities", func(t *testing.T) { @@ -398,7 +417,7 @@ func TestNewConfig(t *testing.T) { ) require.NoError(t, err) require.NotNil(t, c) - assert.Equal(t, 8, len(c.Capabilities.Capabilities)) + assert.Equal(t, 7, len(c.callableCapabilities)) }) t.Run("with basic routes", func(t *testing.T) { diff --git a/server/definitions.go b/server/definitions.go index 7e0a06a..44a1760 100644 --- a/server/definitions.go +++ b/server/definitions.go @@ -15,6 +15,14 @@ const ( DefaultTimeout = 15 * time.Second // Default timeouts ) +// Url params +const ( + PaymailAddressParamName = "paymailAddress" // Used to get actual paymail address from the request url + PubKeyParamName = "pubKey" // Used to get actual pubkey from the request url + PaymailAddressTemplate = "{alias}@{domain.tld}" // Used as a placeholder in capabilities list + PubKeyTemplate = "{pubkey}" // Used as a placeholder in capabilities list +) + // basicRoutes is the configuration for basic server routes type basicRoutes struct { Add404Route bool `json:"add_404_route,omitempty"` diff --git a/server/error.go b/server/error.go index c034ce4..da63d59 100644 --- a/server/error.go +++ b/server/error.go @@ -3,9 +3,10 @@ package server import ( "encoding/json" "errors" + "net/http" + "github.com/bitcoin-sv/go-paymail/logging" "github.com/rs/zerolog" - "net/http" "github.com/bitcoin-sv/go-paymail" ) @@ -52,6 +53,9 @@ var ( // ErrFailedMarshalJSON is when the JSON marshal fails ErrFailedMarshalJSON = errors.New("failed to marshal JSON response") + + //GenerateServiceURL is when the service URL cannot be generated + ErrPrefixOrDomainMissing = errors.New("prefix or domain is missing") ) // ErrorResponse is a standard way to return errors to the client diff --git a/server/p2p_payment_destination.go b/server/p2p_payment_destination.go index d9583cf..8d6d681 100644 --- a/server/p2p_payment_destination.go +++ b/server/p2p_payment_destination.go @@ -23,7 +23,7 @@ type p2pDestinationRequestBody struct { // // Specs: https://docs.moneybutton.com/docs/paymail-07-p2p-payment-destination.html func (c *Configuration) p2pDestination(w http.ResponseWriter, req *http.Request, p httprouter.Params) { - incomingPaymail := p.ByName("paymailAddress") + incomingPaymail := p.ByName(PaymailAddressParamName) // Parse, sanitize and basic validation alias, domain, paymailAddress := paymail.SanitizePaymail(incomingPaymail) diff --git a/server/p2p_receive_transaction_request_parser.go b/server/p2p_receive_transaction_request_parser.go index 2cffda9..f657a29 100644 --- a/server/p2p_receive_transaction_request_parser.go +++ b/server/p2p_receive_transaction_request_parser.go @@ -13,7 +13,7 @@ type parseError struct { } func parseP2pReceiveTxRequest(c *Configuration, req *http.Request, params httprouter.Params, format p2pPayloadFormat) (*p2pReceiveTxReqPayload, *parseError) { - incomingPaymail := params.ByName("paymailAddress") + incomingPaymail := params.ByName(PaymailAddressParamName) alias, domain, paymailAddress := paymail.SanitizePaymail(incomingPaymail) if len(paymailAddress) == 0 { diff --git a/server/pki.go b/server/pki.go index 963450a..6b1fe4f 100644 --- a/server/pki.go +++ b/server/pki.go @@ -12,7 +12,7 @@ import ( // Specs: http://bsvalias.org/03-public-key-infrastructure.html func (c *Configuration) showPKI(w http.ResponseWriter, req *http.Request, p httprouter.Params) { - incomingPaymail := p.ByName("paymailAddress") + incomingPaymail := p.ByName(PaymailAddressParamName) // Parse, sanitize and basic validation alias, domain, address := paymail.SanitizePaymail(incomingPaymail) diff --git a/server/public_profile.go b/server/public_profile.go index 29b781d..d23ae9f 100644 --- a/server/public_profile.go +++ b/server/public_profile.go @@ -11,7 +11,7 @@ import ( // // Specs: https://github.com/bitcoin-sv-specs/brfc-paymail/pull/7/files func (c *Configuration) publicProfile(w http.ResponseWriter, req *http.Request, p httprouter.Params) { - incomingPaymail := p.ByName("paymailAddress") + incomingPaymail := p.ByName(PaymailAddressParamName) // Parse, sanitize and basic validation alias, domain, address := paymail.SanitizePaymail(incomingPaymail) diff --git a/server/resolve_address.go b/server/resolve_address.go index 674062f..f512625 100644 --- a/server/resolve_address.go +++ b/server/resolve_address.go @@ -29,7 +29,7 @@ Incoming Data Object Example: // // Specs: http://bsvalias.org/04-01-basic-address-resolution.html func (c *Configuration) resolveAddress(w http.ResponseWriter, req *http.Request, p httprouter.Params) { - incomingPaymail := p.ByName("paymailAddress") + incomingPaymail := p.ByName(PaymailAddressParamName) // Parse, sanitize and basic validation alias, domain, paymailAddress := paymail.SanitizePaymail(incomingPaymail) diff --git a/server/router.go b/server/router.go index 8e7c930..603584a 100644 --- a/server/router.go +++ b/server/router.go @@ -1,38 +1,25 @@ package server import ( + "fmt" "net/http" + "strings" "github.com/newrelic/go-agent/v3/integrations/nrhttprouter" ) // Handlers are used to isolate loading the routes (used for testing) func Handlers(configuration *Configuration) *nrhttprouter.Router { + router := nrhttprouter.New(nil) - // Create a new router - r := nrhttprouter.New(nil) + configuration.RegisterBasicRoutes(router) + configuration.RegisterRoutes(router) - // Register the routes - configuration.RegisterBasicRoutes(r) - configuration.RegisterRoutes(r) - - // Return the router - return r + return router } // RegisterBasicRoutes register the basic routes to the http router -func (c *Configuration) RegisterBasicRoutes(r *nrhttprouter.Router) { - c.registerBasicRoutes(r) -} - -// RegisterRoutes register all the available paymail routes to the http router -func (c *Configuration) RegisterRoutes(r *nrhttprouter.Router) { - c.registerPaymailRoutes(r) -} - -// registerBasicRoutes will register basic server related routes -func (c *Configuration) registerBasicRoutes(router *nrhttprouter.Router) { - +func (c *Configuration) RegisterBasicRoutes(router *nrhttprouter.Router) { // Skip if not set if c.BasicRoutes == nil { return @@ -62,54 +49,29 @@ func (c *Configuration) registerBasicRoutes(router *nrhttprouter.Router) { } } -// registerPaymailRoutes will register all paymail related routes -func (c *Configuration) registerPaymailRoutes(router *nrhttprouter.Router) { - - // Capabilities (service discovery) - router.GET( - "/.well-known/"+c.ServiceName, - c.showCapabilities, - ) - - // PKI request (public key information) - router.GET( - "/"+c.APIVersion+"/"+c.ServiceName+"/id/:paymailAddress", - c.showPKI, - ) - - // Verify PubKey request (public key verification to paymail address) - router.GET( - "/"+c.APIVersion+"/"+c.ServiceName+"/verify-pubkey/:paymailAddress/:pubKey", - c.verifyPubKey, - ) - - // Payment Destination request (address resolution) - router.POST( - "/"+c.APIVersion+"/"+c.ServiceName+"/address/:paymailAddress", - c.resolveAddress, - ) - - // Public Profile request (returns Name & Avatar) - router.GET( - "/"+c.APIVersion+"/"+c.ServiceName+"/public-profile/:paymailAddress", - c.publicProfile, - ) - - // P2P Destination request (returns output & reference) - router.POST( - "/"+c.APIVersion+"/"+c.ServiceName+"/p2p-payment-destination/:paymailAddress", - c.p2pDestination, - ) +// RegisterRoutes register all the available paymail routes to the http router +func (c *Configuration) RegisterRoutes(router *nrhttprouter.Router) { + router.GET("/.well-known/"+c.ServiceName, c.showCapabilities) // service discovery + + for key, cap := range c.callableCapabilities { + routerPath := c.templateToRouterPath(cap.Path) + router.Handle( + cap.Method, + routerPath, + cap.Handler, + ) + + c.Logger.Info().Msgf("Registering endpoint for capability: %s", key) + c.Logger.Debug().Msgf("Endpoint[%s]: %s %s", key, cap.Method, routerPath) + } +} - // P2P Receive Tx request (receives the P2P transaction, broadcasts, returns tx_id) - router.POST( - "/"+c.APIVersion+"/"+c.ServiceName+"/receive-transaction/:paymailAddress", - c.p2pReceiveTx, - ) +func (c *Configuration) templateToRouterPath(template string) string { + template = strings.ReplaceAll(template, PaymailAddressTemplate, _routerParam(PaymailAddressParamName)) + template = strings.ReplaceAll(template, PubKeyTemplate, _routerParam(PubKeyParamName)) + return fmt.Sprintf("/%s/%s/%s", c.APIVersion, c.ServiceName, strings.TrimPrefix(template, "/")) +} - // P2P BEEF capability Receive Tx request - router.POST( - "/"+c.APIVersion+"/"+c.ServiceName+"/beef/:paymailAddress", - c.p2pReceiveBeefTx, - ) +func _routerParam(name string) string { + return ":" + name } diff --git a/server/server.go b/server/server.go index 035df51..bb198b3 100644 --- a/server/server.go +++ b/server/server.go @@ -3,9 +3,9 @@ package server import ( "fmt" - "github.com/rs/zerolog" "net/http" - "strings" + + "github.com/rs/zerolog" ) // CreateServer will create a basic Paymail Server @@ -24,23 +24,3 @@ func StartServer(srv *http.Server, logger *zerolog.Logger) { logger.Info().Str("address", srv.Addr).Msg("starting go paymail server...") logger.Fatal().Msg(srv.ListenAndServe().Error()) } - -// getHost tries its best to return the request host -func getHost(r *http.Request) string { - if r.URL.IsAbs() { - return removePort(r.Host) - } - if len(r.URL.Host) == 0 { - return removePort(r.Host) - } - return r.URL.Host -} - -// removePort will attempt to remove the port if found -func removePort(host string) string { - // Slice off any port information. - if i := strings.Index(host, ":"); i != -1 { - host = host[:i] - } - return host -} diff --git a/server/server_test.go b/server/server_test.go index be08fb8..dbe3425 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,9 +1,12 @@ package server import ( - "context" + "encoding/json" "fmt" "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" "time" @@ -14,10 +17,9 @@ import ( // TestCreateServer will test the method CreateServer() func TestCreateServer(t *testing.T) { t.Run("valid config", func(t *testing.T) { - config := &Configuration{ - Port: 12345, - Timeout: 10 * time.Second, - } + config := testConfig(t, "localhost") + config.Port = 12345 + config.Timeout = 10 * time.Second s := CreateServer(config) require.NotNil(t, s) assert.IsType(t, &http.Server{}, s) @@ -27,68 +29,51 @@ func TestCreateServer(t *testing.T) { }) } -// TestStart will test the method Start() -func TestStart(t *testing.T) { - t.Run("run server", func(t *testing.T) { - /* - // todo: run in a non-blocking way to test - config := &Configuration{ - Port: 12345, - Timeout: 10 * time.Second, - } - s := CreateServer(config) - StartServer(s) - */ - }) -} +// TestWithServer will test if the server is running and responding to capabilities discovery & each capability is accessible +func TestWithServer(t *testing.T) { + t.Run("run server and check capabilities", func(t *testing.T) { + config, _ := NewConfig(new(mockServiceProvider), WithDomain("domain.com")) + config.Prefix = "http://" -// Test_removePort will test the method removePort() -func Test_removePort(t *testing.T) { - testDomain := "domain.com" + server := httptest.NewServer(Handlers(config)) + defer server.Close() - t.Run("valid removal", func(t *testing.T) { - host := testDomain + ":1234" - rp := removePort(host) - assert.Equal(t, rp, testDomain) - }) + err := config.AddDomain(server.URL) + assert.NoError(t, err) - t.Run("valid removal (no port)", func(t *testing.T) { - host := testDomain + ":" - rp := removePort(host) - assert.Equal(t, rp, testDomain) - }) + resp, err := http.Get(fmt.Sprintf("%s/.well-known/bsvalias", server.URL)) + if err != nil { + t.Fatalf("Failed to make GET request: %v", err) + } - t.Run("no port", func(t *testing.T) { - rp := removePort(testDomain) - assert.Equal(t, rp, testDomain) - }) -} + var result map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&result) + assert.NoError(t, err) + assert.Equal(t, result["bsvalias"], config.BSVAliasVersion) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() -// Test_getHost will test the method getHost() -func Test_getHost(t *testing.T) { - testDomain := "domain.com" + capabilities := result["capabilities"].(map[string]interface{}) + assert.NotNil(t, capabilities) + assert.Greater(t, len(capabilities), 0) - t.Run("valid host with port", func(t *testing.T) { - req, err := http.NewRequestWithContext( - context.Background(), http.MethodGet, - "http://"+testDomain+":1234", nil, - ) - require.NoError(t, err) - require.NotNil(t, req) + //Check if all callable capabilities are accessible by trying to make a request to each one of them + for _, cap := range capabilities { + capUrl, ok := cap.(string) + if !ok { + continue //skip static capabilities + } - host := getHost(req) - assert.Equal(t, testDomain, host) - }) + capUrl = strings.ReplaceAll(capUrl, PaymailAddressTemplate, "example@domain.com") + capUrl = strings.ReplaceAll(capUrl, PubKeyTemplate, "xpub") - t.Run("valid host with no port", func(t *testing.T) { - req, err := http.NewRequestWithContext( - context.Background(), http.MethodGet, - "http://"+testDomain+"/", nil, - ) - require.NoError(t, err) - require.NotNil(t, req) + _, err := url.Parse(capUrl) + assert.NoError(t, err, "Endpoint %s is not a valid URL", capUrl) - host := getHost(req) - assert.Equal(t, testDomain, host) + _, err = http.Get(capUrl) + + //Only verify if the current 'capUrl' endpoint is accessible, even if the 'GET' method is not permitted for it. + assert.NoError(t, err) + } }) } diff --git a/server/verify.go b/server/verify.go index 88e9e61..e86ed33 100644 --- a/server/verify.go +++ b/server/verify.go @@ -11,8 +11,8 @@ import ( // // Specs: https://bsvalias.org/05-verify-public-key-owner.html func (c *Configuration) verifyPubKey(w http.ResponseWriter, req *http.Request, p httprouter.Params) { - incomingPaymail := p.ByName("paymailAddress") - incomingPubKey := p.ByName("pubKey") + incomingPaymail := p.ByName(PaymailAddressParamName) + incomingPubKey := p.ByName(PubKeyParamName) // Parse, sanitize and basic validation alias, domain, address := paymail.SanitizePaymail(incomingPaymail) diff --git a/utilities.go b/utilities.go index 16bfe83..a0a20bd 100644 --- a/utilities.go +++ b/utilities.go @@ -8,12 +8,10 @@ import ( "time" ) -// emptySpace is an empty space for replacing -var emptySpace = []byte("") - var ( emailRegExp = regexp.MustCompile(`[^a-zA-Z0-9-_.@+]`) pathNameRegExp = regexp.MustCompile(`[^a-zA-Z0-9-_]`) + portRegExp = regexp.MustCompile(`:\d*$`) ) // SanitisedPaymail contains elements of a sanitized paymail address. @@ -47,7 +45,7 @@ func ValidateAndSanitisePaymail(paymail string, isBeta bool) (*SanitisedPaymail, func SanitizePaymail(paymailAddress string) (alias, domain, address string) { // Sanitize the paymail address - address = SanitizeEmail(paymailAddress, false) + address = SanitizeEmail(paymailAddress) // Split the email parts (alias @ domain) parts := strings.Split(address, "@") @@ -149,62 +147,55 @@ func SanitizeDomain(original string) (string, error) { return original, nil } - // Missing http? - if !strings.Contains(original, "http") { + if !strings.HasPrefix(original, "http") { + // The http part is temporary, we just need it for url.Parse to work original = "http://" + strings.TrimSpace(original) } - // Try to parse the url u, err := url.Parse(original) if err != nil { return original, err } - // Generally all domains should be uniform and lowercase - u.Host = strings.ToLower(u.Host) - - // Remove leading www. + u.Host = strings.ToLower(u.Host) // Generally all domains should be uniform and lowercase u.Host = strings.TrimPrefix(u.Host, "www.") + u.Host = removePort(u.Host) return u.Host, nil } +func removePort(host string) string { + return portRegExp.ReplaceAllString(host, "") +} + // SanitizeEmail will take an input and return the sanitized version // // This will sanitize the email address (force to lowercase, remove spaces, etc.) // Example: SanitizeEmail(" John.Doe@Gmail ", false) // Result: johndoe@gmail -func SanitizeEmail(original string, preserveCase bool) string { - - // Leave the email address in its original case - if preserveCase { - return string(emailRegExp.ReplaceAll( - []byte(strings.Replace(original, "mailto:", "", -1)), emptySpace), - ) - } +func SanitizeEmail(original string) string { + original = strings.ToLower(original) + original = strings.Replace(original, "mailto:", "", -1) + original = strings.TrimSpace(original) - // Standard is forced to lowercase - return string(emailRegExp.ReplaceAll( - []byte(strings.ToLower(strings.Replace(original, "mailto:", "", -1))), emptySpace), - ) + return emailRegExp.ReplaceAllString(original, "") } // SanitizePathName returns a formatted path compliant name. // // View examples: sanitize_test.go func SanitizePathName(original string) string { - return string(pathNameRegExp.ReplaceAll([]byte(original), emptySpace)) + return pathNameRegExp.ReplaceAllString(original, "") } // replaceAliasDomain will replace the alias and domain with the correct values func replaceAliasDomain(urlString, alias, domain string) string { - return strings.Replace( - strings.Replace(urlString, "{alias}", alias, -1), - "{domain.tld}", domain, -1, - ) + urlString = strings.ReplaceAll(urlString, "{alias}", alias) + urlString = strings.ReplaceAll(urlString, "{domain.tld}", domain) + return urlString } // replacePubKey will replace the PubKey with the correct values func replacePubKey(urlString, pubKey string) string { - return strings.Replace(urlString, "{pubkey}", pubKey, -1) + return strings.ReplaceAll(urlString, "{pubkey}", pubKey) } diff --git a/utilities_test.go b/utilities_test.go index f070e72..4784569 100644 --- a/utilities_test.go +++ b/utilities_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + + "github.com/stretchr/testify/assert" ) // TestSanitizePaymail will test the method SanitizePaymail() @@ -23,6 +25,8 @@ func TestSanitizePaymail(t *testing.T) { {"TEST@domain.com", "test", "domain.com", "test@domain.com"}, {"TEST@Domain.com", "test", "domain.com", "test@domain.com"}, {"TEST@DomaiN.COM", "test", "domain.com", "test@domain.com"}, + {"mailto:TEST@DomaiN.COM", "test", "domain.com", "test@domain.com"}, + {"MailTO:TEST@DomaiN.COM", "test", "domain.com", "test@domain.com"}, {"@DomaiN.COM", "", "domain.com", ""}, {"test@", "test", "", ""}, {"test@domain", "test", "domain", "test@domain"}, @@ -327,6 +331,32 @@ func TestValidateAndSanitisePaymail(t *testing.T) { } } +func Test_removePort(t *testing.T) { + testDomain := "domain.com" + + t.Run("valid removal", func(t *testing.T) { + host := testDomain + ":1234" + rp := removePort(host) + assert.Equal(t, rp, testDomain) + }) + + t.Run("valid removal (no port)", func(t *testing.T) { + host := testDomain + ":" + rp := removePort(host) + assert.Equal(t, rp, testDomain) + }) + + t.Run("no port", func(t *testing.T) { + rp := removePort(testDomain) + assert.Equal(t, rp, testDomain) + }) + + t.Run("remove port from full url", func(t *testing.T) { + rp := removePort("http://" + testDomain + ":1234") + assert.Equal(t, rp, "http://"+testDomain) + }) +} + // BenchmarkTestValidateAndSanitisePaymail benchmarks the method ValidateTimestamp() func BenchmarkTestValidateAndSanitisePaymail(b *testing.B) { for i := 0; i < b.N; i++ {