From 5722d69b8e434b9909b79dbdd59d2a514264f61b Mon Sep 17 00:00:00 2001 From: Julien LEFEVRE Date: Mon, 1 Apr 2019 16:44:01 +0200 Subject: [PATCH] [sasl] use a SCRAM client for each connection --- broker.go | 2 +- broker_test.go | 2 +- config.go | 8 ++++---- config_test.go | 8 ++++---- examples/sasl_scram_client/main.go | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/broker.go b/broker.go index 12ee4f07d..00199bc28 100644 --- a/broker.go +++ b/broker.go @@ -1002,7 +1002,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error { return err } - scramClient := b.conf.Net.SASL.SCRAMClient + scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc() if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil { return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error()) } diff --git a/broker_test.go b/broker_test.go index a3b17af4f..ce8149d58 100644 --- a/broker_test.go +++ b/broker_test.go @@ -340,7 +340,7 @@ func TestSASLSCRAMSHAXXX(t *testing.T) { conf := NewConfig() conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 - conf.Net.SASL.SCRAMClient = test.scramClient + conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient } broker.conf = conf dialer := net.Dialer{ diff --git a/config.go b/config.go index 6fa8bb940..3adc186af 100644 --- a/config.go +++ b/config.go @@ -66,9 +66,9 @@ type Config struct { Password string // authz id used for SASL/SCRAM authentication SCRAMAuthzID string - // SCRAMClient is a user provided implementation of a SCRAM + // SCRAMClientGeneratorFunc is a generator of a user provided implementation of a SCRAM // client used to perform the SCRAM exchange with the server. - SCRAMClient SCRAMClient + SCRAMClientGeneratorFunc func() SCRAMClient // TokenProvider is a user-defined callback for generating // access tokens for SASL/OAUTHBEARER auth. See the // AccessTokenProvider interface docs for proper implementation @@ -503,8 +503,8 @@ func (c *Config) Validate() error { if c.Net.SASL.Password == "" { return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled") } - if c.Net.SASL.SCRAMClient == nil { - return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient") + if c.Net.SASL.SCRAMClientGeneratorFunc == nil { + return ConfigurationError("A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc") } default: msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`", diff --git a/config_test.go b/config_test.go index 3ba9a2023..23538e2f3 100644 --- a/config_test.go +++ b/config_test.go @@ -103,20 +103,20 @@ func TestNetConfigValidates(t *testing.T) { func(cfg *Config) { cfg.Net.SASL.Enable = true cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256 - cfg.Net.SASL.SCRAMClient = nil + cfg.Net.SASL.SCRAMClientGeneratorFunc = nil cfg.Net.SASL.User = "user" cfg.Net.SASL.Password = "stong_password" }, - "A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"}, + "A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"}, {"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client", func(cfg *Config) { cfg.Net.SASL.Enable = true cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 - cfg.Net.SASL.SCRAMClient = nil + cfg.Net.SASL.SCRAMClientGeneratorFunc = nil cfg.Net.SASL.User = "user" cfg.Net.SASL.Password = "stong_password" }, - "A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"}, + "A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"}, } for i, test := range tests { diff --git a/examples/sasl_scram_client/main.go b/examples/sasl_scram_client/main.go index 2a28fffb8..f14acb351 100644 --- a/examples/sasl_scram_client/main.go +++ b/examples/sasl_scram_client/main.go @@ -86,10 +86,10 @@ func main() { conf.Net.SASL.Password = *passwd conf.Net.SASL.Handshake = true if *algorithm == "sha512" { - conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA512} + conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA512} } conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA512) } else if *algorithm == "sha256" { - conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA256} + conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA256} } conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA256) } else {