diff --git a/dn.go b/dn.go index 7478919..a6083d6 100644 --- a/dn.go +++ b/dn.go @@ -1,13 +1,14 @@ package ldap import ( - "bytes" "encoding/asn1" "encoding/hex" "errors" "fmt" "sort" "strings" + "unicode" + "unicode/utf8" ) // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 @@ -34,6 +35,9 @@ func (a *AttributeTypeAndValue) setValue(s string) error { // AttributeValue is represented by an number sign ('#' U+0023) // character followed by the hexadecimal encoding of each of the octets // of the BER encoding of the X.500 AttributeValue. + // + // WARNING: we only support hex-encoded ASN.1 DER values here, not + // BER encoding. This is a deviation from the RFC. if len(s) > 0 && s[0] == '#' { decodedString, err := decodeEncodedString(s[1:]) if err != nil { @@ -56,59 +60,7 @@ func (a *AttributeTypeAndValue) setValue(s string) error { // String returns a normalized string representation of this attribute type and // value pair which is the lowercase join of the Type and Value with a "=". func (a *AttributeTypeAndValue) String() string { - return strings.ToLower(a.Type) + "=" + a.encodeValue() -} - -func (a *AttributeTypeAndValue) encodeValue() string { - // Normalize the value first. - // value := strings.ToLower(a.Value) - value := a.Value - - encodedBuf := bytes.Buffer{} - - escapeChar := func(c byte) { - encodedBuf.WriteByte('\\') - encodedBuf.WriteByte(c) - } - - escapeHex := func(c byte) { - encodedBuf.WriteByte('\\') - encodedBuf.WriteString(hex.EncodeToString([]byte{c})) - } - - for i := 0; i < len(value); i++ { - char := value[i] - if i == 0 && char == ' ' || char == '#' { - // Special case leading space or number sign. - escapeChar(char) - continue - } - if i == len(value)-1 && char == ' ' { - // Special case trailing space. - escapeChar(char) - continue - } - - switch char { - case '"', '+', ',', ';', '<', '>', '\\': - // Each of these special characters must be escaped. - escapeChar(char) - continue - } - - if char < ' ' || char > '~' { - // All special character escapes are handled first - // above. All bytes less than ASCII SPACE and all bytes - // greater than ASCII TILDE must be hex-escaped. - escapeHex(char) - continue - } - - // Any other character does not require escaping. - encodedBuf.WriteByte(char) - } - - return encodedBuf.String() + return encodeString(foldString(a.Type), false) + "=" + encodeString(a.Value, true) } // RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514 @@ -119,12 +71,29 @@ type RelativeDN struct { // String returns a normalized string representation of this relative DN which // is the a join of all attributes (sorted in increasing order) with a "+". func (r *RelativeDN) String() string { - attrs := make([]string, len(r.Attributes)) - for i := range r.Attributes { - attrs[i] = r.Attributes[i].String() + builder := strings.Builder{} + sortedAttributes := make([]*AttributeTypeAndValue, len(r.Attributes)) + copy(sortedAttributes, r.Attributes) + sortAttributes(sortedAttributes) + for i, atv := range sortedAttributes { + builder.WriteString(atv.String()) + if i < len(sortedAttributes)-1 { + builder.WriteByte('+') + } } - sort.Strings(attrs) - return strings.Join(attrs, "+") + return builder.String() +} + +func sortAttributes(atvs []*AttributeTypeAndValue) { + sort.Slice(atvs, func(i, j int) bool { + ti := foldString(atvs[i].Type) + tj := foldString(atvs[j].Type) + if ti != tj { + return ti < tj + } + + return atvs[i].Value < atvs[j].Value + }) } // DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514 @@ -135,11 +104,25 @@ type DN struct { // String returns a normalized string representation of this DN which is the // join of all relative DNs with a ",". func (d *DN) String() string { - rdns := make([]string, len(d.RDNs)) - for i := range d.RDNs { - rdns[i] = d.RDNs[i].String() + builder := strings.Builder{} + for i, rdn := range d.RDNs { + builder.WriteString(rdn.String()) + if i < len(d.RDNs)-1 { + builder.WriteByte(',') + } } - return strings.Join(rdns, ",") + return builder.String() +} + +func stripLeadingAndTrailingSpaces(inVal string) string { + noSpaces := strings.Trim(inVal, " ") + + // Re-add the trailing space if it was an escaped space + if len(noSpaces) > 0 && noSpaces[len(noSpaces)-1] == '\\' && inVal[len(inVal)-1] == ' ' { + noSpaces = noSpaces + " " + } + + return noSpaces } // Remove leading and trailing spaces from the attribute type and value @@ -147,11 +130,7 @@ func (d *DN) String() string { // // decodeString is based on https://github.com/inteon/cert-manager/blob/ed280d28cd02b262c5db46054d88e70ab518299c/pkg/util/pki/internal/dn.go#L170 func decodeString(str string) (string, error) { - s := []rune(strings.TrimSpace(str)) - // Re-add the trailing space if the last character was an escaped space character - if len(s) > 0 && s[len(s)-1] == '\\' && str[len(str)-1] == ' ' { - s = append(s, ' ') - } + s := []rune(stripLeadingAndTrailingSpaces(str)) builder := strings.Builder{} for i := 0; i < len(s); i++ { @@ -212,6 +191,65 @@ func decodeString(str string) (string, error) { return builder.String(), nil } +// Escape a string according to RFC 4514 +func encodeString(value string, isValue bool) string { + builder := strings.Builder{} + + escapeChar := func(c byte) { + builder.WriteByte('\\') + builder.WriteByte(c) + } + + escapeHex := func(c byte) { + builder.WriteByte('\\') + builder.WriteString(hex.EncodeToString([]byte{c})) + } + + // Loop through each byte and escape as necessary. + // Runes that take up more than one byte are escaped + // byte by byte (since both bytes are non-ASCII). + for i := 0; i < len(value); i++ { + char := value[i] + if i == 0 && (char == ' ' || char == '#') { + // Special case leading space or number sign. + escapeChar(char) + continue + } + if i == len(value)-1 && char == ' ' { + // Special case trailing space. + escapeChar(char) + continue + } + + switch char { + case '"', '+', ',', ';', '<', '>', '\\': + // Each of these special characters must be escaped. + escapeChar(char) + continue + } + + if !isValue && char == '=' { + // Equal signs have to be escaped only in the type part of + // the attribute type and value pair. + escapeChar(char) + continue + } + + if char < ' ' || char > '~' { + // All special character escapes are handled first + // above. All bytes less than ASCII SPACE and all bytes + // greater than ASCII TILDE must be hex-escaped. + escapeHex(char) + continue + } + + // Any other character does not require escaping. + builder.WriteByte(char) + } + + return builder.String() +} + func decodeEncodedString(str string) (string, error) { decoded, err := hex.DecodeString(str) if err != nil { @@ -247,12 +285,17 @@ func ParseDN(str string) (*DN, error) { rdn.Attributes = append(rdn.Attributes, attr) attr = &AttributeTypeAndValue{} if end { + sortAttributes(rdn.Attributes) dn.RDNs = append(dn.RDNs, rdn) rdn = &RelativeDN{} } } ) + // Loop through each character in the string and + // build up the attribute type and value pairs. + // We only check for ascii characters here, which + // allows us to iterate over the string byte by byte. for i := 0; i < len(str); i++ { char := str[i] switch { @@ -420,3 +463,34 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool { func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool { return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value) } + +// foldString returns a folded string such that foldString(x) == foldString(y) +// is identical to bytes.EqualFold(x, y). +// based on https://go.dev/src/encoding/json/fold.go +func foldString(s string) string { + builder := strings.Builder{} + for _, char := range s { + // Handle single-byte ASCII. + if char < utf8.RuneSelf { + if 'A' <= char && char <= 'Z' { + char += 'a' - 'A' + } + builder.WriteRune(char) + continue + } + + builder.WriteRune(foldRune(char)) + } + return builder.String() +} + +// foldRune is returns the smallest rune for all runes in the same fold set. +func foldRune(r rune) rune { + for { + r2 := unicode.SimpleFold(r) + if r2 <= r { + return r + } + r = r2 + } +} diff --git a/dn_test.go b/dn_test.go index 4fc9019..d3c2c18 100644 --- a/dn_test.go +++ b/dn_test.go @@ -3,6 +3,8 @@ package ldap import ( "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestSuccessfulDNParsing(t *testing.T) { @@ -20,8 +22,8 @@ func TestSuccessfulDNParsing(t *testing.T) { }}, "OU=Sales+CN=J. Smith,DC=example,DC=net": {[]*RelativeDN{ {[]*AttributeTypeAndValue{ - {"OU", "Sales"}, {"CN", "J. Smith"}, + {"OU", "Sales"}, }}, {[]*AttributeTypeAndValue{{"DC", "example"}}}, {[]*AttributeTypeAndValue{{"DC", "net"}}}, @@ -59,12 +61,26 @@ func TestSuccessfulDNParsing(t *testing.T) { {" B ", " 2 "}, }}, }}, - + "A = 88 \t": {[]*RelativeDN{ + {[]*AttributeTypeAndValue{ + {"A", "88 \t"}, + }}, + }}, + "A = 88 \n": {[]*RelativeDN{ + {[]*AttributeTypeAndValue{ + {"A", "88 \n"}, + }}, + }}, `cn=john.doe;dc=example,dc=net`: {[]*RelativeDN{ {[]*AttributeTypeAndValue{{"cn", "john.doe"}}}, {[]*AttributeTypeAndValue{{"dc", "example"}}}, {[]*AttributeTypeAndValue{{"dc", "net"}}}, }}, + `cn=⭐;dc=❤️=\==,dc=❤️\\`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"cn", "⭐"}}}, + {[]*AttributeTypeAndValue{{"dc", "❤️==="}}}, + {[]*AttributeTypeAndValue{{"dc", "❤️\\"}}}, + }}, // Escaped `;` should not be treated as RDN `cn=john.doe\;weird name,dc=example,dc=net`: {[]*RelativeDN{ @@ -77,6 +93,29 @@ func TestSuccessfulDNParsing(t *testing.T) { {[]*AttributeTypeAndValue{{"dc", "dummy"}}}, {[]*AttributeTypeAndValue{{"dc", "com"}}}, }}, + `1.3.6.1.4.1.1466.0=test`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"1.3.6.1.4.1.1466.0", "test"}}}, + }}, + `1=#04024869`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"1", "Hi"}}}, + }}, + `CN=James \"Jim\" Smith\, III,DC=example,DC=net`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"CN", "James \"Jim\" Smith, III"}}}, + {[]*AttributeTypeAndValue{{"DC", "example"}}}, + {[]*AttributeTypeAndValue{{"DC", "net"}}}, + }}, + `CN=Before\0dAfter,DC=example,DC=net`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"CN", "Before\x0dAfter"}}}, + {[]*AttributeTypeAndValue{{"DC", "example"}}}, + {[]*AttributeTypeAndValue{{"DC", "net"}}}, + }}, + `cn=foo-lon\e2\9d\a4\ef\b8\8f\,g.com,OU=Foo===Long;ou=Ba # rq,ou=Baz,o=C\; orp.+c=US`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"cn", "foo-lon❤️,g.com"}}}, + {[]*AttributeTypeAndValue{{"OU", "Foo===Long"}}}, + {[]*AttributeTypeAndValue{{"ou", "Ba # rq"}}}, + {[]*AttributeTypeAndValue{{"ou", "Baz"}}}, + {[]*AttributeTypeAndValue{{"c", "US"}, {"o", "C; orp."}}}, + }}, } for test, answer := range testcases { @@ -105,13 +144,17 @@ func TestSuccessfulDNParsing(t *testing.T) { func TestErrorDNParsing(t *testing.T) { testcases := map[string]string{ - "*": "DN ended with incomplete type, value pair", - "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", - "cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0", - "DC=example,=net": "DN ended with incomplete type, value pair", - "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", - "test,DC=example,DC=com": "incomplete type, value pair", - "=test,DC=example,DC=com": "incomplete type, value pair", + "*": "DN ended with incomplete type, value pair", + "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", + "cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0", + "DC=example,=net": "DN ended with incomplete type, value pair", + "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", + "test,DC=example,DC=com": "incomplete type, value pair", + "=test,DC=example,DC=com": "incomplete type, value pair", + "1.3.6.1.4.1.1466.0=test+": "DN ended with incomplete type, value pair", + `1.3.6.1.4.1.1466.0=test;`: "DN ended with incomplete type, value pair", + "1.3.6.1.4.1.1466.0=test+,": "incomplete type, value pair", + "DF=#6666666666665006838820013100000746939546349182108463491821809FBFFFFFFFFF": "failed to unmarshal hex-encoded string: asn1: syntax error: data truncated", } for test, answer := range testcases { @@ -290,3 +333,86 @@ func TestDNAncestor(t *testing.T) { } } } + +func BenchmarkParseSubject(b *testing.B) { + for n := 0; n < b.N; n++ { + _, err := ParseDN("DF=#6666666666665006838820013100000746939546349182108463491821809FBFFFFFFFFF") + if err == nil { + b.Fatal("expected error, but got none") + } + } +} + +func TestMustKeepOrderInRawDerBytes(t *testing.T) { + subject := "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US" + rdnSeq, err := ParseDN(subject) + if err != nil { + t.Fatal(err) + } + + expectedRdnSeq := &DN{ + []*RelativeDN{ + {[]*AttributeTypeAndValue{{Type: "cn", Value: "foo-long.com"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "FooLong"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Barq"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Baz"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Dept."}}}, + {[]*AttributeTypeAndValue{{Type: "o", Value: "Corp."}}}, + {[]*AttributeTypeAndValue{{Type: "c", Value: "US"}}}, + }, + } + + assert.Equal(t, expectedRdnSeq, rdnSeq) + assert.Equal(t, subject, rdnSeq.String()) +} + +func TestRoundTripLiteralSubject(t *testing.T) { + rdnSequences := map[string]string{ + "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US": "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US", + "cn=foo-lon❤️\\,g.com,ou=Foo===Long,ou=Ba # rq,ou=Baz,o=C\\; orp.,c=US": "cn=foo-lon\\e2\\9d\\a4\\ef\\b8\\8f\\,g.com,ou=Foo===Long,ou=Ba # rq,ou=Baz,o=C\\; orp.,c=US", + "cn=fo\x00o-long.com,ou=\x04FooLong": "cn=fo\\00o-long.com,ou=\\04FooLong", + } + + for subjIn, subjOut := range rdnSequences { + t.Logf("Testing subject: %s", subjIn) + + newRDNSeq, err := ParseDN(subjIn) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, subjOut, newRDNSeq.String()) + } +} + +func TestDecodeString(t *testing.T) { + successTestcases := map[string]string{ + "foo-long.com": "foo-long.com", + "foo-lon❤️\\,g.com": "foo-lon❤️,g.com", + "fo\x00o-long.com": "fo\x00o-long.com", + "fo\\00o-long.com": "fo\x00o-long.com", + } + + for encoded, decoded := range successTestcases { + t.Logf("Testing encoded string: %s", encoded) + decodedString, err := decodeString(encoded) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, decoded, decodedString) + } + + errorTestcases := map[string]string{ + "fo\\": "got corrupted escaped character: 'fo\\'", + "fo\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0", + "fo\\UU️o-long.com": "failed to decode escaped character: encoding/hex: invalid byte: U+0055 'U'", + "fo\\0❤️o-long.com": "failed to decode escaped character: invalid byte: 0❤", + } + + for encoded, expectedError := range errorTestcases { + t.Logf("Testing encoded string: %s", encoded) + _, err := decodeString(encoded) + assert.EqualError(t, err, expectedError) + } +} diff --git a/fuzz_test.go b/fuzz_test.go index 74e1a66..e2d3008 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -3,7 +3,11 @@ package ldap -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func FuzzParseDN(f *testing.F) { f.Add("*") @@ -42,3 +46,50 @@ func FuzzEscapeDN(f *testing.F) { _ = EscapeDN(input_data) }) } + +func FuzzRoundTripRDNSequence(f *testing.F) { + f.Add("CN=foo-long.com,OU=FooLong,OU=Barq,OU=Baz,OU=Dept.,O=Corp.,C=US") + f.Add("CN=foo-lon❤️\\,g.com,OU=Foo===Long,OU=Ba # rq,OU=Baz,O=C\\; orp.,C=US") + f.Add("CN=fo\x00o-long.com,OU=\x04FooLong") + f.Add("İ=") + + f.Fuzz(func(t *testing.T, subjectString string) { + t.Parallel() + rdnSeq, err := ParseDN(subjectString) + if err != nil { + t.Skip() + } + + newRDNSeq, err := ParseDN(rdnSeq.String()) + if err != nil { + t.Fatal(err) + } + + assert.True(t, rdnSeq.Equal(newRDNSeq)) + assert.True(t, rdnSeq.EqualFold(newRDNSeq)) + }) +} + +func FuzzRoundTripEncodeDecode(f *testing.F) { + f.Add("dffad=-fasdfsd") + f.Add("❤️\\,") + f.Add("aaa\x00o-long.c\x04FooLong") + f.Add("İ") + + f.Fuzz(func(t *testing.T, rawString string) { + t.Parallel() + keyEncoded := encodeString(rawString, true) + keyDecoded, err := decodeString(keyEncoded) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, rawString, keyDecoded) + + valueEncoded := encodeString(rawString, false) + valueDecoded, err := decodeString(valueEncoded) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, rawString, valueDecoded) + }) +} diff --git a/v3/dn.go b/v3/dn.go index 7478919..5e2683e 100644 --- a/v3/dn.go +++ b/v3/dn.go @@ -8,6 +8,8 @@ import ( "fmt" "sort" "strings" + "unicode" + "unicode/utf8" ) // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 @@ -34,6 +36,9 @@ func (a *AttributeTypeAndValue) setValue(s string) error { // AttributeValue is represented by an number sign ('#' U+0023) // character followed by the hexadecimal encoding of each of the octets // of the BER encoding of the X.500 AttributeValue. + // + // WARNING: we only support hex-encoded ASN.1 DER values here, not + // BER encoding. This is a deviation from the RFC. if len(s) > 0 && s[0] == '#' { decodedString, err := decodeEncodedString(s[1:]) if err != nil { @@ -56,59 +61,7 @@ func (a *AttributeTypeAndValue) setValue(s string) error { // String returns a normalized string representation of this attribute type and // value pair which is the lowercase join of the Type and Value with a "=". func (a *AttributeTypeAndValue) String() string { - return strings.ToLower(a.Type) + "=" + a.encodeValue() -} - -func (a *AttributeTypeAndValue) encodeValue() string { - // Normalize the value first. - // value := strings.ToLower(a.Value) - value := a.Value - - encodedBuf := bytes.Buffer{} - - escapeChar := func(c byte) { - encodedBuf.WriteByte('\\') - encodedBuf.WriteByte(c) - } - - escapeHex := func(c byte) { - encodedBuf.WriteByte('\\') - encodedBuf.WriteString(hex.EncodeToString([]byte{c})) - } - - for i := 0; i < len(value); i++ { - char := value[i] - if i == 0 && char == ' ' || char == '#' { - // Special case leading space or number sign. - escapeChar(char) - continue - } - if i == len(value)-1 && char == ' ' { - // Special case trailing space. - escapeChar(char) - continue - } - - switch char { - case '"', '+', ',', ';', '<', '>', '\\': - // Each of these special characters must be escaped. - escapeChar(char) - continue - } - - if char < ' ' || char > '~' { - // All special character escapes are handled first - // above. All bytes less than ASCII SPACE and all bytes - // greater than ASCII TILDE must be hex-escaped. - escapeHex(char) - continue - } - - // Any other character does not require escaping. - encodedBuf.WriteByte(char) - } - - return encodedBuf.String() + return encodeString(foldString(a.Type), false) + "=" + encodeString(a.Value, true) } // RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514 @@ -119,12 +72,29 @@ type RelativeDN struct { // String returns a normalized string representation of this relative DN which // is the a join of all attributes (sorted in increasing order) with a "+". func (r *RelativeDN) String() string { - attrs := make([]string, len(r.Attributes)) - for i := range r.Attributes { - attrs[i] = r.Attributes[i].String() + builder := strings.Builder{} + sortedAttributes := make([]*AttributeTypeAndValue, len(r.Attributes)) + copy(sortedAttributes, r.Attributes) + sortAttributes(sortedAttributes) + for i, atv := range sortedAttributes { + builder.WriteString(atv.String()) + if i < len(sortedAttributes)-1 { + builder.WriteByte('+') + } } - sort.Strings(attrs) - return strings.Join(attrs, "+") + return builder.String() +} + +func sortAttributes(atvs []*AttributeTypeAndValue) { + sort.Slice(atvs, func(i, j int) bool { + ti := foldString(atvs[i].Type) + tj := foldString(atvs[j].Type) + if ti != tj { + return ti < tj + } + + return atvs[i].Value < atvs[j].Value + }) } // DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514 @@ -135,11 +105,25 @@ type DN struct { // String returns a normalized string representation of this DN which is the // join of all relative DNs with a ",". func (d *DN) String() string { - rdns := make([]string, len(d.RDNs)) - for i := range d.RDNs { - rdns[i] = d.RDNs[i].String() + builder := strings.Builder{} + for i, rdn := range d.RDNs { + builder.WriteString(rdn.String()) + if i < len(d.RDNs)-1 { + builder.WriteByte(',') + } } - return strings.Join(rdns, ",") + return builder.String() +} + +func stripLeadingAndTrailingSpaces(inVal string) string { + noSpaces := strings.Trim(inVal, " ") + + // Re-add the trailing space if it was an escaped space + if len(noSpaces) > 0 && noSpaces[len(noSpaces)-1] == '\\' && inVal[len(inVal)-1] == ' ' { + noSpaces = noSpaces + " " + } + + return noSpaces } // Remove leading and trailing spaces from the attribute type and value @@ -147,11 +131,7 @@ func (d *DN) String() string { // // decodeString is based on https://github.com/inteon/cert-manager/blob/ed280d28cd02b262c5db46054d88e70ab518299c/pkg/util/pki/internal/dn.go#L170 func decodeString(str string) (string, error) { - s := []rune(strings.TrimSpace(str)) - // Re-add the trailing space if the last character was an escaped space character - if len(s) > 0 && s[len(s)-1] == '\\' && str[len(str)-1] == ' ' { - s = append(s, ' ') - } + s := []rune(stripLeadingAndTrailingSpaces(str)) builder := strings.Builder{} for i := 0; i < len(s); i++ { @@ -212,6 +192,62 @@ func decodeString(str string) (string, error) { return builder.String(), nil } +// Escape a string according to RFC 4514 +func encodeString(value string, isValue bool) string { + encodedBuf := bytes.Buffer{} + + escapeChar := func(c byte) { + encodedBuf.WriteByte('\\') + encodedBuf.WriteByte(c) + } + + escapeHex := func(c byte) { + encodedBuf.WriteByte('\\') + encodedBuf.WriteString(hex.EncodeToString([]byte{c})) + } + + for i := 0; i < len(value); i++ { + char := value[i] + if i == 0 && (char == ' ' || char == '#') { + // Special case leading space or number sign. + escapeChar(char) + continue + } + if i == len(value)-1 && char == ' ' { + // Special case trailing space. + escapeChar(char) + continue + } + + switch char { + case '"', '+', ',', ';', '<', '>', '\\': + // Each of these special characters must be escaped. + escapeChar(char) + continue + } + + if !isValue && char == '=' { + // Equal signs have to be escaped only in the type part of + // the attribute type and value pair. + escapeChar(char) + continue + } + + if char < ' ' || char > '~' { + // All special character escapes are handled first + // above. All bytes less than ASCII SPACE and all bytes + // greater than ASCII TILDE must be hex-escaped. + escapeHex(char) + continue + } + + // Any other character does not require escaping. + encodedBuf.WriteByte(char) + } + + return encodedBuf.String() +} + func decodeEncodedString(str string) (string, error) { decoded, err := hex.DecodeString(str) if err != nil { @@ -247,6 +283,7 @@ func ParseDN(str string) (*DN, error) { rdn.Attributes = append(rdn.Attributes, attr) attr = &AttributeTypeAndValue{} if end { + sortAttributes(rdn.Attributes) dn.RDNs = append(dn.RDNs, rdn) rdn = &RelativeDN{} } @@ -420,3 +457,34 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool { func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool { return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value) } + +// foldString returns a folded string such that foldString(x) == foldString(y) +// is identical to bytes.EqualFold(x, y). +// based on https://go.dev/src/encoding/json/fold.go +func foldString(s string) string { + builder := strings.Builder{} + for _, char := range s { + // Handle single-byte ASCII. + if char < utf8.RuneSelf { + if 'A' <= char && char <= 'Z' { + char += 'a' - 'A' + } + builder.WriteRune(char) + continue + } + + builder.WriteRune(foldRune(char)) + } + return builder.String() +} + +// foldRune is returns the smallest rune for all runes in the same fold set. +func foldRune(r rune) rune { + for { + r2 := unicode.SimpleFold(r) + if r2 <= r { + return r + } + r = r2 + } +} diff --git a/v3/dn_test.go b/v3/dn_test.go index 4fc9019..8693d63 100644 --- a/v3/dn_test.go +++ b/v3/dn_test.go @@ -3,6 +3,8 @@ package ldap import ( "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestSuccessfulDNParsing(t *testing.T) { @@ -20,8 +22,8 @@ func TestSuccessfulDNParsing(t *testing.T) { }}, "OU=Sales+CN=J. Smith,DC=example,DC=net": {[]*RelativeDN{ {[]*AttributeTypeAndValue{ - {"OU", "Sales"}, {"CN", "J. Smith"}, + {"OU", "Sales"}, }}, {[]*AttributeTypeAndValue{{"DC", "example"}}}, {[]*AttributeTypeAndValue{{"DC", "net"}}}, @@ -59,12 +61,26 @@ func TestSuccessfulDNParsing(t *testing.T) { {" B ", " 2 "}, }}, }}, - + "A = 88 \t": {[]*RelativeDN{ + {[]*AttributeTypeAndValue{ + {"A", "88 \t"}, + }}, + }}, + "A = 88 \n": {[]*RelativeDN{ + {[]*AttributeTypeAndValue{ + {"A", "88 \n"}, + }}, + }}, `cn=john.doe;dc=example,dc=net`: {[]*RelativeDN{ {[]*AttributeTypeAndValue{{"cn", "john.doe"}}}, {[]*AttributeTypeAndValue{{"dc", "example"}}}, {[]*AttributeTypeAndValue{{"dc", "net"}}}, }}, + `cn=⭐;dc=❤️=\==,dc=❤️\\`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"cn", "⭐"}}}, + {[]*AttributeTypeAndValue{{"dc", "❤️==="}}}, + {[]*AttributeTypeAndValue{{"dc", "❤️\\"}}}, + }}, // Escaped `;` should not be treated as RDN `cn=john.doe\;weird name,dc=example,dc=net`: {[]*RelativeDN{ @@ -77,6 +93,29 @@ func TestSuccessfulDNParsing(t *testing.T) { {[]*AttributeTypeAndValue{{"dc", "dummy"}}}, {[]*AttributeTypeAndValue{{"dc", "com"}}}, }}, + `1.3.6.1.4.1.1466.0=test`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"1.3.6.1.4.1.1466.0", "test"}}}, + }}, + `1=#04024869`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"1", "Hi"}}}, + }}, + `CN=James \"Jim\" Smith\, III,DC=example,DC=net`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"CN", "James \"Jim\" Smith, III"}}}, + {[]*AttributeTypeAndValue{{"DC", "example"}}}, + {[]*AttributeTypeAndValue{{"DC", "net"}}}, + }}, + `CN=Before\0dAfter,DC=example,DC=net`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"CN", "Before\x0dAfter"}}}, + {[]*AttributeTypeAndValue{{"DC", "example"}}}, + {[]*AttributeTypeAndValue{{"DC", "net"}}}, + }}, + `cn=foo-lon\e2\9d\a4\ef\b8\8f\,g.com,OU=Foo===Long;ou=Ba # rq,ou=Baz,o=C\; orp.+c=US`: {[]*RelativeDN{ + {[]*AttributeTypeAndValue{{"cn", "foo-lon❤️,g.com"}}}, + {[]*AttributeTypeAndValue{{"OU", "Foo===Long"}}}, + {[]*AttributeTypeAndValue{{"ou", "Ba # rq"}}}, + {[]*AttributeTypeAndValue{{"ou", "Baz"}}}, + {[]*AttributeTypeAndValue{{"c", "US"}, {"o", "C; orp."}}}, + }}, } for test, answer := range testcases { @@ -105,13 +144,17 @@ func TestSuccessfulDNParsing(t *testing.T) { func TestErrorDNParsing(t *testing.T) { testcases := map[string]string{ - "*": "DN ended with incomplete type, value pair", - "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", - "cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0", - "DC=example,=net": "DN ended with incomplete type, value pair", - "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", - "test,DC=example,DC=com": "incomplete type, value pair", - "=test,DC=example,DC=com": "incomplete type, value pair", + "*": "DN ended with incomplete type, value pair", + "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", + "cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0", + "DC=example,=net": "DN ended with incomplete type, value pair", + "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", + "test,DC=example,DC=com": "incomplete type, value pair", + "=test,DC=example,DC=com": "incomplete type, value pair", + "1.3.6.1.4.1.1466.0=test+": "DN ended with incomplete type, value pair", + `1.3.6.1.4.1.1466.0=test;`: "DN ended with incomplete type, value pair", + "1.3.6.1.4.1.1466.0=test+,": "incomplete type, value pair", + "DF=#6666666666665006838820013100000746939546349182108463491821809FBFFFFFFFFF": "failed to unmarshal hex-encoded string: asn1: syntax error: data truncated", } for test, answer := range testcases { @@ -290,3 +333,54 @@ func TestDNAncestor(t *testing.T) { } } } + +func BenchmarkParseSubject(b *testing.B) { + for n := 0; n < b.N; n++ { + _, err := ParseDN("DF=#6666666666665006838820013100000746939546349182108463491821809FBFFFFFFFFF") + if err == nil { + b.Fatal("expected error, but got none") + } + } +} + +func TestMustKeepOrderInRawDerBytes(t *testing.T) { + subject := "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US" + rdnSeq, err := ParseDN(subject) + if err != nil { + t.Fatal(err) + } + + expectedRdnSeq := &DN{ + []*RelativeDN{ + {[]*AttributeTypeAndValue{{Type: "cn", Value: "foo-long.com"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "FooLong"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Barq"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Baz"}}}, + {[]*AttributeTypeAndValue{{Type: "ou", Value: "Dept."}}}, + {[]*AttributeTypeAndValue{{Type: "o", Value: "Corp."}}}, + {[]*AttributeTypeAndValue{{Type: "c", Value: "US"}}}, + }, + } + + assert.Equal(t, expectedRdnSeq, rdnSeq) + assert.Equal(t, subject, rdnSeq.String()) +} + +func TestRoundTripLiteralSubject(t *testing.T) { + rdnSequences := map[string]string{ + "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US": "cn=foo-long.com,ou=FooLong,ou=Barq,ou=Baz,ou=Dept.,o=Corp.,c=US", + "cn=foo-lon❤️\\,g.com,ou=Foo===Long,ou=Ba # rq,ou=Baz,o=C\\; orp.,c=US": "cn=foo-lon\\e2\\9d\\a4\\ef\\b8\\8f\\,g.com,ou=Foo===Long,ou=Ba # rq,ou=Baz,o=C\\; orp.,c=US", + "cn=fo\x00o-long.com,ou=\x04FooLong": "cn=fo\\00o-long.com,ou=\\04FooLong", + } + + for subjIn, subjOut := range rdnSequences { + t.Logf("Testing subject: %s", subjIn) + + newRDNSeq, err := ParseDN(subjIn) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, subjOut, newRDNSeq.String()) + } +} diff --git a/v3/fuzz_test.go b/v3/fuzz_test.go new file mode 100644 index 0000000..e2d3008 --- /dev/null +++ b/v3/fuzz_test.go @@ -0,0 +1,95 @@ +//go:build go1.18 +// +build go1.18 + +package ldap + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func FuzzParseDN(f *testing.F) { + f.Add("*") + f.Add("cn=Jim\\0Test") + f.Add("cn=Jim\\0") + f.Add("DC=example,=net") + f.Add("o=a+o=B") + + f.Fuzz(func(t *testing.T, input_data string) { + _, _ = ParseDN(input_data) + }) +} + +func FuzzDecodeEscapedSymbols(f *testing.F) { + f.Add([]byte("a\u0100\x80")) + f.Add([]byte(`start\d`)) + f.Add([]byte(`\`)) + f.Add([]byte(`start\--end`)) + f.Add([]byte(`start\d0\hh`)) + + f.Fuzz(func(t *testing.T, input_data []byte) { + _, _ = decodeEscapedSymbols(input_data) + }) +} + +func FuzzEscapeDN(f *testing.F) { + f.Add("test,user") + f.Add("#test#user#") + f.Add("\\test\\user\\") + f.Add(" test user ") + f.Add("\u0000te\x00st\x00user" + string(rune(0))) + f.Add("test\"+,;<>\\-_user") + f.Add("test\u0391user ") + + f.Fuzz(func(t *testing.T, input_data string) { + _ = EscapeDN(input_data) + }) +} + +func FuzzRoundTripRDNSequence(f *testing.F) { + f.Add("CN=foo-long.com,OU=FooLong,OU=Barq,OU=Baz,OU=Dept.,O=Corp.,C=US") + f.Add("CN=foo-lon❤️\\,g.com,OU=Foo===Long,OU=Ba # rq,OU=Baz,O=C\\; orp.,C=US") + f.Add("CN=fo\x00o-long.com,OU=\x04FooLong") + f.Add("İ=") + + f.Fuzz(func(t *testing.T, subjectString string) { + t.Parallel() + rdnSeq, err := ParseDN(subjectString) + if err != nil { + t.Skip() + } + + newRDNSeq, err := ParseDN(rdnSeq.String()) + if err != nil { + t.Fatal(err) + } + + assert.True(t, rdnSeq.Equal(newRDNSeq)) + assert.True(t, rdnSeq.EqualFold(newRDNSeq)) + }) +} + +func FuzzRoundTripEncodeDecode(f *testing.F) { + f.Add("dffad=-fasdfsd") + f.Add("❤️\\,") + f.Add("aaa\x00o-long.c\x04FooLong") + f.Add("İ") + + f.Fuzz(func(t *testing.T, rawString string) { + t.Parallel() + keyEncoded := encodeString(rawString, true) + keyDecoded, err := decodeString(keyEncoded) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, rawString, keyDecoded) + + valueEncoded := encodeString(rawString, false) + valueDecoded, err := decodeString(valueEncoded) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, rawString, valueDecoded) + }) +}