-
Notifications
You must be signed in to change notification settings - Fork 28.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
956 changed files
with
6,448 additions
and
2,612 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml | ||
|
||
import scala.annotation.varargs | ||
|
||
import org.apache.spark.annotation.Since | ||
import org.apache.spark.ml.param.{ParamMap, ParamPair} | ||
import org.apache.spark.sql.Dataset | ||
|
||
/** | ||
* Abstract class for estimators that fit models to data. | ||
*/ | ||
abstract class Estimator[M <: Model[M]] extends PipelineStage { | ||
|
||
/** | ||
* Fits a single model to the input data with optional parameters. | ||
* | ||
* @param dataset | ||
* input dataset | ||
* @param firstParamPair | ||
* the first param pair, overrides embedded params | ||
* @param otherParamPairs | ||
* other param pairs. These values override any specified in this Estimator's embedded | ||
* ParamMap. | ||
* @return | ||
* fitted model | ||
*/ | ||
@Since("3.5.0") | ||
@varargs | ||
def fit( | ||
dataset: Dataset[_], | ||
firstParamPair: ParamPair[_], | ||
otherParamPairs: ParamPair[_]*): M = { | ||
val map = new ParamMap() | ||
.put(firstParamPair) | ||
.put(otherParamPairs: _*) | ||
fit(dataset, map) | ||
} | ||
|
||
/** | ||
* Fits a single model to the input data with provided parameter map. | ||
* | ||
* @param dataset | ||
* input dataset | ||
* @param paramMap | ||
* Parameter map. These values override any specified in this Estimator's embedded ParamMap. | ||
* @return | ||
* fitted model | ||
*/ | ||
@Since("3.5.0") | ||
def fit(dataset: Dataset[_], paramMap: ParamMap): M = { | ||
copy(paramMap).fit(dataset) | ||
} | ||
|
||
/** | ||
* Fits a model to the input data. | ||
*/ | ||
@Since("3.5.0") | ||
def fit(dataset: Dataset[_]): M | ||
|
||
/** | ||
* Fits multiple models to the input data with multiple sets of parameters. The default | ||
* implementation uses a for loop on each parameter map. Subclasses could override this to | ||
* optimize multi-model training. | ||
* | ||
* @param dataset | ||
* input dataset | ||
* @param paramMaps | ||
* An array of parameter maps. These values override any specified in this Estimator's | ||
* embedded ParamMap. | ||
* @return | ||
* fitted models, matching the input parameter maps | ||
*/ | ||
@Since("3.5.0") | ||
def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = { | ||
paramMaps.map(fit(dataset, _)) | ||
} | ||
|
||
@Since("3.5.0") | ||
override def copy(extra: ParamMap): Estimator[M] | ||
} |
53 changes: 53 additions & 0 deletions
53
connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.annotation.Since | ||
import org.apache.spark.ml.param.ParamMap | ||
|
||
/** | ||
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. | ||
* | ||
* @tparam M | ||
* model type | ||
*/ | ||
abstract class Model[M <: Model[M]] extends Transformer { | ||
|
||
/** | ||
* The parent estimator that produced this model. | ||
* @note | ||
* For ensembles' component Models, this value can be null. | ||
*/ | ||
@transient var parent: Estimator[M] = _ | ||
|
||
/** | ||
* Sets the parent of this model (Java API). | ||
*/ | ||
@Since("3.5.0") | ||
def setParent(parent: Estimator[M]): M = { | ||
this.parent = parent | ||
this.asInstanceOf[M] | ||
} | ||
|
||
/** Indicates whether this [[Model]] has a corresponding parent. */ | ||
@Since("3.5.0") | ||
def hasParent: Boolean = parent != null | ||
|
||
@Since("3.5.0") | ||
override def copy(extra: ParamMap): M | ||
} |
63 changes: 63 additions & 0 deletions
63
connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.annotation.DeveloperApi | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml.param.{ParamMap, Params} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
/** | ||
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. | ||
*/ | ||
abstract class PipelineStage extends Params with Logging { | ||
|
||
/** | ||
* Check transform validity and derive the output schema from the input schema. | ||
* | ||
* We check validity for interactions between parameters during `transformSchema` and raise an | ||
* exception if any parameter value is invalid. Parameter value checks which do not depend on | ||
* other parameters are handled by `Param.validate()`. | ||
* | ||
* Typical implementation should first conduct verification on schema change and parameter | ||
* validity, including complex parameter interaction checks. | ||
*/ | ||
def transformSchema(schema: StructType): StructType | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* | ||
* Derives the output schema from the input schema and parameters, optionally with logging. | ||
* | ||
* This should be optimistic. If it is unclear whether the schema will be valid, then it should | ||
* be assumed valid until proven otherwise. | ||
*/ | ||
@DeveloperApi | ||
protected def transformSchema(schema: StructType, logging: Boolean): StructType = { | ||
if (logging) { | ||
logDebug(s"Input schema: ${schema.json}") | ||
} | ||
val outputSchema = transformSchema(schema) | ||
if (logging) { | ||
logDebug(s"Expected output schema: ${outputSchema.json}") | ||
} | ||
outputSchema | ||
} | ||
|
||
override def copy(extra: ParamMap): PipelineStage | ||
} |
Oops, something went wrong.