-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8bee217
commit 1b855f9
Showing
4 changed files
with
106 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from sparktorch.util import serialize_torch_obj, serialize_torch_obj_lazy | ||
from sparktorch.torch_distributed import SparkTorch | ||
from sparktorch.pipeline_util import PysparkPipelineWrapper | ||
from sparktorch.inference import create_spark_torch_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from sparktorch.torch_distributed import SparkTorchModel | ||
import torch.nn as nn | ||
import codecs | ||
import dill | ||
from pyspark.ml.pipeline import PipelineModel | ||
|
||
|
||
def convert_to_serialized_torch(network: nn.Module) -> str: | ||
""" | ||
Converts an existing torch network to a serialized string. | ||
:param network: a nn.Module that you want to serialize | ||
:return: Returns the serialized torch model. | ||
""" | ||
return codecs.encode(dill.dumps(network), "base64").decode() | ||
|
||
|
||
def create_spark_torch_model( | ||
network: nn.Module, | ||
inputCol: str = 'features', | ||
predictionCol: str = 'predicted', | ||
useVectorOut: bool = False | ||
) -> SparkTorchModel: | ||
""" | ||
Creates a spark SparkTorchModel from an already trained network. Useful for running inference on large datasets. | ||
:param network: an already trained network | ||
:param inputCol: The spark dataframe input column | ||
:param predictionCol: The spark dataframe prediction columns | ||
:param useVectorOut: Determines whether the output should return a spark vector | ||
:return: Returns a SparkTorchModel | ||
""" | ||
|
||
return SparkTorchModel( | ||
inputCol=inputCol, | ||
predictionCol=predictionCol, | ||
modStr=convert_to_serialized_torch(network), | ||
useVectorOut=useVectorOut | ||
) | ||
|
||
|
||
def attach_pytorch_model_to_pipeline( | ||
network: nn.Module, | ||
pipeline_model: PipelineModel, | ||
inputCol: str = 'features', | ||
predictionCol: str = 'predicted', | ||
useVectorOut: bool = False | ||
) -> PipelineModel: | ||
""" | ||
Attaches a pytorch model to an existing pyspark pipeline. | ||
:param network: Pytorch Network | ||
:param pipeline_model: An existing spark pipeline model (This is a fitted pipeline) | ||
:param inputCol: The input column to the dataframe for the pytorch network | ||
:param predictionCol: The prediction column. | ||
:param useVectorOut: option to use a vector output. | ||
:return: a spark PipelineModel | ||
""" | ||
|
||
spark_model = create_spark_torch_model(network, inputCol, predictionCol, useVectorOut) | ||
return PipelineModel(stages=[pipeline_model, spark_model]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters