From 447c2df7698fd9080efe94dce3bb6e777b0d64a8 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Wed, 22 Nov 2017 21:55:10 -0800 Subject: [PATCH] Compare request header tokens with ASCII case folding Compare request header tokens with ASCII case folding per the WebSocket RFC. --- util.go | 28 ++++++++++++++++++++++++++-- util_test.go | 18 ++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/util.go b/util.go index 262e647b..60ca26da 100644 --- a/util.go +++ b/util.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "strings" + "unicode/utf8" ) var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") @@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) { return "", "" } +// equalASCIIFold returns true if s is equal to t with ASCII case folding. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + // tokenListContainsValue returns true if the 1#token header with the given -// name contains token. +// name contains a token equal to value with ASCII case folding. func tokenListContainsValue(header http.Header, name string, value string) bool { headers: for _, s := range header[name] { @@ -142,7 +166,7 @@ headers: if s != "" && s[0] != ',' { continue headers } - if strings.EqualFold(t, value) { + if equalASCIIFold(t, value) { return true } if s == "" { diff --git a/util_test.go b/util_test.go index ab11c3f9..cdc3d36d 100644 --- a/util_test.go +++ b/util_test.go @@ -10,6 +10,24 @@ import ( "testing" ) +var equalASCIIFoldTests = []struct { + t, s string + eq bool +}{ + {"WebSocket", "websocket", true}, + {"websocket", "WebSocket", true}, + {"Öyster", "öyster", false}, +} + +func TestEqualASCIIFold(t *testing.T) { + for _, tt := range equalASCIIFoldTests { + eq := equalASCIIFold(tt.s, tt.t) + if eq != tt.eq { + t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) + } + } +} + var tokenListContainsValueTests = []struct { value string ok bool