Skip to content

Commit

Permalink
Refactor API so that keys are generic (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
ammario authored Jun 14, 2024
1 parent 34992a1 commit e942b4f
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 298 deletions.
32 changes: 16 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ go get github.com/coder/hnsw@main
```

```go
g := hnsw.NewGraph[hnsw.Vector]()
g := hnsw.NewGraph[int]()
g.Add(
hnsw.MakeVector("1", []float32{1, 1, 1}),
hnsw.MakeVector("2", []float32{1, -1, 0.999}),
hnsw.MakeVector("3", []float32{1, 0, -0.5}),
hnsw.MakeNode(1, []float32{1, 1, 1}),
hnsw.MakeNode(2, []float32{1, -1, 0.999}),
hnsw.MakeNode(3, []float32{1, 0, -0.5}),
)

neighbors := g.Search(
[]float32{0.5, 0.5, 0.5},
1,
)
fmt.Printf("best friend: %v\n", neighbors[0].Embedding())
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
// Output: best friend: [1 1 1]
```

Expand All @@ -59,13 +59,13 @@ If you're using a single file as the backend, hnsw provides a convenient `SavedG

```go
path := "some.graph"
g1, err := LoadSavedGraph[hnsw.Vector](path)
g1, err := LoadSavedGraph[int](path)
if err != nil {
panic(err)
}
// Insert some vectors
for i := 0; i < 128; i++ {
g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)}))
g1.Add(hnsw.MakeNode(i, []float32{float32(i)}))
}

// Save to disk
Expand All @@ -76,7 +76,7 @@ if err != nil {

// Later...
// g2 is a copy of g1
g2, err := LoadSavedGraph[Vector](path)
g2, err := LoadSavedGraph[int](path)
if err != nil {
panic(err)
}
Expand All @@ -94,10 +94,10 @@ nearly at disk speed. On my M3 Macbook I get these benchmark results:
goos: darwin
goarch: arm64
pkg: github.com/coder/hnsw
BenchmarkGraph_Import-16 2733 369803 ns/op 228.65 MB/s 352041 B/op 9880 allocs/op
BenchmarkGraph_Export-16 6046 194441 ns/op 1076.65 MB/s 261854 B/op 3760 allocs/op
BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op
BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op
PASS
ok github.com/coder/hnsw 2.530s
ok github.com/coder/hnsw 2.624s
```

when saving/loading a graph of 100 vectors with 256 dimensions.
Expand Down Expand Up @@ -130,18 +130,18 @@ $$

where:
* $n$ is the number of vectors in the graph
* $\text{size(id)}$ is the average size of the ID in bytes
* $\text{size(key)}$ is the average size of the key in bytes
* $M$ is the maximum number of neighbors each node can have
* $d$ is the dimensionality of the vectors
* $mem_{graph}$ is the memory used by the graph structure across all layers
* $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer

You can infer that:
* Connectivity ($M$) is very expensive if IDs are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
* Connectivity ($M$) is very expensive if keys are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure

In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes:
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes:

* $256 \cdot 4 = 1024$ data bytes
* $16 \cdot 8 = 128$ metadata bytes
Expand Down
14 changes: 8 additions & 6 deletions analyzer.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package hnsw

import "cmp"

// Analyzer is a struct that holds a graph and provides
// methods for analyzing it. It offers no compatibility guarantee
// as the methods of measuring the graph's health with change
// with the implementation.
type Analyzer[T Embeddable] struct {
Graph *Graph[T]
type Analyzer[K cmp.Ordered] struct {
Graph *Graph[K]
}

func (a *Analyzer[T]) Height() int {
Expand All @@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int {
func (a *Analyzer[T]) Connectivity() []float64 {
var layerConnectivity []float64
for _, layer := range a.Graph.layers {
if len(layer.Nodes) == 0 {
if len(layer.nodes) == 0 {
continue
}

var sum float64
for _, node := range layer.Nodes {
for _, node := range layer.nodes {
sum += float64(len(node.neighbors))
}

layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes)))
}

return layerConnectivity
Expand All @@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 {
func (a *Analyzer[T]) Topography() []int {
var topography []int
for _, layer := range a.Graph.layers {
topography = append(topography, len(layer.Nodes))
topography = append(topography, len(layer.nodes))
}
return topography
}
82 changes: 47 additions & 35 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package hnsw

import (
"bufio"
"cmp"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) {
*v = string(s)
return len(s), err

case *[]float32:
var ln int
_, err := binaryRead(r, &ln)
if err != nil {
return 0, err
}

*v = make([]float32, ln)
return binary.Size(*v), binary.Read(r, byteOrder, *v)

case io.ReaderFrom:
n, err := v.ReadFrom(r)
return int(n), err
Expand Down Expand Up @@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) {
}

return n + n2, nil
case []float32:
n, err := binaryWrite(w, len(v))
if err != nil {
return n, err
}
return n + binary.Size(v), binary.Write(w, byteOrder, v)

default:
sz := binary.Size(data)
Expand Down Expand Up @@ -113,7 +130,7 @@ const encodingVersion = 1
// Export writes the graph to a writer.
//
// T must implement io.WriterTo.
func (h *Graph[T]) Export(w io.Writer) error {
func (h *Graph[K]) Export(w io.Writer) error {
distFuncName, ok := distanceFuncToName(h.Distance)
if !ok {
return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance)
Expand All @@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error {
return fmt.Errorf("encode number of layers: %w", err)
}
for _, layer := range h.layers {
_, err = binaryWrite(w, len(layer.Nodes))
_, err = binaryWrite(w, len(layer.nodes))
if err != nil {
return fmt.Errorf("encode number of nodes: %w", err)
}
for _, node := range layer.Nodes {
_, err = binaryWrite(w, node.Point)
for _, node := range layer.nodes {
_, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors))
if err != nil {
return fmt.Errorf("encode node point: %w", err)
}

if _, err = binaryWrite(w, len(node.neighbors)); err != nil {
return fmt.Errorf("encode number of neighbors: %w", err)
return fmt.Errorf("encode node data: %w", err)
}

for neighbor := range node.neighbors {
_, err = binaryWrite(w, neighbor)
if err != nil {
return fmt.Errorf("encode neighbor %q: %w", neighbor, err)
return fmt.Errorf("encode neighbor %v: %w", neighbor, err)
}
}
}
Expand All @@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error {
// T must implement io.ReaderFrom.
// The imported graph does not have to match the exported graph's parameters (except for
// dimensionality). The graph will converge onto the new parameters.
func (h *Graph[T]) Import(r io.Reader) error {
func (h *Graph[K]) Import(r io.Reader) error {
var (
version int
dist string
Expand Down Expand Up @@ -195,55 +208,54 @@ func (h *Graph[T]) Import(r io.Reader) error {
return err
}

h.layers = make([]*layer[T], nLayers)
h.layers = make([]*layer[K], nLayers)
for i := 0; i < nLayers; i++ {
var nNodes int
_, err = binaryRead(r, &nNodes)
if err != nil {
return err
}

nodes := make(map[string]*layerNode[T], nNodes)
nodes := make(map[K]*layerNode[K], nNodes)
for j := 0; j < nNodes; j++ {
var point T
_, err = binaryRead(r, &point)
if err != nil {
return fmt.Errorf("decoding node %d: %w", j, err)
}

var key K
var vec Vector
var nNeighbors int
_, err = binaryRead(r, &nNeighbors)
_, err = multiBinaryRead(r, &key, &vec, &nNeighbors)
if err != nil {
return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err)
return fmt.Errorf("decoding node %d: %w", j, err)
}

neighbors := make([]string, nNeighbors)
neighbors := make([]K, nNeighbors)
for k := 0; k < nNeighbors; k++ {
var neighbor string
var neighbor K
_, err = binaryRead(r, &neighbor)
if err != nil {
return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err)
}
neighbors[k] = neighbor
}

node := &layerNode[T]{
Point: point,
neighbors: make(map[string]*layerNode[T]),
node := &layerNode[K]{
Node: Node[K]{
Key: key,
Value: vec,
},
neighbors: make(map[K]*layerNode[K]),
}

nodes[point.ID()] = node
nodes[key] = node
for _, neighbor := range neighbors {
node.neighbors[neighbor] = nil
}
}
// Fill in neighbor pointers
for _, node := range nodes {
for id := range node.neighbors {
node.neighbors[id] = nodes[id]
for key := range node.neighbors {
node.neighbors[key] = nodes[key]
}
}
h.layers[i] = &layer[T]{Nodes: nodes}
h.layers[i] = &layer[K]{nodes: nodes}
}

return nil
Expand All @@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error {
// changes to a file upon calls to Save. It is more convenient
// but less powerful than calling Graph.Export and Graph.Import
// directly.
type SavedGraph[T Embeddable] struct {
*Graph[T]
type SavedGraph[K cmp.Ordered] struct {
*Graph[K]
Path string
}

Expand All @@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct {
//
// It does not hold open a file descriptor, so SavedGraph can be forgotten
// without ever calling Save.
func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) {
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
if err != nil {
return nil, err
Expand All @@ -276,15 +288,15 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
return nil, err
}

g := NewGraph[T]()
g := NewGraph[K]()
if info.Size() > 0 {
err = g.Import(bufio.NewReader(f))
if err != nil {
return nil, fmt.Errorf("import: %w", err)
}
}

return &SavedGraph[T]{Graph: g, Path: path}, nil
return &SavedGraph[K]{Graph: g, Path: path}, nil
}

// Save writes the graph to the file.
Expand Down
Loading

0 comments on commit e942b4f

Please sign in to comment.