Skip to content

Commit

Permalink
sync from master
Browse files Browse the repository at this point in the history
  • Loading branch information
dtenedor committed Mar 20, 2023
2 parents 7dcb503 + 67a254c commit d9d8abd
Show file tree
Hide file tree
Showing 956 changed files with 6,448 additions and 2,612 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ jobs:
catalyst, hive-thriftserver
- >-
streaming, sql-kafka-0-10, streaming-kafka-0-10,
mllib-local, mllib,
mllib-local, mllib-common, mllib,
yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl,
connect, protobuf
# Here, we split Hive and SQL tests into some of slow ones and the rest of them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public KVTypeInfo(Class<?> type) {
KVIndex idx = m.getAnnotation(KVIndex.class);
if (idx != null) {
checkIndex(idx, indices);
Preconditions.checkArgument(m.getParameterTypes().length == 0,
Preconditions.checkArgument(m.getParameterCount() == 0,
"Annotated method %s::%s should not have any parameters.", type.getName(), m.getName());
m.setAccessible(true);
indices.put(idx.value(), idx);
Expand Down
12 changes: 12 additions & 0 deletions connector/connect/client/jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import scala.annotation.varargs

import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.{ParamMap, ParamPair}
import org.apache.spark.sql.Dataset

/**
* Abstract class for estimators that fit models to data.
*/
abstract class Estimator[M <: Model[M]] extends PipelineStage {

/**
* Fits a single model to the input data with optional parameters.
*
* @param dataset
* input dataset
* @param firstParamPair
* the first param pair, overrides embedded params
* @param otherParamPairs
* other param pairs. These values override any specified in this Estimator's embedded
* ParamMap.
* @return
* fitted model
*/
@Since("3.5.0")
@varargs
def fit(
dataset: Dataset[_],
firstParamPair: ParamPair[_],
otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
fit(dataset, map)
}

/**
* Fits a single model to the input data with provided parameter map.
*
* @param dataset
* input dataset
* @param paramMap
* Parameter map. These values override any specified in this Estimator's embedded ParamMap.
* @return
* fitted model
*/
@Since("3.5.0")
def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}

/**
* Fits a model to the input data.
*/
@Since("3.5.0")
def fit(dataset: Dataset[_]): M

/**
* Fits multiple models to the input data with multiple sets of parameters. The default
* implementation uses a for loop on each parameter map. Subclasses could override this to
* optimize multi-model training.
*
* @param dataset
* input dataset
* @param paramMaps
* An array of parameter maps. These values override any specified in this Estimator's
* embedded ParamMap.
* @return
* fitted models, matching the input parameter maps
*/
@Since("3.5.0")
def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}

@Since("3.5.0")
override def copy(extra: ParamMap): Estimator[M]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.ParamMap

/**
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
*
* @tparam M
* model type
*/
abstract class Model[M <: Model[M]] extends Transformer {

/**
* The parent estimator that produced this model.
* @note
* For ensembles' component Models, this value can be null.
*/
@transient var parent: Estimator[M] = _

/**
* Sets the parent of this model (Java API).
*/
@Since("3.5.0")
def setParent(parent: Estimator[M]): M = {
this.parent = parent
this.asInstanceOf[M]
}

/** Indicates whether this [[Model]] has a corresponding parent. */
@Since("3.5.0")
def hasParent: Boolean = parent != null

@Since("3.5.0")
override def copy(extra: ParamMap): M
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.types.StructType

/**
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
*/
abstract class PipelineStage extends Params with Logging {

/**
* Check transform validity and derive the output schema from the input schema.
*
* We check validity for interactions between parameters during `transformSchema` and raise an
* exception if any parameter value is invalid. Parameter value checks which do not depend on
* other parameters are handled by `Param.validate()`.
*
* Typical implementation should first conduct verification on schema change and parameter
* validity, including complex parameter interaction checks.
*/
def transformSchema(schema: StructType): StructType

/**
* :: DeveloperApi ::
*
* Derives the output schema from the input schema and parameters, optionally with logging.
*
* This should be optimistic. If it is unclear whether the schema will be valid, then it should
* be assumed valid until proven otherwise.
*/
@DeveloperApi
protected def transformSchema(schema: StructType, logging: Boolean): StructType = {
if (logging) {
logDebug(s"Input schema: ${schema.json}")
}
val outputSchema = transformSchema(schema)
if (logging) {
logDebug(s"Expected output schema: ${outputSchema.json}")
}
outputSchema
}

override def copy(extra: ParamMap): PipelineStage
}
Loading

0 comments on commit d9d8abd

Please sign in to comment.