From 45c51e9f04235ca0bcfdecf5c2f5c5d434b26c96 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 May 2015 18:06:48 -0700 Subject: [PATCH] [SPARK-7329] [MLLIB] simplify ParamGridBuilder impl as suggested by justinuang on #5601. Author: Xiangrui Meng Closes #5873 from mengxr/SPARK-7329 and squashes the following commits: d08f9cf [Xiangrui Meng] simplify tests b7a7b9b [Xiangrui Meng] simplify grid build --- python/pyspark/ml/tuning.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a383bd0c0d26f..1773ab5bdcdb1 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,6 +15,8 @@ # limitations under the License. # +import itertools + __all__ = ['ParamGridBuilder'] @@ -37,14 +39,10 @@ class ParamGridBuilder(object): {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] - >>> fail_count = 0 - >>> for e in expected: - ... if e not in output: - ... fail_count += 1 - >>> if len(expected) != len(output): - ... fail_count += 1 - >>> fail_count - 0 + >>> len(output) == len(expected) + True + >>> all([m in expected for m in output]) + True """ def __init__(self): @@ -76,17 +74,9 @@ def build(self): Builds and returns all combinations of parameters specified by the param grid. """ - param_maps = [{}] - for (param, values) in self._param_grid.items(): - new_param_maps = [] - for value in values: - for old_map in param_maps: - copied_map = old_map.copy() - copied_map[param] = value - new_param_maps.append(copied_map) - param_maps = new_param_maps - - return param_maps + keys = self._param_grid.keys() + grid_values = self._param_grid.values() + return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] if __name__ == "__main__":