Skip to content

Commit

Permalink
Merge pull request #165 from AlexanderPfefferle/onnx_exportable
Browse files Browse the repository at this point in the history
enable ONNX export
  • Loading branch information
LeoGrin authored Feb 7, 2025
2 parents 3d0647f + 133e5ab commit 634efcd
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 78 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ jobs:
- name: Install dependencies
run: |
uv pip install --system --no-deps .
uv pip install --system pytest
# onnx is required for onnx export tests
# we don't install all dev dependencies here for speed
uv pip install --system -r requirements.txt
uv pip install --system pytest onnx
- name: Initialize submodules
run: git submodule update --init --recursive
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dev = [
"mypy",
# Test
"pytest",
"onnx", # required for onnx export tests
# Docs
"mkdocs",
"mkdocs-material",
Expand Down
20 changes: 13 additions & 7 deletions src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from torch import nn


# TODO(eddiebergman): These were used before but I have no idea why.
# We use the implementations given by torch for now.
# TODO(Arjun): Enabling these again because their behaviour is a little
# different from torch's implementation (see Issue #2). We should check if this makes
# a difference in the results.
# usage of custom implementations is required to support ONNX export
def torch_nansum(x: torch.Tensor, axis=None, keepdim=False, dtype=None):
nan_mask = torch.isnan(x)
masked_input = torch.where(
nan_mask,
torch.tensor(0.0, device=x.device, dtype=x.dtype),
x,
)
return torch.sum(masked_input, axis=axis, keepdim=keepdim, dtype=dtype)


def torch_nanmean(
x: torch.Tensor,
axis: int = 0,
Expand Down Expand Up @@ -46,7 +52,7 @@ def torch_nanstd(x: torch.Tensor, axis: int = 0):
dim=axis,
)
return torch.sqrt(
torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore
torch_nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore
)


Expand Down Expand Up @@ -465,7 +471,7 @@ def _fit(self, x: torch.Tensor, single_eval_pos: int, **kwargs: Any) -> None:
single_eval_pos: The position to use for single evaluation.
**kwargs: Additional keyword arguments (unused).
"""
self.feature_means_ = torch.nanmean(x[:single_eval_pos], dim=0)
self.feature_means_ = torch_nanmean(x[:single_eval_pos], axis=0)

def _transform(
self,
Expand Down
121 changes: 52 additions & 69 deletions src/tabpfn/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import random
import warnings
from collections.abc import Callable, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal

Expand All @@ -25,6 +26,20 @@
DEFAULT_EMSIZE = 128


@contextmanager
def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, None]:
torch_rng_state = torch.get_rng_state()
if torch.cuda.is_available():
torch_cuda_rng_state = torch.cuda.get_rng_state(device=device)
torch.manual_seed(seed)
try:
yield
finally:
torch.set_rng_state(torch_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(torch_cuda_rng_state, device=device)


class LayerStack(nn.Module):
"""Similar to nn.Sequential, but with support for passing keyword arguments
to layers and stacks the same layer multiple times.
Expand Down Expand Up @@ -294,17 +309,6 @@ def __init__( # noqa: C901, D417, PLR0913
self.cached_feature_positional_embeddings: torch.Tensor | None = None
self.seed = seed if seed is not None else random.randint(0, 1_000_000) # noqa: S311

# Device on which the generator was last initialized.
# If loading from a checkpoint, this might be false,
# but it will be set to the correct device on the first forward pass.
self.generator_device = "cpu"
self._init_rnd()

def _init_rnd(self) -> None:
self.generator = SerializableGenerator(device=self.generator_device)
if self.seed: # This can be none if set outside of the model.
self.generator.manual_seed(self.seed)

def reset_save_peak_mem_factor(self, factor: int | None = None) -> None:
"""Sets the save_peak_mem_factor for all layers.
Expand Down Expand Up @@ -377,7 +381,6 @@ def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: # noqa
Returns:
The output of the model, which can be a tensor or a dictionary of tensors.
"""
self._init_rnd()
half_layers = kwargs.pop("half_layers", False)
assert half_layers is False

Expand Down Expand Up @@ -694,57 +697,47 @@ def add_embeddings( # noqa: C901, PLR0912
x += self.cached_embeddings[None, None]
return x, y

if (
self.generator_device != self.generator.device
or self.generator_device != x.device
):
self.generator_device = x.device
self._init_rnd()

if self.feature_positional_embedding == "normal_rand_vec":
embs = torch.randn(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
x += embs[None, None]
elif self.feature_positional_embedding == "uni_rand_vec":
embs = (
torch.rand(
with isolate_torch_rng(self.seed, device=x.device):
if self.feature_positional_embedding == "normal_rand_vec":
embs = torch.randn(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
* 2
- 1
)
x += embs[None, None]
elif self.feature_positional_embedding == "learned":
w = self.feature_positional_embedding_embeddings.weight
embs = w[
torch.randint(
0,
w.shape[0],
(x.shape[2],),
generator=self.generator,
x += embs[None, None]
elif self.feature_positional_embedding == "uni_rand_vec":
embs = (
torch.rand(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
)
* 2
- 1
)
]
x += embs[None, None]
elif self.feature_positional_embedding == "subspace":
embs = torch.randn(
(x.shape[2], x.shape[3] // 4),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
embs = self.feature_positional_embedding_embeddings(embs)
x += embs[None, None]
elif self.feature_positional_embedding is None:
embs = None
else:
raise ValueError(f"Unknown {self.feature_positional_embedding=}")
x += embs[None, None]
elif self.feature_positional_embedding == "learned":
w = self.feature_positional_embedding_embeddings.weight
embs = w[
torch.randint(
0,
w.shape[0],
(x.shape[2],),
)
]
x += embs[None, None]
elif self.feature_positional_embedding == "subspace":
embs = torch.randn(
(x.shape[2], x.shape[3] // 4),
device=x.device,
dtype=x.dtype,
)
embs = self.feature_positional_embedding_embeddings(embs)
x += embs[None, None]
elif self.feature_positional_embedding is None:
embs = None
else:
raise ValueError(f"Unknown {self.feature_positional_embedding=}")

self.cached_embeddings = None
if cache_embeddings and embs is not None:
Expand Down Expand Up @@ -869,13 +862,3 @@ def _add_pos_emb(
# TODO(old) Double check the ordering is right
for n, pe_ in zip(graph.nodes(), pe):
graph.nodes[n]["positional_encoding"] = pe_


class SerializableGenerator(torch.Generator):
"""A serializable version of the torch.Generator, that cna be saved and pickled."""

def __getstate__(self) -> Any:
return self.__dict__

def __setstate__(self, d: Any) -> None:
self.__dict__ = d
65 changes: 65 additions & 0 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from itertools import product
from typing import Callable, Literal

Expand All @@ -11,6 +12,7 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
from torch import nn

from tabpfn import TabPFNClassifier
from tabpfn.preprocessing import PreprocessorConfig
Expand Down Expand Up @@ -219,3 +221,66 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray])
prob_dict = model_dict.predict_proba(X)
prob_obj = model_obj.predict_proba(X)
np.testing.assert_array_almost_equal(prob_dict, prob_obj)


class ModelWrapper(nn.Module):
def __init__(self, original_model): # noqa: D107
super().__init__()
self.model = original_model

def forward(
self,
X,
y,
single_eval_pos,
only_return_standard_out,
categorical_inds,
):
return self.model(
None,
X,
y,
single_eval_pos=single_eval_pos,
only_return_standard_out=only_return_standard_out,
categorical_inds=categorical_inds,
)


@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
X, y = X_y
with torch.no_grad():
classifier = TabPFNClassifier(n_estimators=1, device="cpu", random_state=42)
# load the model so we can access it via classifier.model_
classifier.fit(X, y)
# this is necessary if cuda is available
classifier.predict(X)
# replicate the above call with random tensors of same shape
X = torch.randn(
(X.shape[0] * 2, 1, X.shape[1] + 1),
generator=torch.Generator().manual_seed(42),
)
y = (
torch.rand(y.shape, generator=torch.Generator().manual_seed(42))
.round()
.to(torch.float32)
)
dynamic_axes = {
"X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"},
"y": {0: "num_labels"},
}
torch.onnx.export(
ModelWrapper(classifier.model_).eval(),
(X, y, y.shape[0], True, []),
io.BytesIO(),
input_names=[
"X",
"y",
"single_eval_pos",
"only_return_standard_out",
"categorical_inds",
],
output_names=["output"],
opset_version=17, # using 17 since we use torch>=2.1
dynamic_axes=dynamic_axes,
)
65 changes: 64 additions & 1 deletion tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from itertools import product
from typing import Callable, Literal

Expand All @@ -11,6 +12,7 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
from torch import nn

from tabpfn import TabPFNRegressor
from tabpfn.preprocessing import PreprocessorConfig
Expand Down Expand Up @@ -111,7 +113,6 @@ def test_sklearn_compatible_estimator(
"check_methods_sample_order_invariance",
):
estimator.inference_precision = torch.float64
if check.func.__name__ == "check_methods_sample_order_invariance": # type: ignore
pytest.xfail("We're not at 1e-7 difference yet")
check(estimator)

Expand Down Expand Up @@ -217,3 +218,65 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray])
q_obj,
err_msg="Quantile predictions differ",
)


class ModelWrapper(nn.Module):
def __init__(self, original_model): # noqa: D107
super().__init__()
self.model = original_model

def forward(
self,
X,
y,
single_eval_pos,
only_return_standard_out,
categorical_inds,
):
return self.model(
None,
X,
y,
single_eval_pos=single_eval_pos,
only_return_standard_out=only_return_standard_out,
categorical_inds=categorical_inds,
)


# WARNING: unstable for scipy<1.11.0
@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
X, y = X_y
with torch.no_grad():
regressor = TabPFNRegressor(n_estimators=1, device="cpu", random_state=43)
# load the model so we can access it via classifier.model_
regressor.fit(X, y)
# this is necessary if cuda is available
regressor.predict(X)
# replicate the above call with random tensors of same shape
X = torch.randn(
(X.shape[0] * 2, 1, X.shape[1] + 1),
generator=torch.Generator().manual_seed(42),
)
y = (torch.randn(y.shape, generator=torch.Generator().manual_seed(42)) > 0).to(
torch.float32,
)
dynamic_axes = {
"X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"},
"y": {0: "num_labels"},
}
torch.onnx.export(
ModelWrapper(regressor.model_).eval(),
(X, y, y.shape[0], True, []),
io.BytesIO(),
input_names=[
"X",
"y",
"single_eval_pos",
"only_return_standard_out",
"categorical_inds",
],
output_names=["output"],
opset_version=17, # using 17 since we use torch>=2.1
dynamic_axes=dynamic_axes,
)

0 comments on commit 634efcd

Please sign in to comment.