Skip to content

Commit

Permalink
refactor random_permutations for simplicity (#300)
Browse files Browse the repository at this point in the history
Refactor random_permutations ensemble strategy to simplify code

[ committed by @ankona ]
[ reviewed by @billschereriii ]
  • Loading branch information
ankona authored Jun 15, 2023
1 parent 395ffb0 commit 3f2dfa4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
2 changes: 2 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ A full list of changes and detailed notes can be found below:

Detailed notes

- Simplify code in `random_permutations` parameter generation strategy (PR300_)
- Remove wait time associated with Experiment launch summary (PR298_)
- Update Redis conf file to conform with Redis v7.0.5 conf file (PR293_)
- Migrate from redis-py-cluster to redis-py for cluster status checks (PR292_)
Expand All @@ -55,6 +56,7 @@ argument name is still `interface` for backward compatibility reasons. (PR281_)
- Typehints have been added to public APIs. A makefile target to execute static
analysis with mypy is available `make check-mypy`. (PR295_)

.. _PR300: https://github.com/CrayLabs/SmartSim/pull/300
.. _PR298: https://github.com/CrayLabs/SmartSim/pull/298
.. _PR293: https://github.com/CrayLabs/SmartSim/pull/293
.. _PR292: https://github.com/CrayLabs/SmartSim/pull/292
Expand Down
25 changes: 7 additions & 18 deletions smartsim/entity/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,10 @@ def step_values(
def random_permutations(
param_names: t.List[str], param_values: t.List[t.List[str]], n_models: int = 0
) -> t.List[t.Dict[str, str]]:
# first, check if we've requested more values than possible.
perms = list(product(*param_values))
if n_models >= len(perms):
return create_all_permutations(param_names, param_values)
else:
permutations: t.List[t.Dict[str, str]] = []
permutation_strings = set()
while len(permutations) < n_models:
model_dict = dict(
zip(
param_names,
map(lambda x: x[random.randint(0, len(x) - 1)], param_values),
)
)
if str(model_dict) not in permutation_strings:
permutation_strings.add(str(model_dict))
permutations.append(model_dict)
return permutations
permutations = create_all_permutations(param_names, param_values)

# sample from available permutations if n_models is specified
if n_models and n_models < len(permutations):
permutations = random.sample(permutations, n_models)

return permutations

0 comments on commit 3f2dfa4

Please sign in to comment.