Skip to content

Commit

Permalink
fix bug with chunking and num target documents
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 8, 2021
1 parent 48b8f45 commit aa2612a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions marge_pytorch/marge_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,16 @@ def get_embeds(data):
train_embeds = get_embeds(np_data)
self.index.train(train_embeds)

total_chunks = math.ceil(self.num_docs / batch_size)
total_evidence_chunks = math.ceil(self.num_docs / batch_size)

for data_slice in tqdm(chunk(batch_size, self.num_docs), total=total_chunks, desc='Adding embedding to indexes'):
for data_slice in tqdm(chunk(batch_size, self.num_docs), total=total_evidence_chunks, desc='Adding embedding to indexes'):
np_data = torch.from_numpy(doc_pointer[data_slice, :]).cuda().long()
embeds = get_embeds(np_data)
self.index.add(embeds)

for data_slice in tqdm(chunk(batch_size, self.num_targets), total=total_chunks, desc='Fetching and storing nearest neighbors'):
total_target_chunks = math.ceil(self.num_targets / batch_size)

for data_slice in tqdm(chunk(batch_size, self.num_targets), total=total_target_chunks, desc='Fetching and storing nearest neighbors'):
np_data = torch.from_numpy(target_pointer[data_slice, :]).cuda().long()

embeds = get_embeds(np_data)
Expand Down

0 comments on commit aa2612a

Please sign in to comment.