diff --git a/graph.go b/graph.go index d7f0a29..2c9604a 100644 --- a/graph.go +++ b/graph.go @@ -14,10 +14,15 @@ import ( type Vector = []float32 +// Node is a node in the graph. +type Node[K cmp.Ordered] struct { + ID K + Vec Vector +} + // layerNode is a node in a layer of the graph. type layerNode[K cmp.Ordered] struct { - id K - vec Vector + Node[K] // neighbors is map of neighbor IDs to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. @@ -32,7 +37,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu n.neighbors = make(map[K]*layerNode[K], m) } - n.neighbors[newNode.id] = newNode + n.neighbors[newNode.ID] = newNode if len(n.neighbors) <= m { return } @@ -43,7 +48,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu worst *layerNode[K] ) for _, neighbor := range n.neighbors { - d := dist(neighbor.vec, n.vec) + d := dist(neighbor.Vec, n.Vec) // d > worstDist may always be false if the distance function // returns NaN, e.g., when the embeddings are zero. if d > worstDist || worst == nil { @@ -52,9 +57,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu } } - delete(n.neighbors, worst.id) + delete(n.neighbors, worst.ID) // Delete backlink from the worst neighbor. - delete(worst.neighbors, n.id) + delete(worst.neighbors, n.ID) worst.replenish(m) } @@ -83,7 +88,7 @@ func (n *layerNode[K]) search( candidates.Push( searchCandidate[K]{ node: n, - dist: distance(n.vec, target), + dist: distance(n.Vec, target), }, ) var ( @@ -94,7 +99,7 @@ func (n *layerNode[K]) search( // Begin with the entry node in the result set. result.Push(candidates.Min()) - visited[n.id] = true + visited[n.ID] = true for candidates.Len() > 0 { var ( @@ -113,7 +118,7 @@ func (n *layerNode[K]) search( } visited[neighborID] = true - dist := distance(neighbor.vec, target) + dist := distance(neighbor.Vec, target) improved = improved || dist < result.Min().dist if result.Len() < k { result.Push(searchCandidate[K]{node: neighbor, dist: dist}) @@ -168,7 +173,7 @@ func (n *layerNode[K]) replenish(m int) { // to neighbors. func (n *layerNode[K]) isolate(m int) { for _, neighbor := range n.neighbors { - delete(neighbor.neighbors, n.id) + delete(neighbor.neighbors, n.ID) neighbor.replenish(m) } } @@ -179,7 +184,7 @@ type layer[K cmp.Ordered] struct { // property of the graph. // // nodes is exported for interop with encoding/gob. - nodes map[string]*layerNode[K] + nodes map[K]*layerNode[K] } // entry returns the entry node of the layer. @@ -237,8 +242,8 @@ func defaultRand() *rand.Rand { // NewGraph returns a new graph with default parameters, roughly designed for // storing OpenAI embeddings. -func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] { - return &Graph[K, V]{ +func NewGraph[K cmp.Ordered]() *Graph[K] { + return &Graph[K]{ M: 16, Ml: 0.25, Distance: CosineDistance, @@ -307,38 +312,48 @@ func (g *Graph[T]) Dims() int { if len(g.layers) == 0 { return 0 } - return len(g.layers[0].entry().Point.Embedding()) + return len(g.layers[0].entry().Vec) +} + +func ptr[T any](v T) *T { + return &v } // Add inserts nodes into the graph. // If another node with the same ID exists, it is replaced. -func (g *Graph[T]) Add(nodes ...T) { - for _, n := range nodes { - g.assertDims(n.Embedding()) +func (g *Graph[K]) Add(nodes ...Node[K]) { + for _, node := range nodes { + id := node.ID + vec := node.Vec + + g.assertDims(vec) insertLevel := g.randomLevel() // Create layers that don't exist yet. for insertLevel >= len(g.layers) { - g.layers = append(g.layers, &layer[T]{}) + g.layers = append(g.layers, &layer[K]{}) } if insertLevel < 0 { panic("invalid level") } - var elevator string + var elevator *K preLen := g.Len() // Insert node at each layer, beginning with the highest. for i := len(g.layers) - 1; i >= 0; i-- { layer := g.layers[i] - newNode := &layerNode[T]{ - vec: n, + newNode := &layerNode[K]{ + Node: Node[K]{ + ID: id, + Vec: vec, + }, } // Insert the new node into the layer. if layer.entry() == nil { - layer.Nodes = map[string]*layerNode[T]{n.ID(): newNode} + layer.nodes = map[K]*layerNode[K]{id: newNode} continue } @@ -348,15 +363,15 @@ func (g *Graph[T]) Add(nodes ...T) { // On subsequent layers, we use the elevator node to enter the graph // at the best point. - if elevator != "" { - searchPoint = layer.Nodes[elevator] + if elevator != nil { + searchPoint = layer.nodes[*elevator] } if g.Distance == nil { panic("(*Graph).Distance must be set") } - neighborhood := searchPoint.search(g.M, g.EfSearch, n.Embedding(), g.Distance) + neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance) if len(neighborhood) == 0 { // This should never happen because the searchPoint itself // should be in the result set. @@ -364,14 +379,14 @@ func (g *Graph[T]) Add(nodes ...T) { } // Re-set the elevator node for the next layer. - elevator = neighborhood[0].node.Point.ID() + elevator = ptr(neighborhood[0].node.ID) if insertLevel >= i { - if _, ok := layer.Nodes[n.ID()]; ok { - g.Delete(n.ID()) + if _, ok := layer.nodes[id]; ok { + g.Delete(id) } // Insert the new node into the layer. - layer.Nodes[n.ID()] = newNode + layer.nodes[id] = newNode for _, node := range neighborhood { // Create a bi-directional edge between the new node and the best node. node.node.addNeighbor(newNode, g.M, g.Distance) @@ -388,7 +403,7 @@ func (g *Graph[T]) Add(nodes ...T) { } // Search finds the k nearest neighbors from the target node. -func (h *Graph[T]) Search(near Vector, k int) []T { +func (h *Graph[K]) Search(near Vector, k int) []Node[K] { h.assertDims(near) if len(h.layers) == 0 { return nil @@ -397,27 +412,27 @@ func (h *Graph[T]) Search(near Vector, k int) []T { var ( efSearch = h.EfSearch - elevator string + elevator *K ) for layer := len(h.layers) - 1; layer >= 0; layer-- { searchPoint := h.layers[layer].entry() - if elevator != "" { - searchPoint = h.layers[layer].Nodes[elevator] + if elevator != nil { + searchPoint = h.layers[layer].nodes[*elevator] } // Descending hierarchies if layer > 0 { nodes := searchPoint.search(1, efSearch, near, h.Distance) - elevator = nodes[0].node.Point.ID() + elevator = ptr(nodes[0].node.ID) continue } nodes := searchPoint.search(k, efSearch, near, h.Distance) - out := make([]T, 0, len(nodes)) + out := make([]Node[K], 0, len(nodes)) for _, node := range nodes { - out = append(out, node.node.Point.(T)) + out = append(out, node.node.Node) } return out @@ -437,18 +452,18 @@ func (h *Graph[T]) Len() int { // Delete removes a node from the graph by ID. // It tries to preserve the clustering properties of the graph by // replenishing connectivity in the affected neighborhoods. -func (h *Graph[T]) Delete(id string) bool { +func (h *Graph[K]) Delete(id K) bool { if len(h.layers) == 0 { return false } var deleted bool for _, layer := range h.layers { - node, ok := layer.Nodes[id] + node, ok := layer.nodes[id] if !ok { continue } - delete(layer.Nodes, id) + delete(layer.nodes, id) node.isolate(h.M) deleted = true } @@ -456,12 +471,15 @@ func (h *Graph[T]) Delete(id string) bool { return deleted } -// Lookup returns the node with the given ID. -func (h *Graph[T]) Lookup(id string) (T, bool) { - var zero T +// Lookup returns the vector with the given ID. +func (h *Graph[K]) Lookup(id K) (Vector, bool) { if len(h.layers) == 0 { - return zero, false + return nil, false } - return h.layers[0].Nodes[id].Point.(T), true + node, ok := h.layers[0].nodes[id] + if !ok { + return nil, false + } + return node.Vec, ok }