Skip to content

Commit

Permalink
add tests to parseDN (including fuzz tests) and apply changes require…
Browse files Browse the repository at this point in the history
…d to make roundtripping work

Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com>
  • Loading branch information
inteon committed Apr 3, 2024
1 parent 4ca7b8e commit 2bc9b15
Show file tree
Hide file tree
Showing 6 changed files with 662 additions and 154 deletions.
210 changes: 142 additions & 68 deletions dn.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package ldap

import (
"bytes"
"encoding/asn1"
"encoding/hex"
"errors"
"fmt"
"sort"
"strings"
"unicode"
"unicode/utf8"
)

// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
Expand All @@ -34,6 +35,9 @@ func (a *AttributeTypeAndValue) setValue(s string) error {
// AttributeValue is represented by an number sign ('#' U+0023)
// character followed by the hexadecimal encoding of each of the octets
// of the BER encoding of the X.500 AttributeValue.
//
// WARNING: we only support hex-encoded ASN.1 DER values here, not
// BER encoding. This is a deviation from the RFC.
if len(s) > 0 && s[0] == '#' {
decodedString, err := decodeEncodedString(s[1:])
if err != nil {
Expand All @@ -56,59 +60,7 @@ func (a *AttributeTypeAndValue) setValue(s string) error {
// String returns a normalized string representation of this attribute type and
// value pair which is the lowercase join of the Type and Value with a "=".
func (a *AttributeTypeAndValue) String() string {
return strings.ToLower(a.Type) + "=" + a.encodeValue()
}

func (a *AttributeTypeAndValue) encodeValue() string {
// Normalize the value first.
// value := strings.ToLower(a.Value)
value := a.Value

encodedBuf := bytes.Buffer{}

escapeChar := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteByte(c)
}

escapeHex := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteString(hex.EncodeToString([]byte{c}))
}

for i := 0; i < len(value); i++ {
char := value[i]
if i == 0 && char == ' ' || char == '#' {
// Special case leading space or number sign.
escapeChar(char)
continue
}
if i == len(value)-1 && char == ' ' {
// Special case trailing space.
escapeChar(char)
continue
}

switch char {
case '"', '+', ',', ';', '<', '>', '\\':
// Each of these special characters must be escaped.
escapeChar(char)
continue
}

if char < ' ' || char > '~' {
// All special character escapes are handled first
// above. All bytes less than ASCII SPACE and all bytes
// greater than ASCII TILDE must be hex-escaped.
escapeHex(char)
continue
}

// Any other character does not require escaping.
encodedBuf.WriteByte(char)
}

return encodedBuf.String()
return encodeString(foldString(a.Type), false) + "=" + encodeString(a.Value, true)
}

// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514
Expand All @@ -119,12 +71,29 @@ type RelativeDN struct {
// String returns a normalized string representation of this relative DN which
// is the a join of all attributes (sorted in increasing order) with a "+".
func (r *RelativeDN) String() string {
attrs := make([]string, len(r.Attributes))
for i := range r.Attributes {
attrs[i] = r.Attributes[i].String()
builder := strings.Builder{}
sortedAttributes := make([]*AttributeTypeAndValue, len(r.Attributes))
copy(sortedAttributes, r.Attributes)
sortAttributes(sortedAttributes)
for i, atv := range sortedAttributes {
builder.WriteString(atv.String())
if i < len(sortedAttributes)-1 {
builder.WriteByte('+')
}
}
sort.Strings(attrs)
return strings.Join(attrs, "+")
return builder.String()
}

func sortAttributes(atvs []*AttributeTypeAndValue) {
sort.Slice(atvs, func(i, j int) bool {
ti := foldString(atvs[i].Type)
tj := foldString(atvs[j].Type)
if ti != tj {
return ti < tj
}

return atvs[i].Value < atvs[j].Value
})
}

// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514
Expand All @@ -135,23 +104,33 @@ type DN struct {
// String returns a normalized string representation of this DN which is the
// join of all relative DNs with a ",".
func (d *DN) String() string {
rdns := make([]string, len(d.RDNs))
for i := range d.RDNs {
rdns[i] = d.RDNs[i].String()
builder := strings.Builder{}
for i, rdn := range d.RDNs {
builder.WriteString(rdn.String())
if i < len(d.RDNs)-1 {
builder.WriteByte(',')
}
}
return strings.Join(rdns, ",")
return builder.String()
}

func stripLeadingAndTrailingSpaces(inVal string) string {
noSpaces := strings.Trim(inVal, " ")

// Re-add the trailing space if it was an escaped space
if len(noSpaces) > 0 && noSpaces[len(noSpaces)-1] == '\\' && inVal[len(inVal)-1] == ' ' {
noSpaces = noSpaces + " "
}

return noSpaces
}

// Remove leading and trailing spaces from the attribute type and value
// and unescape any escaped characters in these fields
//
// decodeString is based on https://github.com/inteon/cert-manager/blob/ed280d28cd02b262c5db46054d88e70ab518299c/pkg/util/pki/internal/dn.go#L170
func decodeString(str string) (string, error) {
s := []rune(strings.TrimSpace(str))
// Re-add the trailing space if the last character was an escaped space character
if len(s) > 0 && s[len(s)-1] == '\\' && str[len(str)-1] == ' ' {
s = append(s, ' ')
}
s := []rune(stripLeadingAndTrailingSpaces(str))

builder := strings.Builder{}
for i := 0; i < len(s); i++ {
Expand Down Expand Up @@ -212,6 +191,65 @@ func decodeString(str string) (string, error) {
return builder.String(), nil
}

// Escape a string according to RFC 4514
func encodeString(value string, isValue bool) string {
builder := strings.Builder{}

escapeChar := func(c byte) {
builder.WriteByte('\\')
builder.WriteByte(c)
}

escapeHex := func(c byte) {
builder.WriteByte('\\')
builder.WriteString(hex.EncodeToString([]byte{c}))
}

// Loop through each byte and escape as necessary.
// Runes that take up more than one byte are escaped
// byte by byte (since both bytes are non-ASCII).
for i := 0; i < len(value); i++ {
char := value[i]
if i == 0 && (char == ' ' || char == '#') {
// Special case leading space or number sign.
escapeChar(char)
continue
}
if i == len(value)-1 && char == ' ' {
// Special case trailing space.
escapeChar(char)
continue
}

switch char {
case '"', '+', ',', ';', '<', '>', '\\':
// Each of these special characters must be escaped.
escapeChar(char)
continue
}

if !isValue && char == '=' {
// Equal signs have to be escaped only in the type part of
// the attribute type and value pair.
escapeChar(char)
continue
}

if char < ' ' || char > '~' {
// All special character escapes are handled first
// above. All bytes less than ASCII SPACE and all bytes
// greater than ASCII TILDE must be hex-escaped.
escapeHex(char)
continue
}

// Any other character does not require escaping.
builder.WriteByte(char)
}

return builder.String()
}

func decodeEncodedString(str string) (string, error) {
decoded, err := hex.DecodeString(str)
if err != nil {
Expand Down Expand Up @@ -247,12 +285,17 @@ func ParseDN(str string) (*DN, error) {
rdn.Attributes = append(rdn.Attributes, attr)
attr = &AttributeTypeAndValue{}
if end {
sortAttributes(rdn.Attributes)
dn.RDNs = append(dn.RDNs, rdn)
rdn = &RelativeDN{}
}
}
)

// Loop through each character in the string and
// build up the attribute type and value pairs.
// We only check for ascii characters here, which
// allows us to iterate over the string byte by byte.
for i := 0; i < len(str); i++ {
char := str[i]
switch {
Expand Down Expand Up @@ -420,3 +463,34 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
}

// foldString returns a folded string such that foldString(x) == foldString(y)
// is identical to bytes.EqualFold(x, y).
// based on https://go.dev/src/encoding/json/fold.go
func foldString(s string) string {
builder := strings.Builder{}
for _, char := range s {
// Handle single-byte ASCII.
if char < utf8.RuneSelf {
if 'A' <= char && char <= 'Z' {
char += 'a' - 'A'
}
builder.WriteRune(char)
continue
}

builder.WriteRune(foldRune(char))
}
return builder.String()
}

// foldRune is returns the smallest rune for all runes in the same fold set.
func foldRune(r rune) rune {
for {
r2 := unicode.SimpleFold(r)
if r2 <= r {
return r
}
r = r2
}
}
Loading

0 comments on commit 2bc9b15

Please sign in to comment.