Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support SaslHandshakeRequest v1 #1354

Merged
merged 1 commit into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"sync/atomic"
"time"

"github.com/rcrowley/go-metrics"
metrics "github.com/rcrowley/go-metrics"
)

// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
Expand Down Expand Up @@ -905,8 +905,10 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
return nil
}

// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149)
// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9
// Kafka 0.10.x supported SASL PLAIN/Kerberos via KAFKA-3149 (KIP-43).
// Kafka 1.x.x onward added a SaslAuthenticate request/response message which
// wraps the SASL flow in the Kafka protocol, which allows for returning
// meaningful errors on authentication failure.
//
// In SASL Plain, Kafka expects the auth header to be in the following format
// Message format (from https://tools.ietf.org/html/rfc4616):
Expand All @@ -920,18 +922,37 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
// SAFE = UTF1 / UTF2 / UTF3 / UTF4
// ;; any UTF-8 encoded Unicode character except NUL
//
// With SASL v0 handshake and auth then:
// When credentials are valid, Kafka returns a 4 byte array of null characters.
// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
// of responding to bad credentials but thats how its being done today.
// When credentials are invalid, Kafka closes the connection.
//
// With SASL v1 handshake and auth then:
// When credentials are invalid, Kafka replies with a SaslAuthenticate response
// containing an error code and message detailing the authentication failure.
func (b *Broker) sendAndReceiveSASLPlainAuth() error {
// default to V0 to allow for backward compatability when SASL is enabled
// but not the handshake
saslHandshake := SASLHandshakeV0
if b.conf.Net.SASL.Handshake {
handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, SASLHandshakeV0)
if b.conf.Version.IsAtLeast(V1_0_0_0) {
saslHandshake = SASLHandshakeV1
}
handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, saslHandshake)
if handshakeErr != nil {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
}
}

if saslHandshake == SASLHandshakeV1 {
return b.sendAndReceiveV1SASLPlainAuth()
}
return b.sendAndReceiveV0SASLPlainAuth()
}

// sendAndReceiveV0SASLPlainAuth flows the v0 sasl auth NOT wrapped in the kafka protocol
func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {

length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
authBytes := make([]byte, length+4) //4 byte length header + auth data
binary.BigEndian.PutUint32(authBytes, uint32(length))
Expand Down Expand Up @@ -965,6 +986,35 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
return nil
}

// sendAndReceiveV1SASLPlainAuth flows the v1 sasl authentication using the kafka protocol
func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
correlationID := b.correlationID

requestTime := time.Now()

bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)

b.updateOutgoingCommunicationMetrics(bytesWritten)

if err != nil {
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(correlationID)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
if err != nil {
Logger.Printf("Error returned from broker during SASL flow %s: %s\n", b.addr, err.Error())
return err
}

return nil
}

// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
Expand All @@ -988,7 +1038,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
b.updateOutgoingCommunicationMetrics(bytesWritten)
b.correlationID++

bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
bytesRead, err := b.receiveSASLServerResponse(correlationID)
if err != nil {
return err
}
Expand Down Expand Up @@ -1123,6 +1173,23 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
return strings.Join(buf, elemSep)
}

func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
authBytes := []byte("\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
rb := &SaslAuthenticateRequest{authBytes}
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
}

err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
return 0, err
}
return b.conn.Write(buf)
}

func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
initialResp, err := buildClientInitialResponse(token)
if err != nil {
Expand All @@ -1145,7 +1212,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
return b.conn.Write(buf)
}

func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {

buf := make([]byte, responseLengthSize+correlationIDSize)

Expand Down
100 changes: 100 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,106 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {
}
}

func TestSASLPlainAuth(t *testing.T) {

testTable := []struct {
name string
mockAuthErr KError // Mock and expect error returned from SaslAuthenticateRequest
mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
expectClientErr bool // Expect an internal client-side error
}{
{
name: "SASL Plain OK server response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
},
{
name: "SASL Plain authentication failure response",
mockAuthErr: ErrSASLAuthenticationFailed,
mockHandshakeErr: ErrNoError,
},
{
name: "SASL Plain handshake failure response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrSASLAuthenticationFailed,
},
}

for i, test := range testTable {

// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`))

if test.mockAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
}

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
})

// broker executes SASL requests against mockBroker
broker := NewBroker(mockBroker.Addr())
broker.requestRate = metrics.NilMeter{}
broker.outgoingByteRate = metrics.NilMeter{}
broker.incomingByteRate = metrics.NilMeter{}
broker.requestSize = metrics.NilHistogram{}
broker.responseSize = metrics.NilHistogram{}
broker.responseRate = metrics.NilMeter{}
broker.requestLatency = metrics.NilHistogram{}

conf := NewConfig()
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V1_0_0_0
dialer := net.Dialer{
Timeout: conf.Net.DialTimeout,
KeepAlive: conf.Net.KeepAlive,
LocalAddr: conf.Net.LocalAddr,
}

conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())

if err != nil {
t.Fatal(err)
}

broker.conn = conn

err = broker.authenticateViaSASL()

if test.mockAuthErr != ErrNoError {
if test.mockAuthErr != err {
t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
}
} else if test.mockHandshakeErr != ErrNoError {
if test.mockHandshakeErr != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
}
} else if test.expectClientErr && err == nil {
t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
} else if !test.expectClientErr && err != nil {
t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
}

mockBroker.Close()
}
}

func TestBuildClientInitialResponse(t *testing.T) {

testTable := []struct {
Expand Down
12 changes: 12 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,18 @@ func (client *client) tryRefreshMetadata(topics []string, attemptsRemaining int)
case PacketEncodingError:
// didn't even send, return the error
return err

case KError:
// if SASL auth error return as this _should_ be a non retryable err for all brokers
if err.(KError) == ErrSASLAuthenticationFailed {
Logger.Println("client/metadata failed SASL authentication")
return err
}
// else remove that broker and try again
Logger.Printf("client/metadata got error from broker %d while fetching metadata: %v\n", broker.ID(), err)
_ = broker.Close()
client.deregisterBroker(broker)

default:
// some other error, remove that broker and try again
Logger.Printf("client/metadata got error from broker %d while fetching metadata: %v\n", broker.ID(), err)
Expand Down