Skip to content

Commit

Permalink
WIP: graph.go compiles!
Browse files Browse the repository at this point in the history
  • Loading branch information
ammario committed May 31, 2024
1 parent 241409c commit 4875617
Showing 1 changed file with 62 additions and 44 deletions.
106 changes: 62 additions & 44 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ import (

type Vector = []float32

// Node is a node in the graph.
type Node[K cmp.Ordered] struct {
ID K
Vec Vector
}

// layerNode is a node in a layer of the graph.
type layerNode[K cmp.Ordered] struct {
id K
vec Vector
Node[K]

// neighbors is map of neighbor IDs to neighbor nodes.
// It is a map and not a slice to allow for efficient deletes, esp.
Expand All @@ -32,7 +37,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
n.neighbors = make(map[K]*layerNode[K], m)
}

n.neighbors[newNode.id] = newNode
n.neighbors[newNode.ID] = newNode
if len(n.neighbors) <= m {
return
}
Expand All @@ -43,7 +48,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
worst *layerNode[K]
)
for _, neighbor := range n.neighbors {
d := dist(neighbor.vec, n.vec)
d := dist(neighbor.Vec, n.Vec)
// d > worstDist may always be false if the distance function
// returns NaN, e.g., when the embeddings are zero.
if d > worstDist || worst == nil {
Expand All @@ -52,9 +57,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
}
}

delete(n.neighbors, worst.id)
delete(n.neighbors, worst.ID)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.id)
delete(worst.neighbors, n.ID)
worst.replenish(m)
}

Expand Down Expand Up @@ -83,7 +88,7 @@ func (n *layerNode[K]) search(
candidates.Push(
searchCandidate[K]{
node: n,
dist: distance(n.vec, target),
dist: distance(n.Vec, target),
},
)
var (
Expand All @@ -94,7 +99,7 @@ func (n *layerNode[K]) search(

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.id] = true
visited[n.ID] = true

for candidates.Len() > 0 {
var (
Expand All @@ -113,7 +118,7 @@ func (n *layerNode[K]) search(
}
visited[neighborID] = true

dist := distance(neighbor.vec, target)
dist := distance(neighbor.Vec, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
Expand Down Expand Up @@ -168,7 +173,7 @@ func (n *layerNode[K]) replenish(m int) {
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.id)
delete(neighbor.neighbors, n.ID)
neighbor.replenish(m)
}
}
Expand All @@ -179,7 +184,7 @@ type layer[K cmp.Ordered] struct {
// property of the graph.
//
// nodes is exported for interop with encoding/gob.
nodes map[string]*layerNode[K]
nodes map[K]*layerNode[K]
}

// entry returns the entry node of the layer.
Expand Down Expand Up @@ -237,8 +242,8 @@ func defaultRand() *rand.Rand {

// NewGraph returns a new graph with default parameters, roughly designed for
// storing OpenAI embeddings.
func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] {
return &Graph[K, V]{
func NewGraph[K cmp.Ordered]() *Graph[K] {
return &Graph[K]{
M: 16,
Ml: 0.25,
Distance: CosineDistance,
Expand Down Expand Up @@ -307,38 +312,48 @@ func (g *Graph[T]) Dims() int {
if len(g.layers) == 0 {
return 0
}
return len(g.layers[0].entry().Point.Embedding())
return len(g.layers[0].entry().Vec)
}

func ptr[T any](v T) *T {
return &v
}

// Add inserts nodes into the graph.
// If another node with the same ID exists, it is replaced.
func (g *Graph[T]) Add(nodes ...T) {
for _, n := range nodes {
g.assertDims(n.Embedding())
func (g *Graph[K]) Add(nodes ...Node[K]) {
for _, node := range nodes {
id := node.ID
vec := node.Vec

g.assertDims(vec)
insertLevel := g.randomLevel()
// Create layers that don't exist yet.
for insertLevel >= len(g.layers) {
g.layers = append(g.layers, &layer[T]{})
g.layers = append(g.layers, &layer[K]{})
}

if insertLevel < 0 {
panic("invalid level")
}

var elevator string
var elevator *K

preLen := g.Len()

// Insert node at each layer, beginning with the highest.
for i := len(g.layers) - 1; i >= 0; i-- {
layer := g.layers[i]
newNode := &layerNode[T]{
vec: n,
newNode := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
},
}

// Insert the new node into the layer.
if layer.entry() == nil {
layer.Nodes = map[string]*layerNode[T]{n.ID(): newNode}
layer.nodes = map[K]*layerNode[K]{id: newNode}
continue
}

Expand All @@ -348,30 +363,30 @@ func (g *Graph[T]) Add(nodes ...T) {

// On subsequent layers, we use the elevator node to enter the graph
// at the best point.
if elevator != "" {
searchPoint = layer.Nodes[elevator]
if elevator != nil {
searchPoint = layer.nodes[*elevator]
}

if g.Distance == nil {
panic("(*Graph).Distance must be set")
}

neighborhood := searchPoint.search(g.M, g.EfSearch, n.Embedding(), g.Distance)
neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance)
if len(neighborhood) == 0 {
// This should never happen because the searchPoint itself
// should be in the result set.
panic("no nodes found")
}

// Re-set the elevator node for the next layer.
elevator = neighborhood[0].node.Point.ID()
elevator = ptr(neighborhood[0].node.ID)

if insertLevel >= i {
if _, ok := layer.Nodes[n.ID()]; ok {
g.Delete(n.ID())
if _, ok := layer.nodes[id]; ok {
g.Delete(id)
}
// Insert the new node into the layer.
layer.Nodes[n.ID()] = newNode
layer.nodes[id] = newNode
for _, node := range neighborhood {
// Create a bi-directional edge between the new node and the best node.
node.node.addNeighbor(newNode, g.M, g.Distance)
Expand All @@ -388,7 +403,7 @@ func (g *Graph[T]) Add(nodes ...T) {
}

// Search finds the k nearest neighbors from the target node.
func (h *Graph[T]) Search(near Vector, k int) []T {
func (h *Graph[K]) Search(near Vector, k int) []Node[K] {
h.assertDims(near)
if len(h.layers) == 0 {
return nil
Expand All @@ -397,27 +412,27 @@ func (h *Graph[T]) Search(near Vector, k int) []T {
var (
efSearch = h.EfSearch

elevator string
elevator *K
)

for layer := len(h.layers) - 1; layer >= 0; layer-- {
searchPoint := h.layers[layer].entry()
if elevator != "" {
searchPoint = h.layers[layer].Nodes[elevator]
if elevator != nil {
searchPoint = h.layers[layer].nodes[*elevator]
}

// Descending hierarchies
if layer > 0 {
nodes := searchPoint.search(1, efSearch, near, h.Distance)
elevator = nodes[0].node.Point.ID()
elevator = ptr(nodes[0].node.ID)
continue
}

nodes := searchPoint.search(k, efSearch, near, h.Distance)
out := make([]T, 0, len(nodes))
out := make([]Node[K], 0, len(nodes))

for _, node := range nodes {
out = append(out, node.node.Point.(T))
out = append(out, node.node.Node)
}

return out
Expand All @@ -437,31 +452,34 @@ func (h *Graph[T]) Len() int {
// Delete removes a node from the graph by ID.
// It tries to preserve the clustering properties of the graph by
// replenishing connectivity in the affected neighborhoods.
func (h *Graph[T]) Delete(id string) bool {
func (h *Graph[K]) Delete(id K) bool {
if len(h.layers) == 0 {
return false
}

var deleted bool
for _, layer := range h.layers {
node, ok := layer.Nodes[id]
node, ok := layer.nodes[id]
if !ok {
continue
}
delete(layer.Nodes, id)
delete(layer.nodes, id)
node.isolate(h.M)
deleted = true
}

return deleted
}

// Lookup returns the node with the given ID.
func (h *Graph[T]) Lookup(id string) (T, bool) {
var zero T
// Lookup returns the vector with the given ID.
func (h *Graph[K]) Lookup(id K) (Vector, bool) {
if len(h.layers) == 0 {
return zero, false
return nil, false
}

return h.layers[0].Nodes[id].Point.(T), true
node, ok := h.layers[0].nodes[id]
if !ok {
return nil, false
}
return node.Vec, ok
}

0 comments on commit 4875617

Please sign in to comment.