Skip to content

Latest commit

 

History

History
19 lines (16 loc) · 509 Bytes

README.md

File metadata and controls

19 lines (16 loc) · 509 Bytes

TensorFlow Nearest Neighbours Op

Given an embedding matrix EM, and batch of word embeddings x find nearest embedding for each token x_ij in EM.

import tensorflow as tf
from tensorflow_nearest_neighbours import nearest_neighbours
tf.debugging.set_log_device_placement(True)

x = tf.random.uniform(shape=[8, 10, 32])
EM = tf.random.uniform(shape=[500, 32])
result = nearest_neighbours(x, EM)
print(result.shape)

Instalation

pip install tensorflow_nearest_neighbours