From a439c5c61b893573f7d1600a62c30fc023b611ec Mon Sep 17 00:00:00 2001 From: Hanno Hecker Date: Sat, 13 Feb 2016 18:50:24 +0100 Subject: [PATCH 1/5] util functions for DNs --- dn_util.go | 216 ++++++++++++++++++++++++++++++++++++++++++++++++ dn_util_test.go | 93 +++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 dn_util.go create mode 100644 dn_util_test.go diff --git a/dn_util.go b/dn_util.go new file mode 100644 index 00000000..e23abaf9 --- /dev/null +++ b/dn_util.go @@ -0,0 +1,216 @@ +package ldap + +import ( + enchex "encoding/hex" + "errors" + "strings" +) + +// When true, uses strings.EqualFold to compare the values of an RDN. +// This is usually needed as true for most values, set to false when +// your DNs contain case sensitive values. +var RDNCompareFold bool = true + +var ErrDNNotSubordinate = errors.New("Not a subordinate") + +// Returns the stringified version of a *DN, the RDN values are escaped +func (dn *DN) String() string { + var rdns []string + for _, r := range dn.RDNs { + var tv []string + for _, av := range r.Attributes { + tv = append(tv, strings.ToLower(av.Type)+"="+EscapeValue(av.Value)) + } + rdns = append(rdns, strings.Join(tv, "+")) + } + return strings.Join(rdns, ",") +} + +func EscapeValue(value string) (escaped string) { + for _, r := range value { + switch r { + case ',', '+', '"', '\\', '<', '>', ';', '#', '=': + escaped += "\\" + string(r) + default: + if uint(r) < 32 { + escaped += "\\" + enchex.EncodeToString([]byte(string(r))) + } else { + escaped += string(r) + } + } + } + return +} + +// check if all RDNs of both DNs are equal +func (dn *DN) Equal(other *DN) bool { + if len(dn.RDNs) != len(other.RDNs) { + return false + } + for i, rdn := range dn.RDNs { + if !rdn.Equal(other.RDNs[i]) { + return false + } + } + return true +} + +// Check if all types and values of both RDNs are equal, the result may be +// influenced by the value of RDNCompareFold. +func (r *RelativeDN) Equal(o *RelativeDN) bool { + if len(r.Attributes) != len(o.Attributes) { + return false + } + for i, av := range r.Attributes { + if strings.ToLower(av.Type) != strings.ToLower(o.Attributes[i].Type) { + return false + } + if RDNCompareFold { + if !strings.EqualFold(av.Value, o.Attributes[i].Value) { + return false + } + } else { + if av.Value != o.Attributes[i].Value { + return false + } + } + } + return true +} + +// Returns true if the "other" DN is a parent of "dn" +func (dn *DN) IsSubordinate(other *DN) bool { + if off := len(dn.RDNs) - len(other.RDNs); off <= 0 { + return false + } else { + for i, rdn := range other.RDNs { + if !rdn.Equal(dn.RDNs[i+off]) { + return false + } + } + } + return true +} + +// appends the "other" DN to the "dn", e.g. +// +// dn, err := ldap.ParseDN("CN=Someone") +// base, err := ldap.ParseDN("ou=people,dc=example,dc=org") +// dn.Append(base) -> "cn=Someone,ou=people,dc=example,dc=org" +func (dn *DN) Append(other *DN) { + dn.RDNs = append(dn.RDNs, other.RDNs...) +} + +// removes the "other" DN from the "dn", e.g. +// +// dn, err := ldap.ParseDN(""cn=Someone,ou=people,dc=example,dc=org") +// base, err := ldap.ParseDN("ou=people,dc=example,dc=org") +// dn.Strip(base) -> "cn=Someone" +// +// Note: the "other" DN must be a parent of the "dn" +func (dn *DN) Strip(base *DN) error { + if !dn.IsSubordinate(base) { + return ErrDNNotSubordinate + } + dn.RDNs = dn.RDNs[0 : len(dn.RDNs)-len(base.RDNs)] + return nil +} + +// Changes the first RDN of DN to the given one +func (dn *DN) Rename(rdn *RelativeDN) { + dn.RDNs[0] = rdn +} + +// Moves the first RDN to the new base +func (dn *DN) Move(newBase *DN) { + dn.RDNs = dn.RDNs[:1] + dn.Append(newBase) +} + +// Returns the value of the first RDN, e.g. +// +// dn, err := ldap.ParseDN("uid=someone,ou=people,dc=example,dc=org") +// dn.RDN() -> "someone" +func (dn *DN) RDN() string { + if len(dn.RDNs) == 0 || len(dn.RDNs[0].Attributes) == 0 { + return "" + } + return dn.RDNs[0].Attributes[0].Value +} + +// Returns the parent of the "dn" as a cloned *DN +func (dn *DN) Parent() *DN { + c := dn.Clone() + if len(c.RDNs) > 0 { + c.RDNs = c.RDNs[1:] + return c + } + c.RDNs = []*RelativeDN{} + return c +} + +// Returns a clone of the DN +func (dn *DN) Clone() *DN { + c := &DN{} + for _, r := range dn.RDNs { + rc := &RelativeDN{} + for _, tv := range r.Attributes { + rc.Attributes = append(rc.Attributes, &AttributeTypeAndValue{Type: tv.Type, Value: tv.Value}) + } + c.RDNs = append(c.RDNs, rc) + } + return c +} + +// Sorting DNs: +// all := []*ldap.DN{dn1, dn2, dn3, dn4} +// sort.Sort(DNs(all)) +// for _, dn := range all { +// println(dn.String()) +// } +// +// The result order from deepest part in tree upwards, so you could +// easily search for all dns in a base, sort them and then remove +// every DN in that order to remove the tree (including the search base) +type DNs []*DN + +func (d DNs) Len() int { + return len(([]*DN)(d)) +} + +func (d DNs) Swap(i, j int) { + ([]*DN)(d)[i], ([]*DN)(d)[j] = ([]*DN)(d)[j], ([]*DN)(d)[i] +} + +func (d DNs) Less(i, j int) bool { + if ([]*DN)(d)[i].IsSubordinate(([]*DN)(d)[j]) { + return true + } + if ([]*DN)(d)[i].Parent().Equal(([]*DN)(d)[j].Parent()) { + return ([]*DN)(d)[i].RDNs[0].Less(([]*DN)(d)[j].RDNs[0]) + } + return false +} + +func (r *RelativeDN) Less(o *RelativeDN) bool { + if len(r.Attributes) != len(o.Attributes) { + return len(r.Attributes) < len(o.Attributes) + } + for i, a := range r.Attributes { + if strings.ToLower(a.Type) < strings.ToLower(o.Attributes[i].Type) { + return true + } + if RDNCompareFold { + if strings.ToLower(a.Value) < strings.ToLower(o.Attributes[i].Value) { + return true + } + } else { + if a.Value < o.Attributes[i].Value { + return true + } + } + } + return false +} + +// vim: ts=4 sw=4 noexpandtab diff --git a/dn_util_test.go b/dn_util_test.go new file mode 100644 index 00000000..8b796756 --- /dev/null +++ b/dn_util_test.go @@ -0,0 +1,93 @@ +package ldap_test + +import ( + "fmt" + "gopkg.in/ldap.v2" + "sort" + "testing" +) + +func TestDNString(t *testing.T) { + fmt.Printf("DNString: starting...\n") + dn, _ := ldap.ParseDN("OU=Sales+CN=J. Smith,DC=example,DC=net") + strdn := dn.String() + if strdn != "ou=Sales+cn=J. Smith,dc=example,dc=net" { + t.Errorf("Failed to stringify: %v\n", strdn) + } + fmt.Printf("DNString: -> %v\n", strdn) + dn2, _ := ldap.ParseDN("CN=Lučić\\+Ma\\=> %s\n", dn) + } + if parent.String() != "dc=example,dc=net" { + t.Errorf("wrong parent -> %s\n", parent) + } + fmt.Printf("DN Parent: %s -> %s\n", dn, parent) +} + +func TestDNMove(t *testing.T) { + fmt.Printf("DN Rename and Move: starting...\n") + dn, _ := ldap.ParseDN("OU=Sales+CN=J. Smith,DC=example,DC=net") + base, _ := ldap.ParseDN("OU=People,DC=example,DC=net") + rdn, _ := ldap.ParseDN("cn=J. Smith") + dn.Move(base) + if dn.String() != "ou=Sales+cn=J. Smith,ou=People,dc=example,dc=net" { + t.Errorf("Failed to move: %s\n", dn) + } + dn.Rename(rdn.RDNs[0]) + if dn.String() != "cn=J. Smith,ou=People,dc=example,dc=net" { + t.Errorf("Failed to rename: %s\n", dn) + } + fmt.Printf("DN Rename and Move: %s\n", dn) +} + +func TestDNEqual(t *testing.T) { + dn1, _ := ldap.ParseDN("OU=people,DC=example,DC=org") + dn2, _ := ldap.ParseDN("ou=People,dc=Example,dc=ORG") + ldap.RDNCompareFold = true + if !dn1.Equal(dn2) { + t.Errorf("both dns not equal") + } + ldap.RDNCompareFold = false + if dn1.Equal(dn2) { + t.Errorf("both dns equal with ldap.RDNCompareFold = false") + } + ldap.RDNCompareFold = true +} + +func TestDNSort(t *testing.T) { + var dns []*ldap.DN + dnStrings := []string{ + "ou=people,dc=example,dc=org", + "uid=another,ou=people,dc=example,dc=org", + "uid=another+cn=one,ou=people,dc=example,dc=org", + "dc=example,dc=org", + "uid=someone,ou=people,dc=example,dc=org", + "ou=robots,dc=example,dc=org", + "uid=someone,ou=robots,dc=example,dc=org", + } + + for _, s := range dnStrings { + dn, _ := ldap.ParseDN(s) + dns = append(dns, dn) + } + sort.Sort(ldap.DNs(dns)) + for _, dn := range dns { + fmt.Printf("DN: %s\n", dn.String()) + } + if dns[len(dns)-1].String() != "dc=example,dc=org" { + t.Errorf("DN dc=example,dc=org is not last") + } + if dns[0].String() != "uid=another,ou=people,dc=example,dc=org" { + t.Errorf("DN uid=another,ou=people,dc=example,dc=org is not first") + } +} From a018e5c3b9cb1261c360d72698b4355c973f5c81 Mon Sep 17 00:00:00 2001 From: Hanno Hecker Date: Sun, 14 Feb 2016 09:57:30 +0100 Subject: [PATCH 2/5] pool - initial commit --- client.go | 1 + conn.go | 14 +++++ pool.go | 26 ++++++++ pool_channel.go | 161 ++++++++++++++++++++++++++++++++++++++++++++++++ pool_conn.go | 93 ++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+) create mode 100644 pool.go create mode 100644 pool_channel.go create mode 100644 pool_conn.go diff --git a/client.go b/client.go index d3401f9e..f2de556d 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ type Client interface { Start() StartTLS(config *tls.Config) error Close() + Alive() bool Bind(username, password string) error SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) diff --git a/conn.go b/conn.go index 2f16443f..577faf94 100644 --- a/conn.go +++ b/conn.go @@ -106,6 +106,20 @@ func NewConn(conn net.Conn, isTLS bool) *Conn { } } +func (l *Conn) Alive() bool { + _, err := l.Search(NewSearchRequest( + "", + ScopeBaseObject, + NeverDerefAliases, + 1, // max 1 result + int(DefaultTimeout.Seconds()), // FIXME? + false, + "(objectClass=*)", + []string{"dn"}, + []Control{})) + return err == nil +} + func (l *Conn) Start() { go l.reader() go l.processMessages() diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..c576587b --- /dev/null +++ b/pool.go @@ -0,0 +1,26 @@ +package ldap + +import ( + "errors" +) + +var ( + // ErrClosed is the error resulting if the pool is closed via pool.Close(). + ErrClosed = errors.New("pool is closed") +) + +// Pool interface describes a pool implementation. A pool should have maximum +// capacity. An ideal pool is threadsafe and easy to use. +type Pool interface { + // Get returns a new connection from the pool. Closing the connections puts + // it back to the Pool. Closing it when the pool is destroyed or full will + // be counted as an error. + Get() (*PoolConn, error) + + // Close closes the pool and all its connections. After Close() the pool is + // no longer usable. + Close() + + // Len returns the current number of connections of the pool. + Len() int +} diff --git a/pool_channel.go b/pool_channel.go new file mode 100644 index 00000000..15edac5f --- /dev/null +++ b/pool_channel.go @@ -0,0 +1,161 @@ +package ldap + +import ( + "errors" + "log" + "sync" +) + +// channelPool implements the Pool interface based on buffered channels. +type channelPool struct { + // storage for our net.Conn connections + mu sync.Mutex + conns chan Client + + name string + + // net.Conn generator + factory PoolFactory + closeAt []uint8 +} + +// PoolFactory is a function to create new connections. +type PoolFactory func(string) (Client, error) + +// NewChannelPool returns a new pool based on buffered channels with an initial +// capacity and maximum capacity. Factory is used when initial capacity is +// greater than zero to fill the pool. A zero initialCap doesn't fill the Pool +// until a new Get() is called. During a Get(), If there is no new connection +// available in the pool, a new connection will be created via the Factory() +// method. +// +// closeAt will automagically mark the connection as unusable if the return code +// of the call is one of those passed, most likely you want to set this to something +// like +// []uint8{ldap.LDAPResultTimeLimitExceeded, ldap.ErrorNetwork} +func NewChannelPool(initialCap, maxCap int, name string, factory PoolFactory, closeAt []uint8) (Pool, error) { + if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { + return nil, errors.New("invalid capacity settings") + } + + c := &channelPool{ + conns: make(chan Client, maxCap), + name: name, + factory: factory, + closeAt: closeAt, + } + + // create initial connections, if something goes wrong, + // just close the pool error out. + for i := 0; i < initialCap; i++ { + conn, err := factory(c.name) + log.Printf("init connection: %v", conn) + if err != nil { + c.Close() + return nil, errors.New("factory is not able to fill the pool: " + err.Error()) + } + c.conns <- conn + } + + return c, nil +} + +func (c *channelPool) getConns() chan Client { + c.mu.Lock() + conns := c.conns + c.mu.Unlock() + return conns +} + +// Get implements the Pool interfaces Get() method. If there is no new +// connection available in the pool, a new connection will be created via the +// Factory() method. +func (c *channelPool) Get() (*PoolConn, error) { + conns := c.getConns() + if conns == nil { + return nil, ErrClosed + } + + // wrap our connections with our ldap.Client implementation (wrapConn + // method) that puts the connection back to the pool if it's closed. + select { + case conn := <-conns: + if conn == nil { + return nil, ErrClosed + } + // log.Printf("existing conn: %v", conn) + if conn.Alive() { + return c.wrapConn(conn, c.closeAt), nil + } + + log.Printf("connection dead: %v", conn) + conn.Close() + return c.NewConn() + default: + return c.NewConn() + } +} + +func (c *channelPool) NewConn() (*PoolConn, error) { + conn, err := c.factory(c.name) + log.Printf("new connection: %v", conn) + if err != nil { + return nil, err + } + return c.wrapConn(conn, c.closeAt), nil +} + +// put puts the connection back to the pool. If the pool is full or closed, +// conn is simply closed. A nil conn will be rejected. +func (c *channelPool) put(conn Client) { + if conn == nil { + log.Printf("connection is nil. rejecting") + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + if c.conns == nil { + // pool is closed, close passed connection + conn.Close() + return + } + + // put the resource back into the pool. If the pool is full, this will + // block and the default case will be executed. + select { + case c.conns <- conn: + return + default: + // pool is full, close passed connection + conn.Close() + return + } +} + +func (c *channelPool) Close() { + c.mu.Lock() + conns := c.conns + c.conns = nil + c.factory = nil + c.mu.Unlock() + + if conns == nil { + return + } + + close(conns) + for conn := range conns { + conn.Close() + } + return +} + +func (c *channelPool) Len() int { return len(c.getConns()) } + +func (c *channelPool) wrapConn(conn Client, closeAt []uint8) *PoolConn { + p := &PoolConn{c: c, closeAt: closeAt} + p.Conn = conn + return p +} diff --git a/pool_conn.go b/pool_conn.go new file mode 100644 index 00000000..e797b969 --- /dev/null +++ b/pool_conn.go @@ -0,0 +1,93 @@ +package ldap + +import ( + "crypto/tls" + "log" +) + +// PoolConn implements Client to override the Close() method +type PoolConn struct { + Conn Client + c *channelPool + unusable bool + closeAt []uint8 +} + +func (p *PoolConn) Start() { + p.Conn.Start() +} + +func (p *PoolConn) StartTLS(config *tls.Config) error { + // FIXME - check if already TLS and then ignore? + return p.Conn.StartTLS(config) +} + +// Close() puts the given connects back to the pool instead of closing it. +func (p *PoolConn) Close() { + if p.unusable { + log.Printf("Closing unusable connection %v", p.Conn) + if p.Conn != nil { + p.Conn.Close() + } + return + } + p.c.put(p.Conn) +} + +func (p *PoolConn) Alive() bool { + if !p.Conn.Alive() { + p.MarkUnusable() + return false + } + return true +} + +func (p *PoolConn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { + return p.Conn.SimpleBind(simpleBindRequest) +} + +func (p *PoolConn) Bind(username, password string) error { + return p.Conn.Bind(username, password) +} + +// MarkUnusable() marks the connection not usable any more, to let the pool close it +// instead of returning it to pool. +func (p *PoolConn) MarkUnusable() { + p.unusable = true +} + +func (p *PoolConn) autoClose(err error) { + for _, code := range p.closeAt { + if IsErrorWithCode(err, code) { + p.MarkUnusable() + return + } + } +} + +func (p *PoolConn) Add(addRequest *AddRequest) error { + return p.Conn.Add(addRequest) +} + +func (p *PoolConn) Del(delRequest *DelRequest) error { + return p.Conn.Del(delRequest) +} + +func (p *PoolConn) Modify(modifyRequest *ModifyRequest) error { + return p.Conn.Modify(modifyRequest) +} + +func (p *PoolConn) Compare(dn, attribute, value string) (bool, error) { + return p.Conn.Compare(dn, attribute, value) +} + +func (p *PoolConn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { + return p.Conn.PasswordModify(passwordModifyRequest) +} + +func (p *PoolConn) Search(searchRequest *SearchRequest) (*SearchResult, error) { + return p.Conn.Search(searchRequest) +} +func (p *PoolConn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { + return p.Conn.SearchWithPaging(searchRequest, pagingSize) +} From 626a90b71eaad9a426f0da950d12426cc4715ca4 Mon Sep 17 00:00:00 2001 From: Hanno Hecker Date: Sun, 14 Feb 2016 12:13:37 +0100 Subject: [PATCH 3/5] add basic ldif reader --- ldif.go | 256 +++++++++++++++++++++++++++++++++++++++++++++++++++ ldif_test.go | 118 ++++++++++++++++++++++++ 2 files changed, 374 insertions(+) create mode 100644 ldif.go create mode 100644 ldif_test.go diff --git a/ldif.go b/ldif.go new file mode 100644 index 00000000..9265a23a --- /dev/null +++ b/ldif.go @@ -0,0 +1,256 @@ +package ldap + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + // "os" + "strconv" +) + +// A basic LDIF parser. This one does currently just support LDIFs like +// they're generated by tools like ldapsearch(1) / slapcat(8). Change +// records are not supported. +type LDIF struct { + RelaxedParser bool // ignore parsing errors in a line + Entries []*Entry + curLine int + version int + changeType string +} + +var CR byte = '\x0D' +var LF byte = '\x0A' +var SEP string = string([]byte{CR, LF}) +var Comment byte = '#' +var SPACE byte = ' ' +var SPACES = string(SPACE) + +func (l *LDIF) newError(msg string) error { + return errors.New(fmt.Sprintf("error on line %d: %s\n", l.curLine, msg)) +} + +// Parses the LDIF, the caller is responsible for closing the io.Reader if that's +// needed. This Parser may be called several times with different io.Readers to +// combine the files. +func (l *LDIF) Parse(r io.Reader) (err error) { + if r == nil { + return errors.New("No reader present") + } + reader := bufio.NewReader(r) + l.curLine = 0 + l.changeType = "" + l.version = 0 + + var lines [][]byte + var line, nextLine []byte + + for { + l.curLine++ + nextLine, err = reader.ReadBytes(LF) + nextLine = bytes.TrimRight(nextLine, SEP) + // fmt.Fprintf(os.Stderr, "NEXT=>%s<\n", nextLine) + switch err { + case nil, io.EOF: + switch len(nextLine) { + case 0: + if len(line) == 0 && err == io.EOF { + return nil + } + lines = append(lines, line) + entry, perr := l.parseEntry(lines) + if perr != nil { + if !l.RelaxedParser { + return l.newError(perr.Error()) + } + } + l.Entries = append(l.Entries, entry) + line = []byte{} + lines = [][]byte{} + if err == io.EOF { + return nil + } + default: + switch nextLine[0] { + case Comment: + continue + case SPACE: + if len(nextLine) > 1 { + line = append(line, nextLine[1:]...) + continue + } else { + return l.newError("space only line") + } + default: + if len(line) != 0 { + lines = append(lines, line) + } + line = nextLine + continue + } + } + default: + return l.newError(err.Error()) + } + } + if len(lines) != 0 { + entry, perr := l.parseEntry(lines) + if perr != nil { + return l.newError(perr.Error()) + } + l.Entries = append(l.Entries, entry) + } + return nil +} + +func (l *LDIF) parseEntry(lines [][]byte) (entry *Entry, err error) { + // for i, line := range lines { + // fmt.Fprintf(os.Stderr, "% 2d %s\n", i, line) + // } + if l.version == 0 && bytes.HasPrefix(lines[0], []byte("version:")) { + line := bytes.TrimLeft(lines[0][8:], SPACES) + if version, err := strconv.Atoi(string(line)); err != nil { + return nil, err + } else { + if version != 1 { + return nil, errors.New("Invalid version spec " + string(line)) + } + l.version = 1 + lines = lines[1:] + } + } + + if !bytes.HasPrefix(lines[0], []byte("dn:")) { + return nil, errors.New("Missing dn:") + } + _, val, err := l.parseLine(lines[0]) + if err != nil { + return nil, err + } + dn := val + + lines = lines[1:] + if bytes.HasPrefix(lines[0], []byte("changetype:")) { + _, val, err := l.parseLine(lines[0]) + if err != nil { + return nil, err + } + l.changeType = val + lines = lines[1:] + } + if l.changeType != "" { + return nil, errors.New("change records not supported") + } + + attrs := make(map[string][]string) + for i := 0; i < len(lines); i++ { + attr, val, err := l.parseLine(lines[i]) + if err != nil { + if !l.RelaxedParser { + return nil, err + } else { + continue + } + } + if _, ok := attrs[attr]; ok { + attrs[attr] = append(attrs[attr], string(val)) + } else { + attrs[attr] = []string{string(val)} + } + } + return NewEntry(dn, attrs), nil +} + +func (l *LDIF) parseLine(line []byte) (attr, val string, err error) { + off := 0 + for len(line) > off && line[off] != ':' { + off++ + if off >= len(line) { + return + } + } + if off == len(line) { + err = errors.New("Missing : in line") + return + } + if off > len(line)-2 { + err = errors.New("empty value") + return + } + + attr = string(line[0:off]) + if err = validAttr(attr); err != nil { + attr = "" + val = "" + return + } + + switch line[off+1] { + case ':': + var n int + value := bytes.TrimLeft(line[off+2:], SPACES) + // fmt.Fprintf(os.Stderr, "LINE=%s\nVALUE=%s\n", line, value) + dec := make([]byte, base64.StdEncoding.DecodedLen(len(value))) + n, err = base64.StdEncoding.Decode(dec, value) + if err != nil { + return + } + val = string(dec[:n]) + case '<': // FIXME missing return for *net.URL type + val = string(bytes.TrimLeft(line[off+2:], SPACES)) + default: + val = string(bytes.TrimLeft(line[off+2:], SPACES)) + } + + return +} + +func validOID(oid string) error { + lastDot := true + for _, c := range oid { + switch { + case c == '.' && lastDot: + return errors.New("OID with at least 2 consecutive dots") + case c == '.': + lastDot = true + case c >= 0x30 && c <= 0x39: + lastDot = false + default: + return errors.New("Invalid character in OID") + } + } + return nil +} + +func validAttr(attr string) error { + if len(attr) == 0 { + return errors.New("empty attribute name") + } + switch { + case attr[0] >= 0x41 && attr[0] <= 0x5A: + // A-Z + case attr[0] >= 0x61 && attr[0] <= 0x7A: + // a-z + default: + if attr[0] >= 0x30 && attr[0] <= 0x39 { + return validOID(attr) + } + return errors.New("invalid first character in attribute") + } + for i := 1; i < len(attr); i++ { + c := attr[i] + switch { + case c >= 0x30 && c <= 0x39: + case c >= 0x41 && c <= 0x5A: + case c >= 0x61 && c <= 0x7A: + case c == '-': + case c == ';': + default: + return errors.New("invalid character in attribute name") + } + } + return nil +} diff --git a/ldif_test.go b/ldif_test.go new file mode 100644 index 00000000..e97dc3f9 --- /dev/null +++ b/ldif_test.go @@ -0,0 +1,118 @@ +package ldap_test + +import ( + "bytes" + "gopkg.in/ldap.v2" + "testing" +) + +var ldifRFC2849Example = `version: 1 +dn: cn=Barbara Jensen, ou=Product Development, dc=airius, dc=com +objectclass: top +objectclass: person +objectclass: organizationalPerson +cn: Barbara Jensen +cn: Barbara J Jensen +cn: Babs Jensen +sn: Jensen +uid: bjensen +telephonenumber: +1 408 555 1212 +description: A big sailing fan. + +dn: cn=Bjorn Jensen, ou=Accounting, dc=airius, dc=com +objectclass: top +objectclass: person +objectclass: organizationalPerson +cn: Bjorn Jensen +sn: Jensen +telephonenumber: +1 408 555 1212 +` + +func TestLDIFParseRFC2849Example(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifRFC2849Example)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err != nil { + t.Errorf("Failed to parse RFC 2849 example: %s", err) + } +} + +var ldifEmpty = `dn: uid=someone,dc=example,dc=org +cn: +cn: Some User +` + +func TestLDIFParseEmptyAttr(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifEmpty)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err == nil { + t.Errorf("Did not fail to parse empty attribute") + } +} + +var ldifMissingDN = `objectclass: top +cn: Some User +` + +func TestLDIFParseMissingDN(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifMissingDN)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err == nil { + t.Errorf("Did not fail to parse missing DN attribute") + } +} + +var ldifContinuation = `dn: uid=someone,dc=example,dc=org +sn: Some + One +cn: Someone +` + +func TestLDIFContinuation(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifContinuation)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err != nil { + t.Errorf("Failed to parse LDIF: %s", err) + } + e := l.Entries[0] + if e.GetAttributeValues("sn")[0] != "Some One" { + t.Errorf("Value of continuation line wrong") + } +} + +var ldifBase64 = `dn: uid=someone,dc=example,dc=org +sn:: U29tZSBPbmU= +` + +func TestLDIFBase64(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifBase64)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err != nil { + t.Errorf("Failed to parse LDIF: %s", err) + } + + e := l.Entries[0] + val := e.GetAttributeValues("sn")[0] + cmp := "Some One" + if val != cmp { + t.Errorf("Value of base64 value wrong: >%v< >%v<", []byte(val), []byte(cmp)) + } +} + +var ldifTrailingBlank = `dn: uid=someone,dc=example,dc=org +sn:: U29tZSBPbmU= + +` + +func TestLDIFTrailingBlank(t *testing.T) { + ex := bytes.NewBuffer([]byte(ldifTrailingBlank)) + l := &ldap.LDIF{} + err := l.Parse(ex) + if err != nil { + t.Errorf("Failed to parse LDIF: %s", err) + } +} From 6044b27c69d080dab72ad7e7c599e3251091321d Mon Sep 17 00:00:00 2001 From: Hanno Hecker Date: Thu, 5 May 2016 12:13:30 +0200 Subject: [PATCH 4/5] Add Proxied Authorization (RFC 4370), Who Am I? (RFC 4532) --- client.go | 1 + control.go | 34 +++++++++++++++++ error.go | 6 +++ whoami.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++ whoami_test.go | 27 ++++++++++++++ 5 files changed, 167 insertions(+) create mode 100644 whoami.go create mode 100644 whoami_test.go diff --git a/client.go b/client.go index 055b27b5..be178a95 100644 --- a/client.go +++ b/client.go @@ -21,6 +21,7 @@ type Client interface { Compare(dn, attribute, value string) (bool, error) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) + WhoAmI(controls []Control) (*WhoAmIResult, error) Search(searchRequest *SearchRequest) (*SearchResult, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) diff --git a/control.go b/control.go index 4d829809..99c6995f 100644 --- a/control.go +++ b/control.go @@ -17,12 +17,14 @@ const ( ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4" ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5" ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2" + ControlTypeProxiedAuthorization = "2.16.840.1.113730.3.4.18" ) var ControlTypeMap = map[string]string{ ControlTypePaging: "Paging", ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", ControlTypeManageDsaIT: "Manage DSA IT", + ControlTypeProxiedAuthorization: "Proxied Authorization", } type Control interface { @@ -197,6 +199,38 @@ func NewControlManageDsaIT(Criticality bool) *ControlManageDsaIT { return &ControlManageDsaIT{Criticality: Criticality} } +type ControlProxiedAuthorization struct { + Criticality bool + AuthzId string +} + +func (c *ControlProxiedAuthorization) GetControlType() string { + return ControlTypeProxiedAuthorization +} + +func (c *ControlProxiedAuthorization) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeProxiedAuthorization, "Control Type ("+ControlTypeMap[ControlTypeProxiedAuthorization]+")")) + packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.AuthzId, "AuthzId")) + return packet +} + +func (c *ControlProxiedAuthorization) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t", + ControlTypeMap[ControlTypeProxiedAuthorization], + ControlTypeProxiedAuthorization, + c.Criticality) +} + +func NewControlProxiedAuthoization(authzId string) *ControlProxiedAuthorization { + return &ControlProxiedAuthorization{ + Criticality: true, + AuthzId: authzId, + } +} + func FindControl(controls []Control, controlType string) Control { for _, c := range controls { if c.GetControlType() == controlType { diff --git a/error.go b/error.go index 97404eb6..420bae5e 100644 --- a/error.go +++ b/error.go @@ -47,6 +47,11 @@ const ( LDAPResultObjectClassModsProhibited = 69 LDAPResultAffectsMultipleDSAs = 71 LDAPResultOther = 80 + // https://tools.ietf.org/html/rfc4370 chap 6: + // "A result code (123) has been assigned by the IANA for the case where + // the server does not execute a request using the proxy authorization + // identity." + LDAPResultAuthorizationDenied = 123 ErrorNetwork = 200 ErrorFilterCompile = 201 @@ -95,6 +100,7 @@ var LDAPResultCodeMap = map[uint8]string{ LDAPResultEntryAlreadyExists: "Entry Already Exists", LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", + LDAPResultAuthorizationDenied: "Authorization Denied", LDAPResultOther: "Other", } diff --git a/whoami.go b/whoami.go new file mode 100644 index 00000000..3a21bbe0 --- /dev/null +++ b/whoami.go @@ -0,0 +1,99 @@ +// This file contains the "Who Am I?" extended operation as specified in rfc 4532 +// +// https://tools.ietf.org/html/rfc4532 +// + +package ldap + +import ( + "errors" + "fmt" + + "gopkg.in/asn1-ber.v1" +) + +const ( + whoamiOID = "1.3.6.1.4.1.4203.1.11.3" +) + +type WhoAmIRequest bool + +type WhoAmIResult struct { + AuthzId string +} + +func (r WhoAmIRequest) encode() (*ber.Packet, error) { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Who Am I? Extended Operation") + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, whoamiOID, "Extended Request Name: Who Am I? OID")) + return request, nil +} + +func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) { + messageID := l.nextMessageID() + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + req := WhoAmIRequest(true) + encodedWhoAmIRequest, err := req.encode() + if err != nil { + return nil, err + } + packet.AppendChild(encodedWhoAmIRequest) + + if len(controls) != 0 { + packet.AppendChild(encodeControls(controls)) + } + + l.Debug.PrintPacket(packet) + + channel, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + if channel == nil { + return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message")) + } + defer l.finishMessage(messageID) + + result := &WhoAmIResult{} + + l.Debug.Printf("%d: waiting for response", messageID) + packetResponse, ok := <-channel + if !ok { + return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", messageID, packet) + if err != nil { + return nil, err + } + + if packet == nil { + return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + + if packet.Children[1].Tag == ApplicationExtendedResponse { + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return nil, NewError(resultCode, errors.New(resultDescription)) + } + } else { + return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag)) + } + + extendedResponse := packet.Children[1] + for _, child := range extendedResponse.Children { + if child.Tag == 11 { + result.AuthzId = ber.DecodeString(child.Data.Bytes()) + } + } + + return result, nil +} diff --git a/whoami_test.go b/whoami_test.go new file mode 100644 index 00000000..7a9a590e --- /dev/null +++ b/whoami_test.go @@ -0,0 +1,27 @@ +package ldap_test + +import ( + "fmt" + "gopkg.in/ldap.v2" +) + +func ExampleWhoAmI() { + conn, err := ldap.Dial("tcp", "ldap.example.org:389") + if err != nil { + fmt.Errorf("Failed to connect: %s\n", err) + } + + _, err = conn.SimpleBind(&ldap.SimpleBindRequest{ + Username: "uid=someone,ou=people,dc=example,dc=org", + Password: "MySecretPass", + }) + if err != nil { + fmt.Errorf("Failed to bind: %s\n", err) + } + + res, err := conn.WhoAmI(nil) + if err != nil { + fmt.Errorf("Failed to call WhoAmI(): %s\n", err) + } + fmt.Printf("%s\n", res.AuthzId) +} From 0f49b424415e0b17957976e698c899d06faa3368 Mon Sep 17 00:00:00 2001 From: Hanno Hecker Date: Fri, 6 May 2016 08:38:50 +0200 Subject: [PATCH 5/5] add proxy auth whoami example --- whoami_test.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/whoami_test.go b/whoami_test.go index 7a9a590e..9ad5d041 100644 --- a/whoami_test.go +++ b/whoami_test.go @@ -2,6 +2,7 @@ package ldap_test import ( "fmt" + "gopkg.in/ldap.v2" ) @@ -23,5 +24,28 @@ func ExampleWhoAmI() { if err != nil { fmt.Errorf("Failed to call WhoAmI(): %s\n", err) } - fmt.Printf("%s\n", res.AuthzId) + fmt.Printf("I am: %s\n", res.AuthzId) +} + +func ExampleWhoAmIProxied() { + conn, err := ldap.Dial("tcp", "ldap.example.org:389") + if err != nil { + fmt.Errorf("Failed to connect: %s\n", err) + } + + _, err = conn.SimpleBind(&ldap.SimpleBindRequest{ + Username: "uid=someone,ou=people,dc=example,dc=org", + Password: "MySecretPass", + }) + if err != nil { + fmt.Errorf("Failed to bind: %s\n", err) + } + + pa := ldap.NewProxyAuthControl("dn:uid=other,ou=people,dc=example,dc=org") + + res, err := conn.WhoAmI([]ldap.Control{pa}) + if err != nil { + fmt.Errorf("Failed to call WhoAmI(): %s\n", err) + } + fmt.Printf("For this call only I am now: %s\n", res.AuthzId) }