Skip to content

Commit

Permalink
fix: do not serialize to list on model dump
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Dec 1, 2023
1 parent f30633e commit 8802126
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
15 changes: 15 additions & 0 deletions lantern/functional_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ def staticmethod(cls, fn):
@classmethod
def classmethod(cls, fn):
return cls.setattr(fn.__name__, classmethod(fn))


def test_replace_same_device():
import torch

from .tensor import Tensor

class A(FunctionalBase):
x: Tensor
y: int

a = A(x=torch.tensor([1, 2, 3]).to("meta"), y=2)
b = a.replace(y=2)

assert b.x.device == a.x.device
2 changes: 1 addition & 1 deletion lantern/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __get_pydantic_core_schema__(
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.tolist()
lambda instance: instance
),
)

Expand Down
2 changes: 1 addition & 1 deletion lantern/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __get_pydantic_core_schema__(
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.tolist()
lambda instance: instance
),
)

Expand Down

0 comments on commit 8802126

Please sign in to comment.