From b7a7b9b105357e3d8725bdb295388abff762648a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 May 2015 11:51:19 -0700 Subject: [PATCH] simplify grid build --- python/pyspark/ml/tuning.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a383bd0c0d26f..45ee862bff164 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,6 +15,8 @@ # limitations under the License. # +import itertools + __all__ = ['ParamGridBuilder'] @@ -76,17 +78,9 @@ def build(self): Builds and returns all combinations of parameters specified by the param grid. """ - param_maps = [{}] - for (param, values) in self._param_grid.items(): - new_param_maps = [] - for value in values: - for old_map in param_maps: - copied_map = old_map.copy() - copied_map[param] = value - new_param_maps.append(copied_map) - param_maps = new_param_maps - - return param_maps + keys = self._param_grid.keys() + grid_values = self._param_grid.values() + return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] if __name__ == "__main__":