Skip to content

Commit

Permalink
[sasl] use a SCRAM client for each connection
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsinham authored and Julien LEFEVRE committed Apr 15, 2019
1 parent e4271b1 commit 5722d69
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
return err
}

scramClient := b.conf.Net.SASL.SCRAMClient
scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {

conf := NewConfig()
conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
conf.Net.SASL.SCRAMClient = test.scramClient
conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient }

broker.conf = conf
dialer := net.Dialer{
Expand Down
8 changes: 4 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ type Config struct {
Password string
// authz id used for SASL/SCRAM authentication
SCRAMAuthzID string
// SCRAMClient is a user provided implementation of a SCRAM
// SCRAMClientGeneratorFunc is a generator of a user provided implementation of a SCRAM
// client used to perform the SCRAM exchange with the server.
SCRAMClient SCRAMClient
SCRAMClientGeneratorFunc func() SCRAMClient
// TokenProvider is a user-defined callback for generating
// access tokens for SASL/OAUTHBEARER auth. See the
// AccessTokenProvider interface docs for proper implementation
Expand Down Expand Up @@ -503,8 +503,8 @@ func (c *Config) Validate() error {
if c.Net.SASL.Password == "" {
return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
}
if c.Net.SASL.SCRAMClient == nil {
return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
if c.Net.SASL.SCRAMClientGeneratorFunc == nil {
return ConfigurationError("A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc")
}
default:
msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",
Expand Down
8 changes: 4 additions & 4 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,20 @@ func TestNetConfigValidates(t *testing.T) {
func(cfg *Config) {
cfg.Net.SASL.Enable = true
cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256
cfg.Net.SASL.SCRAMClient = nil
cfg.Net.SASL.SCRAMClientGeneratorFunc = nil
cfg.Net.SASL.User = "user"
cfg.Net.SASL.Password = "stong_password"
},
"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
"A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"},
{"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client",
func(cfg *Config) {
cfg.Net.SASL.Enable = true
cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
cfg.Net.SASL.SCRAMClient = nil
cfg.Net.SASL.SCRAMClientGeneratorFunc = nil
cfg.Net.SASL.User = "user"
cfg.Net.SASL.Password = "stong_password"
},
"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
"A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"},
}

for i, test := range tests {
Expand Down
4 changes: 2 additions & 2 deletions examples/sasl_scram_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ func main() {
conf.Net.SASL.Password = *passwd
conf.Net.SASL.Handshake = true
if *algorithm == "sha512" {
conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA512}
conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA512} }
conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA512)
} else if *algorithm == "sha256" {
conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA256}
conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA256} }
conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA256)

} else {
Expand Down

0 comments on commit 5722d69

Please sign in to comment.