From 931887016ad7483b15e16bbf2dab7e5b5d9c886a Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Thu, 13 Jul 2023 09:13:32 +0200 Subject: [PATCH] SNOW-857829 Fix username and password requiredness --- dsn.go | 22 +++-- dsn_test.go | 264 ++++++++++++++++++++++++++++++++-------------------- 2 files changed, 177 insertions(+), 109 deletions(-) diff --git a/dsn.go b/dsn.go index c4df6a130..f5a34298f 100644 --- a/dsn.go +++ b/dsn.go @@ -390,16 +390,11 @@ func fillMissingConfigParameters(cfg *Config) error { return ErrEmptyAccount } - if cfg.Authenticator != AuthTypeOAuth && strings.Trim(cfg.User, " ") == "" { - // oauth does not require a username + if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { return ErrEmptyUsername } - if cfg.Authenticator != AuthTypeExternalBrowser && - cfg.Authenticator != AuthTypeOAuth && - cfg.Authenticator != AuthTypeJwt && - strings.Trim(cfg.Password, " ") == "" { - // no password parameter is required for EXTERNALBROWSER, OAUTH or JWT. + if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { return ErrEmptyPassword } if strings.Trim(cfg.Protocol, " ") == "" { @@ -467,6 +462,19 @@ func fillMissingConfigParameters(cfg *Config) error { return nil } +func authRequiresUser(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser +} + +func authRequiresPassword(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser && + cfg.Authenticator != AuthTypeJwt +} + // transformAccountToHost transforms host to account name func transformAccountToHost(cfg *Config) (err error) { if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" { diff --git a/dsn_test.go b/dsn_test.go index bc5dbd5f0..25fd8ede5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -582,112 +582,172 @@ func TestParseDSN(t *testing.T) { }, } + for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:777/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 123, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA, AuthTypeJwt} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 123, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: ErrEmptyUsername, + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("user@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "user", Password: "", + Protocol: "http", Host: "host", Port: 123, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: ErrEmptyPassword, + }) + } + for i, test := range testcases { - // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) - cfg, err := ParseDSN(test.dsn) - switch { - case test.err == nil: - if err != nil { - t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) - } - if test.config.Host != cfg.Host { - t.Fatalf("%d: Failed to match host. expected: %v, got: %v", - i, test.config.Host, cfg.Host) - } - if test.config.Account != cfg.Account { - t.Fatalf("%d: Failed to match account. expected: %v, got: %v", - i, test.config.Account, cfg.Account) - } - if test.config.User != cfg.User { - t.Fatalf("%d: Failed to match user. expected: %v, got: %v", - i, test.config.User, cfg.User) - } - if test.config.Password != cfg.Password { - t.Fatalf("%d: Failed to match password. expected: %v, got: %v", - i, test.config.Password, cfg.Password) - } - if test.config.Database != cfg.Database { - t.Fatalf("%d: Failed to match database. expected: %v, got: %v", - i, test.config.Database, cfg.Database) - } - if test.config.Schema != cfg.Schema { - t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", - i, test.config.Schema, cfg.Schema) - } - if test.config.Warehouse != cfg.Warehouse { - t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", - i, test.config.Warehouse, cfg.Warehouse) - } - if test.config.Role != cfg.Role { - t.Fatalf("%d: Failed to match role. expected: %v, got: %v", - i, test.config.Role, cfg.Role) - } - if test.config.Region != cfg.Region { - t.Fatalf("%d: Failed to match region. expected: %v, got: %v", - i, test.config.Region, cfg.Region) - } - if test.config.Protocol != cfg.Protocol { - t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", - i, test.config.Protocol, cfg.Protocol) - } - if test.config.Passcode != cfg.Passcode { - t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", - i, test.config.Passcode, cfg.Passcode) - } - if test.config.PasscodeInPassword != cfg.PasscodeInPassword { - t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", - i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) - } - if test.config.Authenticator != cfg.Authenticator { - t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", - i, test.config.Authenticator.String(), cfg.Authenticator.String()) - } - if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { - t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", - i, test.config.OktaURL, cfg.OktaURL) - } - if test.config.OCSPFailOpen != cfg.OCSPFailOpen { - t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", - i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) - } - if test.ocspMode != cfg.ocspMode() { - t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", - i, test.ocspMode, cfg.ocspMode()) - } - if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { - t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", - i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) - } - if test.config.ClientTimeout != cfg.ClientTimeout { - t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", - i, test.config.ClientTimeout, cfg.ClientTimeout) - } - if test.config.JWTClientTimeout != cfg.JWTClientTimeout { - t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v", - i, test.config.JWTClientTimeout, cfg.JWTClientTimeout) - } - if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout { - t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", - i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) - } - case test.err != nil: - driverErrE, okE := test.err.(*SnowflakeError) - driverErrG, okG := err.(*SnowflakeError) - if okE && !okG || !okE && okG { - t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) - } - if okE && okG { - if driverErrE.Number != driverErrG.Number { - t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + t.Run("TestParseDSN", func(t *testing.T) { + // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) + cfg, err := ParseDSN(test.dsn) + switch { + case test.err == nil: + if err != nil { + t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) + } + if test.config.Host != cfg.Host { + t.Fatalf("%d: Failed to match host. expected: %v, got: %v", + i, test.config.Host, cfg.Host) + } + if test.config.Account != cfg.Account { + t.Fatalf("%d: Failed to match account. expected: %v, got: %v", + i, test.config.Account, cfg.Account) } - } else { - t1 := reflect.TypeOf(err) - t2 := reflect.TypeOf(test.err) - if t1 != t2 { - t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + if test.config.User != cfg.User { + t.Fatalf("%d: Failed to match user. expected: %v, got: %v", + i, test.config.User, cfg.User) + } + if test.config.Password != cfg.Password { + t.Fatalf("%d: Failed to match password. expected: %v, got: %v", + i, test.config.Password, cfg.Password) + } + if test.config.Database != cfg.Database { + t.Fatalf("%d: Failed to match database. expected: %v, got: %v", + i, test.config.Database, cfg.Database) + } + if test.config.Schema != cfg.Schema { + t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", + i, test.config.Schema, cfg.Schema) + } + if test.config.Warehouse != cfg.Warehouse { + t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", + i, test.config.Warehouse, cfg.Warehouse) + } + if test.config.Role != cfg.Role { + t.Fatalf("%d: Failed to match role. expected: %v, got: %v", + i, test.config.Role, cfg.Role) + } + if test.config.Region != cfg.Region { + t.Fatalf("%d: Failed to match region. expected: %v, got: %v", + i, test.config.Region, cfg.Region) + } + if test.config.Protocol != cfg.Protocol { + t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", + i, test.config.Protocol, cfg.Protocol) + } + if test.config.Passcode != cfg.Passcode { + t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", + i, test.config.Passcode, cfg.Passcode) + } + if test.config.PasscodeInPassword != cfg.PasscodeInPassword { + t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", + i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) + } + if test.config.Authenticator != cfg.Authenticator { + t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", + i, test.config.Authenticator.String(), cfg.Authenticator.String()) + } + if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { + t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", + i, test.config.OktaURL, cfg.OktaURL) + } + if test.config.OCSPFailOpen != cfg.OCSPFailOpen { + t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", + i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) + } + if test.ocspMode != cfg.ocspMode() { + t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", + i, test.ocspMode, cfg.ocspMode()) + } + if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { + t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", + i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) + } + if test.config.ClientTimeout != cfg.ClientTimeout { + t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", + i, test.config.ClientTimeout, cfg.ClientTimeout) + } + if test.config.JWTClientTimeout != cfg.JWTClientTimeout { + t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v", + i, test.config.JWTClientTimeout, cfg.JWTClientTimeout) + } + if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout { + t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", + i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) + } + case test.err != nil: + driverErrE, okE := test.err.(*SnowflakeError) + driverErrG, okG := err.(*SnowflakeError) + if okE && !okG || !okE && okG { + t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) + } + if okE && okG { + if driverErrE.Number != driverErrG.Number { + t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + } + } else { + t1 := reflect.TypeOf(err) + t2 := reflect.TypeOf(test.err) + if t1 != t2 { + t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + } } } - } + + }) } }