Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor ParseDN function to fix resource usage and invalid parsings (fixes #487) #497

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 168 additions & 94 deletions dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package ldap

import (
"bytes"
enchex "encoding/hex"
"encoding/asn1"
"encoding/hex"
"errors"
"fmt"
"sort"
"strings"

ber "github.com/go-asn1-ber/asn1-ber"
)

// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
Expand All @@ -19,8 +18,43 @@ type AttributeTypeAndValue struct {
Value string
}

func (a *AttributeTypeAndValue) setType(str string) error {
result, err := decodeString(str)
if err != nil {
return err
}
a.Type = result

return nil
}

func (a *AttributeTypeAndValue) setValue(s string) error {
// https://www.ietf.org/rfc/rfc4514.html#section-2.4
// If the AttributeType is of the dotted-decimal form, the
// AttributeValue is represented by an number sign ('#' U+0023)
// character followed by the hexadecimal encoding of each of the octets
// of the BER encoding of the X.500 AttributeValue.
if len(s) > 0 && s[0] == '#' {
decodedString, err := decodeEncodedString(s[1:])
if err != nil {
return err
}

a.Value = decodedString
return nil
} else {
decodedString, err := decodeString(s)
if err != nil {
return err
}

a.Value = decodedString
return nil
}
}

// String returns a normalized string representation of this attribute type and
// value pair which is the a lowercased join of the Type and Value with a "=".
// value pair which is the lowercase join of the Type and Value with a "=".
func (a *AttributeTypeAndValue) String() string {
return strings.ToLower(a.Type) + "=" + a.encodeValue()
}
Expand All @@ -39,7 +73,7 @@ func (a *AttributeTypeAndValue) encodeValue() string {

escapeHex := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteString(enchex.EncodeToString([]byte{c}))
encodedBuf.WriteString(hex.EncodeToString([]byte{c}))
}

for i := 0; i < len(value); i++ {
Expand Down Expand Up @@ -108,114 +142,154 @@ func (d *DN) String() string {
return strings.Join(rdns, ",")
}

// Remove leading and trailing spaces from the attribute type and value
// and unescape any escaped characters in these fields
// Remove leading and trailing spaces from the attribute type and value
// and unescape any escaped characters in these fields
cpuschma marked this conversation as resolved.
Show resolved Hide resolved
//
// decodeString is based on https://github.com/inteon/cert-manager/blob/ed280d28cd02b262c5db46054d88e70ab518299c/pkg/util/pki/internal/dn.go#L170
func decodeString(str string) (string, error) {
s := []rune(strings.TrimSpace(str))
// Re-add the trailing space if the last character was an escaped space character
if len(s) > 0 && s[len(s)-1] == '\\' && str[len(str)-2] == ' ' {
s = append(s, ' ')
}

builder := strings.Builder{}
for i := 0; i < len(s); i++ {
char := s[i]

// If the character is not an escape character, just add it to the
// builder and continue
if char != '\\' {
builder.WriteRune(char)
continue
}

// If the escape character is the last character, it's a corrupted
// escaped character
if i+1 >= len(s) {
return "", fmt.Errorf("got corrupted escaped character: '%s'", string(s))
}

// If the escaped character is a special character, just add it to
// the builder and continue
switch s[i+1] {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
builder.WriteRune(s[i+1])
i++
continue
}

// If the escaped character is not a special character, it should
// be a hex-encoded character of the form \XX if it's not at least
// two characters long, it's a corrupted escaped character
if i+2 >= len(s) {
return "", errors.New("failed to decode escaped character: encoding/hex: invalid byte: " + string(s[i+1]))
}

// Get the runes for the two characters after the escape character
// and convert them to a byte slice
xx := []byte(string(s[i+1 : i+3]))

// If the two runes are not hex characters and result in more than
// two bytes when converted to a byte slice, it's a corrupted
// escaped character
if len(xx) != 2 {
return "", errors.New("failed to decode escaped character: invalid byte: " + string(xx))
}

// Decode the hex-encoded character and add it to the builder
dst := []byte{0}
if n, err := hex.Decode(dst, xx); err != nil {
return "", errors.New("failed to decode escaped character: " + err.Error())
} else if n != 1 {
return "", fmt.Errorf("failed to decode escaped character: encoding/hex: expected 1 byte when un-escaping, got %d", n)
}

builder.WriteByte(dst[0])
i += 2
}

return builder.String(), nil
}

func decodeEncodedString(str string) (string, error) {
decoded, err := hex.DecodeString(str)
if err != nil {
return "", fmt.Errorf("failed to decode BER encoding: %s", err)
}

var rawValue asn1.RawValue
result, err := asn1.Unmarshal(decoded, &rawValue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuschma @johnweldon
Great to see you are improving this function!
There is one small issue with this approach that I also figured out too late: asn1.Unmarshal( only supports DER parsing, while the github.com/go-asn1-ber/asn1-ber library supports BER parsing too.
This means that parseDN deviates from the RFC, possibly resulting in unexpected limitations for the user.

Copy link
Member Author

@cpuschma cpuschma Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into this and also extend the test cases. We didn't catch on that either. Thank you for pointing this out, @inteon!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have some extra testcases that might be useful. Will create a PR soon.

if err != nil {
return "", fmt.Errorf("failed to unmarshal hex-encoded string: %s", err)
}
if len(result) != 0 {
return "", errors.New("trailing data after unmarshalling hex-encoded string")
}

return string(rawValue.Bytes), nil
}

// ParseDN returns a distinguishedName or an error.
// The function respects https://tools.ietf.org/html/rfc4514
func ParseDN(str string) (*DN, error) {
dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0)
rdn := new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
buffer := bytes.Buffer{}
attribute := new(AttributeTypeAndValue)
escaping := false

unescapedTrailingSpaces := 0
stringFromBuffer := func() string {
s := buffer.String()
s = s[0 : len(s)-unescapedTrailingSpaces]
buffer.Reset()
unescapedTrailingSpaces = 0
return s
var dn = &DN{RDNs: make([]*RelativeDN, 0)}
if str = strings.TrimSpace(str); len(str) == 0 {
return dn, nil
}

var (
rdn = &RelativeDN{}
attr = &AttributeTypeAndValue{}
escaping bool
startPos int
appendAttributesToRDN = func(end bool) {
rdn.Attributes = append(rdn.Attributes, attr)
attr = &AttributeTypeAndValue{}
if end {
dn.RDNs = append(dn.RDNs, rdn)
rdn = &RelativeDN{}
}
}
)

for i := 0; i < len(str); i++ {
char := str[i]
switch {
case escaping:
unescapedTrailingSpaces = 0
escaping = false
switch char {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
buffer.WriteByte(char)
continue
}
// Not a special character, assume hex encoded octet
if len(str) == i+1 {
return nil, errors.New("got corrupted escaped character")
}

dst := []byte{0}
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
if err != nil {
return nil, fmt.Errorf("failed to decode escaped character: %s", err)
} else if n != 1 {
return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n)
}
buffer.WriteByte(dst[0])
i++
case char == '\\':
unescapedTrailingSpaces = 0
escaping = true
case char == '=' && attribute.Type == "":
attribute.Type = stringFromBuffer()
// Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward
// and decode.
if len(str) > i+1 && str[i+1] == '#' {
i += 2
index := strings.IndexAny(str[i:], ",+")
var data string
if index > 0 {
data = str[i : i+index]
} else {
data = str[i:]
}
rawBER, err := enchex.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("failed to decode BER encoding: %s", err)
}
packet, err := ber.DecodePacketErr(rawBER)
if err != nil {
return nil, fmt.Errorf("failed to decode BER packet: %s", err)
}
buffer.WriteString(packet.Data.String())
i += len(data) - 1
case char == '=' && len(attr.Type) == 0:
if err := attr.setType(str[startPos:i]); err != nil {
return nil, err
}
startPos = i + 1
case char == ',' || char == '+' || char == ';':
// We're done with this RDN or value, push it
if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
if len(attr.Type) == 0 {
return dn, errors.New("incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
if char == ',' || char == ';' {
dn.RDNs = append(dn.RDNs, rdn)
rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
if err := attr.setValue(str[startPos:i]); err != nil {
return nil, err
}
case char == ' ' && buffer.Len() == 0:
// ignore unescaped leading spaces
continue
default:
if char == ' ' {
// Track unescaped spaces in case they are trailing and we need to remove them
unescapedTrailingSpaces++
} else {
// Reset if we see a non-space char
unescapedTrailingSpaces = 0
}
buffer.WriteByte(char)

startPos = i + 1
last := char == ',' || char == ';'
appendAttributesToRDN(last)
}
}
if buffer.Len() > 0 {
if len(attribute.Type) == 0 {
return nil, errors.New("DN ended with incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
dn.RDNs = append(dn.RDNs, rdn)

if len(attr.Type) == 0 {
return dn, errors.New("DN ended with incomplete type, value pair")
}

if err := attr.setValue(str[startPos:]); err != nil {
return dn, err
}
appendAttributesToRDN(true)

return dn, nil
}

Expand Down
18 changes: 9 additions & 9 deletions dn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,21 @@ func TestSuccessfulDNParsing(t *testing.T) {
for test, answer := range testcases {
dn, err := ParseDN(test)
if err != nil {
t.Errorf(err.Error())
t.Errorf("ParseDN failed for DN test '%s': %s", test, err)
continue
}
if !reflect.DeepEqual(dn, &answer) {
t.Errorf("Parsed DN %s is not equal to the expected structure", test)
t.Errorf("Parsed DN '%s' is not equal to the expected structure", test)
t.Logf("Expected:")
for _, rdn := range answer.RDNs {
for _, attribs := range rdn.Attributes {
t.Logf("#%v\n", attribs)
for _, attribute := range rdn.Attributes {
t.Logf(" #%v\n", attribute)
}
}
t.Logf("Actual:")
for _, rdn := range dn.RDNs {
for _, attribs := range rdn.Attributes {
t.Logf("#%v\n", attribs)
for _, attribute := range rdn.Attributes {
t.Logf(" #%v\n", attribute)
}
}
}
Expand All @@ -107,7 +107,7 @@ func TestErrorDNParsing(t *testing.T) {
testcases := map[string]string{
"*": "DN ended with incomplete type, value pair",
"cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'",
"cn=Jim\\0": "got corrupted escaped character",
"cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0",
"DC=example,=net": "DN ended with incomplete type, value pair",
"1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string",
"test,DC=example,DC=com": "incomplete type, value pair",
Expand All @@ -117,9 +117,9 @@ func TestErrorDNParsing(t *testing.T) {
for test, answer := range testcases {
_, err := ParseDN(test)
if err == nil {
t.Errorf("Expected %s to fail parsing but succeeded\n", test)
t.Errorf("Expected '%s' to fail parsing but succeeded\n", test)
} else if err.Error() != answer {
t.Errorf("Unexpected error on %s:\n%s\nvs.\n%s\n", test, answer, err.Error())
t.Errorf("Unexpected error on: '%s':\nExpected: %s\nGot: %s\n", test, answer, err.Error())
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,6 @@ func DebugBinaryFile(fileName string) error {
return nil
}

var hex = "0123456789abcdef"

func mustEscape(c byte) bool {
return c > 0x7f || c == '(' || c == ')' || c == '\\' || c == '*' || c == 0
}
Expand All @@ -324,6 +322,7 @@ func mustEscape(c byte) bool {
// characters in the set `()*\` and those out of the range 0 < c < 0x80,
// as defined in RFC4515.
func EscapeFilter(filter string) string {
const hexValues = "0123456789abcdef"
escape := 0
for i := 0; i < len(filter); i++ {
if mustEscape(filter[i]) {
Expand All @@ -338,8 +337,8 @@ func EscapeFilter(filter string) string {
c := filter[i]
if mustEscape(c) {
buf[j+0] = '\\'
buf[j+1] = hex[c>>4]
buf[j+2] = hex[c&0xf]
buf[j+1] = hexValues[c>>4]
buf[j+2] = hexValues[c&0xf]
j += 3
} else {
buf[j] = c
Expand Down
Loading
Loading