Skip to content

Commit

Permalink
Added inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dmmiller612 committed Dec 4, 2019
1 parent 8bee217 commit 1b855f9
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 1 deletion.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Then you can perform predictions, etc with:
predictions = p.transform(df)
```

#### Getting the pytorch model from the training session
#### Getting the Pytorch model from the training session

If you just want to get the Pytorch model after training, you can execute the following code:

Expand All @@ -202,6 +202,24 @@ stm = SparkTorch(
py_model = stm.getPytorchModel()
```


#### Using a pretrained Pytorch model for inference

If you already have a trained Pytorch model, you can attach it your existing pipeline by directly creating a SparkTorchModel.
This can be done by running the following:

```python
from sparktorch import create_spark_torch_model

net = ... # Pretrained Network

spark_torch_model = create_spark_torch_model(
net,
inputCol='features',
predictionCol='predictions'
)
```

## Running

One big thing to remember is to add the `--executor cores 1` option to spark to ensure
Expand Down
1 change: 1 addition & 0 deletions sparktorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sparktorch.util import serialize_torch_obj, serialize_torch_obj_lazy
from sparktorch.torch_distributed import SparkTorch
from sparktorch.pipeline_util import PysparkPipelineWrapper
from sparktorch.inference import create_spark_torch_model
61 changes: 61 additions & 0 deletions sparktorch/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from sparktorch.torch_distributed import SparkTorchModel
import torch.nn as nn
import codecs
import dill
from pyspark.ml.pipeline import PipelineModel


def convert_to_serialized_torch(network: nn.Module) -> str:
"""
Converts an existing torch network to a serialized string.
:param network: a nn.Module that you want to serialize
:return: Returns the serialized torch model.
"""
return codecs.encode(dill.dumps(network), "base64").decode()


def create_spark_torch_model(
network: nn.Module,
inputCol: str = 'features',
predictionCol: str = 'predicted',
useVectorOut: bool = False
) -> SparkTorchModel:
"""
Creates a spark SparkTorchModel from an already trained network. Useful for running inference on large datasets.
:param network: an already trained network
:param inputCol: The spark dataframe input column
:param predictionCol: The spark dataframe prediction columns
:param useVectorOut: Determines whether the output should return a spark vector
:return: Returns a SparkTorchModel
"""

return SparkTorchModel(
inputCol=inputCol,
predictionCol=predictionCol,
modStr=convert_to_serialized_torch(network),
useVectorOut=useVectorOut
)


def attach_pytorch_model_to_pipeline(
network: nn.Module,
pipeline_model: PipelineModel,
inputCol: str = 'features',
predictionCol: str = 'predicted',
useVectorOut: bool = False
) -> PipelineModel:
"""
Attaches a pytorch model to an existing pyspark pipeline.
:param network: Pytorch Network
:param pipeline_model: An existing spark pipeline model (This is a fitted pipeline)
:param inputCol: The input column to the dataframe for the pytorch network
:param predictionCol: The prediction column.
:param useVectorOut: option to use a vector output.
:return: a spark PipelineModel
"""

spark_model = create_spark_torch_model(network, inputCol, predictionCol, useVectorOut)
return PipelineModel(stages=[pipeline_model, spark_model])
25 changes: 25 additions & 0 deletions sparktorch/tests/test_sparktorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyspark.ml.linalg import Vectors
import torch.nn as nn
import torch
from sparktorch.inference import create_spark_torch_model
from sparktorch.util import serialize_torch_obj, serialize_torch_obj_lazy
from sparktorch.torch_distributed import SparkTorch
from sparktorch.tests.simple_net import Net, AutoEncoder, ClassificationNet, NetworkWithParameters
Expand Down Expand Up @@ -79,6 +80,29 @@ def test_model_parameters(data, network_with_params):
assert py_model.fc2 is not None


def test_inference(lazy_model, data):
stm = SparkTorch(
inputCol='features',
labelCol='label',
predictionCol='predictions',
torchObj=lazy_model,
verbose=1,
iters=10
).fit(data)

first_res = stm.transform(data).take(1)

res = stm.getPytorchModel()
spark_model = create_spark_torch_model(
res,
'features',
'predictions'
)

res = spark_model.transform(data).take(1)
assert first_res == res


def test_lazy(lazy_model, data):
stm = SparkTorch(
inputCol='features',
Expand Down Expand Up @@ -259,3 +283,4 @@ def test_validation_pct(data, general_model):

res = stm.transform(data).take(1)
assert 'predictions' in res[0]

0 comments on commit 1b855f9

Please sign in to comment.