Skip to content

Commit

Permalink
fix: handle extra arguments to validate in internal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Nov 28, 2023
1 parent 265389c commit 22ad2fe
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
20 changes: 10 additions & 10 deletions lantern/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate(cls, data, config=None, field=None) -> np.ndarray:
def ndim(cls, ndim) -> Numpy:
class InheritNumpy(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.ndim != ndim:
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
Expand All @@ -36,7 +36,7 @@ def validate(cls, data):
def dims(cls, dims) -> Numpy:
class InheritNumpy(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.ndim != len(dims):
raise ValueError(
Expand All @@ -50,7 +50,7 @@ def validate(cls, data):
def shape(cls, *sizes) -> Numpy:
class InheritNumpy(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
for data_size, size in zip(data.shape, sizes):
if size != -1 and data_size != size:
Expand All @@ -63,7 +63,7 @@ def validate(cls, data):
def between(cls, ge, le) -> Numpy:
class InheritNumpy(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.min() < ge:
Expand All @@ -83,7 +83,7 @@ def validate(cls, data):
def ge(cls, ge) -> Numpy:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.min() < ge:
raise ValueError(
Expand All @@ -96,7 +96,7 @@ def validate(cls, data):
def le(cls, le) -> Numpy:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.max() > le:
Expand All @@ -111,7 +111,7 @@ def validate(cls, data):
def gt(cls, gt) -> Numpy:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.min() <= gt:
Expand All @@ -123,7 +123,7 @@ def validate(cls, data):
def lt(cls, lt) -> Numpy:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.max() >= lt:
Expand All @@ -136,7 +136,7 @@ def validate(cls, data):
def ne(cls, ne) -> Numpy:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if (data == ne).any():
Expand All @@ -149,7 +149,7 @@ def validate(cls, data):
def dtype(cls, dtype) -> Numpy:
class InheritNumpy(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.dtype == dtype:
return data
Expand Down
22 changes: 11 additions & 11 deletions lantern/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def validate(cls, data, config=None, field=None) -> torch.Tensor:
def ndim(cls, ndim) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.ndim != ndim:
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
Expand All @@ -46,7 +46,7 @@ def validate(cls, data):
def dims(cls, dims) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.ndim != len(dims):
raise ValueError(
Expand All @@ -60,7 +60,7 @@ def validate(cls, data):
def shape(cls, *sizes) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
for data_size, size in zip(data.shape, sizes):
if size != -1 and data_size != size:
Expand All @@ -73,7 +73,7 @@ def validate(cls, data):
def between(cls, ge, le) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.min() < ge:
raise ValueError(
Expand All @@ -92,7 +92,7 @@ def validate(cls, data):
def ge(cls, ge) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.min() < ge:
raise ValueError(
Expand All @@ -105,7 +105,7 @@ def validate(cls, data):
def le(cls, le) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.max() > le:
Expand All @@ -120,7 +120,7 @@ def validate(cls, data):
def gt(cls, gt) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.min() <= gt:
Expand All @@ -132,7 +132,7 @@ def validate(cls, data):
def lt(cls, lt) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if data.max() >= lt:
Expand All @@ -145,7 +145,7 @@ def validate(cls, data):
def ne(cls, ne) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)

if (data == ne).any():
Expand All @@ -158,7 +158,7 @@ def validate(cls, data):
def device(cls, device) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
return super().validate(data).to(device)

return InheritTensor
Expand All @@ -175,7 +175,7 @@ def cuda(cls) -> Tensor:
def dtype(cls, dtype) -> Tensor:
class InheritTensor(cls):
@classmethod
def validate(cls, data):
def validate(cls, data, config=None, field=None):
data = super().validate(data)
if data.dtype == dtype:
return data
Expand Down

0 comments on commit 22ad2fe

Please sign in to comment.