Skip to content

Commit

Permalink
Fill generics, when useful
Browse files Browse the repository at this point in the history
  • Loading branch information
zero323 committed Nov 23, 2020
1 parent 426c7bb commit 9182a98
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
10 changes: 8 additions & 2 deletions python/pyspark/ml/feature.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
def getVarianceThreshold(self) -> float: ...

class VarianceThresholdSelector(
JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaEstimator[VarianceThresholdSelectorModel],
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelector],
JavaMLWritable,
):
def __init__(
self,
Expand All @@ -1621,7 +1624,10 @@ class VarianceThresholdSelector(
def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...

class VarianceThresholdSelectorModel(
JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaModel,
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelectorModel],
JavaMLWritable,
):
def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ...
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/pipeline.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter):
def __init__(self, instance: Pipeline) -> None: ...
def saveImpl(self, path: str) -> None: ...

class PipelineReader(MLReader):
class PipelineReader(MLReader[Pipeline]):
cls: Type[Pipeline]
def __init__(self, cls: Type[Pipeline]) -> None: ...
def load(self, path: str) -> Pipeline: ...
Expand All @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter):
def __init__(self, instance: PipelineModel) -> None: ...
def saveImpl(self, path: str) -> None: ...

class PipelineModelReader(MLReader):
class PipelineModelReader(MLReader[PipelineModel]):
cls: Type[PipelineModel]
def __init__(self, cls: Type[PipelineModel]) -> None: ...
def load(self, path: str) -> PipelineModel: ...
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/regression.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class RandomForestRegressionModel(
_TreeEnsembleModel,
_RandomForestRegressorParams,
JavaMLWritable,
JavaMLReadable,
JavaMLReadable[RandomForestRegressionModel],
):
@property
def trees(self) -> List[DecisionTreeRegressionModel]: ...
Expand Down

0 comments on commit 9182a98

Please sign in to comment.