Skip to content

Commit

Permalink
Benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Apr 22, 2024
1 parent f1dc676 commit 2a39e90
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions python/lbann/contrib/data/molecule_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, file_or_files: Union[str, List[str]]):
np.memmap(f + '.offsets', dtype=np.uint64) for f in file_or_files
]
self.samples = [o.shape[0] for o in self.offsets]
self.cs = np.cumsum(np.array(self.samples, dtype=np.uint64), dtype=np.uint64)
self.cs = np.cumsum(np.array(self.samples, dtype=np.uint64),
dtype=np.uint64)
self.total_samples = sum(self.samples)

# Clean memmapped files so that the object can be pickled
Expand Down Expand Up @@ -222,17 +223,41 @@ def trim_and_pad(self, sample, random: bool):

if __name__ == '__main__':
import sys
if len(sys.argv) != 4:
print('USAGE: dataloader_mlm.py <dataset file> <vocabulary file> '
'<smiles/selfies/ais>')
if len(sys.argv) < 4:
print('USAGE: dataloader_mlm.py <vocabulary file> <smiles/selfies/ais>'
' <dataset file> [other dataset files]')
exit(1)

dataset = ChemTextDataset(
fname=[sys.argv[1]],
vocab=sys.argv[2],
seqlen=64,
tokenizer_type=ChemTokenType[sys.argv[3].upper()])
print('Dataset samples:', len(dataset))
_, vocab, toktype, *files = sys.argv

dataset = ChemTextDataset(fname=files,
vocab=vocab,
seqlen=64,
tokenizer_type=ChemTokenType[toktype.upper()])

# Test 1: Arbitrary sample retrieval
num_samples = len(dataset)
print('Dataset samples:', num_samples)
print('Dataset sample -1:')
print(
dataset.tokenizer.decode(dataset[-1].sample[:dataset.sequence_length]))

# Test 2: Retrieval bandwidth
import time
try:
from tqdm import trange
except (ModuleNotFoundError, ImportError):
trange = range

# Warmup
for _ in range(10):
samp = np.random.randint(0, num_samples - 1)
_ = dataset[samp]

SAMPLES = 5000
start = time.time()
for i in trange(SAMPLES):
samp = np.random.randint(0, num_samples - 1)
_ = dataset[samp]
end = time.time()
print('Samples per second:', SAMPLES / (end - start))

0 comments on commit 2a39e90

Please sign in to comment.