-
Notifications
You must be signed in to change notification settings - Fork 7
/
go2vec.go
379 lines (307 loc) · 8.93 KB
/
go2vec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
// Copyright 2015, 2017 Daniël de Kok
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package go2vec
import (
"bufio"
"container/heap"
"encoding/binary"
"fmt"
"math"
"sort"
"strings"
"github.com/gonum/blas"
cblas "github.com/gonum/blas/cgo"
)
// IterFunc is a function for iterating over word embeddings. The function
// should return 'false' if the iteration should be stopped.
type IterFunc func(word string, embedding []float32) bool
// WordSimilarity stores the similarity of a word compared to a query word.
type WordSimilarity struct {
Word string
Similarity float32
}
type similarityHeap []WordSimilarity
func (h similarityHeap) Len() int { return len(h) }
func (h similarityHeap) Less(i, j int) bool {
if h[i].Similarity == h[j].Similarity {
return h[i].Word < h[i].Word
}
return h[i].Similarity < h[j].Similarity
}
func (h similarityHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *similarityHeap) Push(x interface{}) {
*h = append(*h, x.(WordSimilarity))
}
func (h *similarityHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// Embeddings is used to store a set of word embeddings, such that common
// operations can be performed on these embeddings (such as retrieving
// similar words).
type Embeddings struct {
blas blas.Float32Level2
matrix []float32
embedSize int
indices map[string]int
words []string
}
// NewEmbeddings creates a set of word embeddings from scratch. This constructor
// should be used in conjunction with 'Put' to populate the embeddings.
func NewEmbeddings(embedSize int) *Embeddings {
return &Embeddings{
blas: cblas.Implementation{},
matrix: make([]float32, 0),
embedSize: embedSize,
indices: make(map[string]int),
words: make([]string, 0),
}
}
// ReadWord2VecBinary reads word embeddings from a binary file that is produced
// by word2vec. The embeddings can be normalized using their L2 norms.
func ReadWord2VecBinary(r *bufio.Reader, normalize bool) (*Embeddings, error) {
var nWords uint64
if _, err := fmt.Fscanf(r, "%d", &nWords); err != nil {
return nil, err
}
var vSize uint64
if _, err := fmt.Fscanf(r, "%d", &vSize); err != nil {
return nil, err
}
matrix := make([]float32, nWords*vSize)
indices := make(map[string]int)
words := make([]string, nWords)
for idx := 0; idx < int(nWords); idx++ {
word, err := r.ReadString(' ')
word = strings.TrimSpace(word)
indices[word] = idx
words[idx] = word
start := idx * int(vSize)
if err = binary.Read(r, binary.LittleEndian, matrix[start:start+int(vSize)]); err != nil {
return nil, err
}
if normalize {
normalizeEmbeddings(matrix[start : start+int(vSize)])
}
}
return &Embeddings{
blas: cblas.Implementation{},
matrix: matrix,
embedSize: int(vSize),
indices: indices,
words: words,
}, nil
}
// Write embeddings to a binary file accepted by word2vec
func (e *Embeddings) Write(w *bufio.Writer) error {
nWords := len(e.words)
if nWords == 0 {
return nil
}
if e.embedSize == 0 {
return nil
}
if _, err := fmt.Fprintf(w, "%d %d\n", nWords, e.embedSize); err != nil {
return err
}
for idx, word := range e.words {
if _, err := w.WriteString(word + " "); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, e.lookupIdx(idx)); err != nil {
return err
}
if err := w.WriteByte(0x0a); err != nil {
return err
}
}
return nil
}
// Analogy performs word analogy queries.
//
// Consider an analogy of the form 'word1' is to 'word2' as 'word3' is to
// 'word4'. This method returns candidates for 'word4' based on 'word1..3'.
//
// If 'e1' is the embedding of 'word1', etc., then the embedding
// 'e4 = (e2 - e1) + e3' is computed. Then the words with embeddings that are
// the most similar to e4 are returned.
//
// The query words are never returned as a result.
func (e *Embeddings) Analogy(word1, word2, word3 string, limit int) ([]WordSimilarity, error) {
idx1, ok := e.indices[word1]
if !ok {
return nil, fmt.Errorf("Unknown word: %s", word1)
}
idx2, ok := e.indices[word2]
if !ok {
return nil, fmt.Errorf("Unknown word: %s", word2)
}
idx3, ok := e.indices[word3]
if !ok {
return nil, fmt.Errorf("Unknown word: %s", word3)
}
v1 := e.lookupIdx(idx1)
v2 := e.lookupIdx(idx2)
v3 := e.lookupIdx(idx3)
v4 := plus(minus(v2, v1), v3)
skips := map[int]interface{}{
idx1: nil,
idx2: nil,
idx3: nil,
}
return e.similarity(v4, skips, limit)
}
// SetBLAS sets the BLAS implementation to use (default: C BLAS).
func (e *Embeddings) SetBLAS(impl blas.Float32Level2) {
e.blas = impl
}
// Iterate applies the provided iteration function to all word embeddings.
func (e *Embeddings) Iterate(f IterFunc) {
for idx, word := range e.words {
if !f(word, e.lookupIdx(idx)) {
break
}
}
}
func (e *Embeddings) Matrix() []float32 {
return e.matrix
}
// Put adds a word embedding to the word embeddings. The new word can be
// queried after the call returns.
func (e *Embeddings) Put(word string, embedding []float32) error {
if len(embedding) != e.embedSize {
return fmt.Errorf("Expected embedding size: %d, got: %d", e.embedSize, len(embedding))
}
if idx, ok := e.indices[word]; ok {
// The word is already known, replace its embedding.
copy(e.matrix[idx*e.embedSize:], embedding)
} else {
// The word is not known, add it and allocate memory.
e.indices[word] = len(e.words)
e.words = append(e.words, word)
e.matrix = append(e.matrix, embedding...)
}
return nil
}
// Similarity finds words that have embeddings that are similar to that of
// the given word. The 'limit' argument specifis how many words should be
// returned. The returned slice is ordered by similarity.
//
// The query word is never returned as a result.
func (e *Embeddings) Similarity(word string, limit int) ([]WordSimilarity, error) {
idx, ok := e.indices[word]
if !ok {
return nil, fmt.Errorf("Unknown word: %s", word)
}
skips := map[int]interface{}{
idx: nil,
}
return e.similarity(e.lookupIdx(idx), skips, limit)
}
// Size returns the number of words in the embeddings.
func (e *Embeddings) Size() int {
return len(e.indices)
}
// Embedding returns the embedding for a particular word. If the word is
// unknown, the second return value will be false.
func (e *Embeddings) Embedding(word string) ([]float32, bool) {
if idx, ok := e.indices[word]; ok {
return e.lookupIdx(idx), true
}
return nil, false
}
// EmbeddingSize returns the embedding size.
func (e *Embeddings) EmbeddingSize() int {
return e.embedSize
}
// WordIdx returns the index of the word within an embedding.
func (e *Embeddings) WordIdx(word string) (int, bool) {
if idx, ok := e.indices[word]; ok {
return idx, ok
}
return 0, false
}
func (e *Embeddings) similarity(embed []float32, skips map[int]interface{}, limit int) ([]WordSimilarity, error) {
dps := make([]float32, e.Size())
e.blas.Sgemv(blas.NoTrans, int(e.Size()), int(e.EmbeddingSize()),
1, e.matrix, int(e.EmbeddingSize()), embed, 1, 0, dps, 1)
results := make(similarityHeap, 0, minInt(limit, e.Size()))
heap.Init(&results)
for idx, sim := range dps {
// Skip words in the skip set.
if _, ok := skips[idx]; ok {
continue
}
if results.Len() < limit {
heap.Push(&results, WordSimilarity{e.words[idx], sim})
} else if results[0].Similarity < sim {
heap.Pop(&results)
heap.Push(&results, WordSimilarity{e.words[idx], sim})
}
}
// Todo: heapsort.
sort.Sort(sort.Reverse(results))
return results, nil
}
func dotProduct(v, w []float32) float32 {
sum := float32(0)
for idx, val := range v {
sum += val * w[idx]
}
return sum
}
// Look up the embedding at the given index.
func (e *Embeddings) lookupIdx(idx int) []float32 {
start := idx * e.embedSize
return e.matrix[start : start+e.embedSize]
}
func minus(v, w []float32) []float32 {
result := make([]float32, len(v))
for idx, val := range v {
result[idx] = val - w[idx]
}
return result
}
// Normalize an embedding using its l2-norm.
func normalizeEmbeddings(embedding []float32) {
// Normalize
embedLen := float32(0)
for _, val := range embedding {
embedLen += val * val
}
// Cannot normalize zero-vector
if embedLen == 0 {
return
}
embedLen = float32(math.Sqrt(float64(embedLen)))
for idx, val := range embedding {
embedding[idx] = val / embedLen
}
}
func plus(v, w []float32) []float32 {
result := make([]float32, len(v))
for idx, val := range v {
result[idx] = val + w[idx]
}
return result
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}