diff --git a/protocol/assertion.go b/protocol/assertion.go index a34e7dc6..a82313a2 100644 --- a/protocol/assertion.go +++ b/protocol/assertion.go @@ -60,8 +60,8 @@ func ParseCredentialRequestResponse(response *http.Request) (*ParsedCredentialAs func ParseCredentialRequestResponseBody(body io.Reader) (par *ParsedCredentialAssertionData, err error) { var car CredentialAssertionResponse - if err = json.NewDecoder(body).Decode(&car); err != nil { - return nil, ErrBadRequest.WithDetails("Parse error for Assertion") + if err = decodeBody(body, &car); err != nil { + return nil, ErrBadRequest.WithDetails("Parse error for Assertion").WithInfo(err.Error()) } return car.Parse() diff --git a/protocol/assertion_test.go b/protocol/assertion_test.go index 4fc015d8..335f9a13 100644 --- a/protocol/assertion_test.go +++ b/protocol/assertion_test.go @@ -27,10 +27,13 @@ func TestParseCredentialRequestResponse(t *testing.T) { } testCases := []struct { - name string - args args - expected *ParsedCredentialAssertionData - errString string + name string + args args + expected *ParsedCredentialAssertionData + errString string + errType string + errDetails string + errInfo string }{ { name: "ShouldParseCredentialAssertion", @@ -91,6 +94,17 @@ func TestParseCredentialRequestResponse(t *testing.T) { }, errString: "", }, + { + name: "ShouldHandleTrailingData", + args: args{ + "trailingData", + }, + expected: nil, + errString: "Parse error for Assertion", + errType: "invalid_request", + errDetails: "Parse error for Assertion", + errInfo: "The body contains trailing data", + }, } for _, tc := range testCases { @@ -104,6 +118,8 @@ func TestParseCredentialRequestResponse(t *testing.T) { if tc.errString != "" { assert.EqualError(t, err, tc.errString) + AssertIsProtocolError(t, err, tc.errType, tc.errDetails, tc.errInfo) + return } @@ -185,4 +201,18 @@ var testAssertionResponses = map[string]string{ "userHandle":"0ToAAAAAAAAAAA"} } `, + `trailingData`: `{ + "id":"AI7D5q2P0LS-Fal9ZT7CHM2N5BLbUunF92T8b6iYC199bO2kagSuU05-5dZGqb1SP0A0lyTWng", + "rawId":"AI7D5q2P0LS-Fal9ZT7CHM2N5BLbUunF92T8b6iYC199bO2kagSuU05-5dZGqb1SP0A0lyTWng", + "clientExtensionResults":{"appID":"example.com"}, + "type":"public-key", + "response":{ + "authenticatorData":"dKbqkhPJnC90siSSsyDPQCYqlMGpUKA5fyklC2CEHvBFXJJiGa3OAAI1vMYKZIsLJfHwVQMANwCOw-atj9C0vhWpfWU-whzNjeQS21Lpxfdk_G-omAtffWztpGoErlNOfuXWRqm9Uj9ANJck1p6lAQIDJiABIVggKAhfsdHcBIc0KPgAcRyAIK_-Vi-nCXHkRHPNaCMBZ-4iWCBxB8fGYQSBONi9uvq0gv95dGWlhJrBwCsj_a4LJQKVHQ", + "clientDataJSON":"eyJjaGFsbGVuZ2UiOiJFNFBUY0lIX0hmWDFwQzZTaWdrMVNDOU5BbGdlenROMDQzOXZpOHpfYzlrIiwibmV3X2tleXNfbWF5X2JlX2FkZGVkX2hlcmUiOiJkbyBub3QgY29tcGFyZSBjbGllbnREYXRhSlNPTiBhZ2FpbnN0IGEgdGVtcGxhdGUuIFNlZSBodHRwczovL2dvby5nbC95YWJQZXgiLCJvcmlnaW4iOiJodHRwczovL3dlYmF1dGhuLmlvIiwidHlwZSI6IndlYmF1dGhuLmdldCJ9", + "signature":"MEUCIBtIVOQxzFYdyWQyxaLR0tik1TnuPhGVhXVSNgFwLmN5AiEAnxXdCq0UeAVGWxOaFcjBZ_mEZoXqNboY5IkQDdlWZYc", + "userHandle":"0ToAAAAAAAAAAA"} + } + +trailing + `, } diff --git a/protocol/credential.go b/protocol/credential.go index d3753af9..bb9782b0 100644 --- a/protocol/credential.go +++ b/protocol/credential.go @@ -3,7 +3,6 @@ package protocol import ( "crypto/sha256" "encoding/base64" - "encoding/json" "io" "net/http" ) @@ -48,7 +47,7 @@ type CredentialCreationResponse struct { PublicKeyCredential AttestationResponse AuthenticatorAttestationResponse `json:"response"` - // Deprecated: Transports is deprecated due to upstream changes to the API. + // Deprecated: Transports is deprecated due to upstream changes to the API. // Use the Transports field of AuthenticatorAttestationResponse // instead. Transports is kept for backward compatibility, and should not // be used by new clients. @@ -61,18 +60,25 @@ type ParsedCredentialCreationData struct { Raw CredentialCreationResponse } +// ParseCredentialCreationResponse is a non-agnostic function for parsing a registration response from the http library +// from stdlib. It handles some standard cleanup operations. func ParseCredentialCreationResponse(response *http.Request) (*ParsedCredentialCreationData, error) { if response == nil || response.Body == nil { return nil, ErrBadRequest.WithDetails("No response given") } + defer response.Body.Close() + defer io.Copy(io.Discard, response.Body) + return ParseCredentialCreationResponseBody(response.Body) } +// ParseCredentialCreationResponseBody is an agnostic version of ParseCredentialCreationResponse. Implementers are +// therefore responsible for managing cleanup. func ParseCredentialCreationResponseBody(body io.Reader) (pcc *ParsedCredentialCreationData, err error) { var ccr CredentialCreationResponse - if err = json.NewDecoder(body).Decode(&ccr); err != nil { + if err = decodeBody(body, &ccr); err != nil { return nil, ErrBadRequest.WithDetails("Parse error for Registration").WithInfo(err.Error()) } diff --git a/protocol/credential_test.go b/protocol/credential_test.go index 89aca29e..7c0b819d 100644 --- a/protocol/credential_test.go +++ b/protocol/credential_test.go @@ -24,10 +24,13 @@ func TestParseCredentialCreationResponse(t *testing.T) { byteClientDataJSON, _ := base64.RawURLEncoding.DecodeString("eyJjaGFsbGVuZ2UiOiJXOEd6RlU4cEdqaG9SYldyTERsYW1BZnFfeTRTMUNaRzFWdW9lUkxBUnJFIiwib3JpZ2luIjoiaHR0cHM6Ly93ZWJhdXRobi5pbyIsInR5cGUiOiJ3ZWJhdXRobi5jcmVhdGUifQ") testCases := []struct { - name string - args args - expected *ParsedCredentialCreationData - errString string + name string + args args + expected *ParsedCredentialCreationData + errString string + errType string + errDetails string + errInfo string }{ { name: "ShouldParseCredentialRequest", @@ -215,6 +218,17 @@ func TestParseCredentialCreationResponse(t *testing.T) { }, errString: "", }, + { + name: "ShouldHandleTrailingData", + args: args{ + responseName: "trailingData", + }, + expected: nil, + errString: "Parse error for Registration", + errType: "invalid_request", + errDetails: "Parse error for Registration", + errInfo: "The body contains trailing data", + }, } for _, tc := range testCases { @@ -226,6 +240,8 @@ func TestParseCredentialCreationResponse(t *testing.T) { if tc.errString != "" { assert.EqualError(t, err, tc.errString) + AssertIsProtocolError(t, err, tc.errType, tc.errDetails, tc.errInfo) + return } @@ -371,6 +387,24 @@ var testCredentialRequestResponses = map[string]string{ "transports":["usb","nfc","fake"] } } +`, + `trailingData`: ` +{ + "id":"6xrtBhJQW6QU4tOaB4rrHaS2Ks0yDDL_q8jDC16DEjZ-VLVf4kCRkvl2xp2D71sTPYns-exsHQHTy3G-zJRK8g", + "rawId":"6xrtBhJQW6QU4tOaB4rrHaS2Ks0yDDL_q8jDC16DEjZ-VLVf4kCRkvl2xp2D71sTPYns-exsHQHTy3G-zJRK8g", + "type":"public-key", + "authenticatorAttachment":"platform", + "clientExtensionResults":{ + "appid":true + }, + "response":{ + "attestationObject":"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YVjEdKbqkhPJnC90siSSsyDPQCYqlMGpUKA5fyklC2CEHvBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAQOsa7QYSUFukFOLTmgeK6x2ktirNMgwy_6vIwwtegxI2flS1X-JAkZL5dsadg-9bEz2J7PnsbB0B08txvsyUSvKlAQIDJiABIVggLKF5xS0_BntttUIrm2Z2tgZ4uQDwllbdIfrrBMABCNciWCDHwin8Zdkr56iSIh0MrB5qZiEzYLQpEOREhMUkY6q4Vw", + "clientDataJSON":"eyJjaGFsbGVuZ2UiOiJXOEd6RlU4cEdqaG9SYldyTERsYW1BZnFfeTRTMUNaRzFWdW9lUkxBUnJFIiwib3JpZ2luIjoiaHR0cHM6Ly93ZWJhdXRobi5pbyIsInR5cGUiOiJ3ZWJhdXRobi5jcmVhdGUifQ", + "transports":["usb","nfc","fake"] + } +} + +trailing `, `successDeprecatedTransports`: ` { diff --git a/protocol/decoder.go b/protocol/decoder.go new file mode 100644 index 00000000..92e8a81c --- /dev/null +++ b/protocol/decoder.go @@ -0,0 +1,23 @@ +package protocol + +import ( + "encoding/json" + "errors" + "io" +) + +func decodeBody(body io.Reader, v any) (err error) { + decoder := json.NewDecoder(body) + + if err = decoder.Decode(v); err != nil { + return err + } + + _, err = decoder.Token() + + if !errors.Is(err, io.EOF) { + return errors.New("The body contains trailing data") + } + + return nil +} diff --git a/protocol/func_test.go b/protocol/func_test.go new file mode 100644 index 00000000..933b4ad5 --- /dev/null +++ b/protocol/func_test.go @@ -0,0 +1,19 @@ +package protocol + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func AssertIsProtocolError(t *testing.T, err error, errType, errDetails, errInfo string) { + var e *Error + + require.True(t, errors.As(err, &e)) + + assert.Equal(t, errType, e.Type) + assert.Equal(t, errDetails, e.Details) + assert.Equal(t, errInfo, e.DevInfo) +}