From 9182a9863c057c3b1974e613c1152c6e434a7b52 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 23 Nov 2020 13:01:18 +0100 Subject: [PATCH] Fill generics, when useful --- python/pyspark/ml/feature.pyi | 10 ++++++++-- python/pyspark/ml/pipeline.pyi | 4 ++-- python/pyspark/ml/regression.pyi | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index c1d3669b2479d..4999defdf8a70 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): def getVarianceThreshold(self) -> float: ... class VarianceThresholdSelector( - JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaEstimator[VarianceThresholdSelectorModel], + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelector], + JavaMLWritable, ): def __init__( self, @@ -1621,7 +1624,10 @@ class VarianceThresholdSelector( def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... class VarianceThresholdSelectorModel( - JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaModel, + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelectorModel], + JavaMLWritable, ): def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi index 44680586d70d1..f47e9e012ae14 100644 --- a/python/pyspark/ml/pipeline.pyi +++ b/python/pyspark/ml/pipeline.pyi @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter): def __init__(self, instance: Pipeline) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): cls: Type[Pipeline] def __init__(self, cls: Type[Pipeline]) -> None: ... def load(self, path: str) -> Pipeline: ... @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter): def __init__(self, instance: PipelineModel) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader[PipelineModel]): cls: Type[PipelineModel] def __init__(self, cls: Type[PipelineModel]) -> None: ... def load(self, path: str) -> PipelineModel: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index c9f9b8c32d93d..b8f1e61859c72 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -414,7 +414,7 @@ class RandomForestRegressionModel( _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable[RandomForestRegressionModel], ): @property def trees(self) -> List[DecisionTreeRegressionModel]: ...