Skip to content

Commit

Permalink
Remove unnecessary data loading procedures.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 422336610
  • Loading branch information
adria-p authored and copybara-github committed Jan 17, 2022
1 parent daf28bd commit a2ed70f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 74 deletions.
9 changes: 3 additions & 6 deletions clrs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

from clrs import models
from clrs._src import algorithms
from clrs._src.dataset import CLRSDataset
from clrs._src.dataset import create_dataset
from clrs._src.model import evaluate
from clrs._src.model import Model
from clrs._src.probing import DataPoint
from clrs._src.samplers import build_sampler
from clrs._src.samplers import CLRS21
from clrs._src.samplers import clrs21_test
from clrs._src.samplers import clrs21_train
from clrs._src.samplers import clrs21_val
from clrs._src.samplers import Features
from clrs._src.samplers import Feedback
from clrs._src.samplers import Sampler
Expand All @@ -41,9 +40,7 @@
__all__ = (
"build_sampler",
"CLRS21",
"clrs21_test",
"clrs21_train",
"clrs21_val",
"create_dataset",
"DataPoint",
"evaluate",
"Features",
Expand Down
72 changes: 4 additions & 68 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import abc
import collections
import os
import pickle as pkl
import types

from typing import Any, Callable, List, Optional, Tuple
Expand Down Expand Up @@ -69,7 +67,6 @@ def __init__(
num_samples: int,
*args,
seed: Optional[int] = None,
file_name: Optional[str] = None,
**kwargs,
):
"""Initializes a `Sampler`.
Expand All @@ -80,8 +77,6 @@ def __init__(
num_samples: Number of algorithm unrolls to sample.
*args: Algorithm args.
seed: RNG seed.
file_name: file_name where the dataset is stored. If None,
the dataset is generated on the fly (slower).
**kwargs: Algorithm kwargs.
"""

Expand All @@ -93,17 +88,9 @@ def __init__(
outputs = []
hints = []

file_probes = None
if file_name is not None:
with open(file_name, 'rb') as f:
file_probes = pkl.load(f)

for i in range(num_samples):
if file_name is None:
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
else:
probes = file_probes[i]
for _ in range(num_samples):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
inp, outp, hint = probing.split_stages(probes, spec)
inputs.append(inp)
outputs.append(outp)
Expand Down Expand Up @@ -219,61 +206,11 @@ def _random_bipartite_graph(self, n, m, p=0.25):
return mat


def _join_prefix(folder, prefix):
return None if folder is None else os.path.join(folder, prefix)


def clrs21_train(name: str, folder: Optional[str] = None) -> Tuple[
Sampler, specs.Spec]:
"""Builds a CLRS-21 training sampler for algorithm specified by `name`."""
if name not in specs.CLRS_21_ALGS:
raise NotImplementedError(f'Algorithm {name} not supported in CLRS-21.')
sampler = build_sampler(
name,
seed=CLRS21['train']['seed'],
num_samples=CLRS21['train']['num_samples'],
length=CLRS21['train']['length'],
file_name=_join_prefix(folder, 'train_{}.pkl'.format(name)),
)
return sampler


def clrs21_val(name: str, folder: Optional[str] = None) -> Tuple[
Sampler, specs.Spec]:
"""Builds a CLRS-21 validation sampler for algorithm specified by `name`."""
if name not in specs.CLRS_21_ALGS:
raise NotImplementedError(f'Algorithm {name} not supported in CLRS-21.')
sampler = build_sampler(
name,
seed=CLRS21['val']['seed'],
num_samples=CLRS21['val']['num_samples'],
length=CLRS21['val']['length'],
file_name=_join_prefix(folder, 'val_{}.pkl'.format(name)),
)
return sampler


def clrs21_test(name: str, folder: Optional[str] = None) -> Tuple[
Sampler, specs.Spec]:
"""Builds a CLRS-21 testing sampler for algorithm specified by `name`."""
if name not in specs.CLRS_21_ALGS:
raise NotImplementedError(f'Algorithm {name} not supported in CLRS-21.')
sampler = build_sampler(
name,
seed=CLRS21['test']['seed'],
num_samples=CLRS21['test']['num_samples'],
length=CLRS21['test']['length'],
file_name=_join_prefix(folder, 'test_{}.pkl'.format(name)),
)
return sampler


def build_sampler(
name: str,
num_samples: int,
*args,
seed: Optional[int] = None,
file_name: Optional[str] = None,
**kwargs,
) -> Tuple[Sampler, specs.Spec]:
"""Builds a sampler. See `Sampler` documentation."""
Expand All @@ -283,8 +220,7 @@ def build_sampler(
spec = specs.SPECS[name]
algorithm = getattr(algorithms, name)
sampler = SAMPLERS[name](
algorithm, spec, num_samples, seed=seed,
file_name=file_name, *args, **kwargs)
algorithm, spec, num_samples, seed=seed, *args, **kwargs)
return sampler, spec


Expand Down

0 comments on commit a2ed70f

Please sign in to comment.