Skip to content

Commit

Permalink
TensorSpec.encode domain check parametrization (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 26, 2022
1 parent aad5d04 commit 571d5aa
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import os
from dataclasses import dataclass
from textwrap import indent
from typing import (
Expand Down Expand Up @@ -41,6 +42,17 @@

INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]

_NO_CHECK_SPEC_ENCODE = os.environ.get("NO_CHECK_SPEC_ENCODE", False)
if _NO_CHECK_SPEC_ENCODE in ("0", "False", False):
_NO_CHECK_SPEC_ENCODE = False
elif _NO_CHECK_SPEC_ENCODE in ("1", "True", True):
_NO_CHECK_SPEC_ENCODE = True
else:
raise NotImplementedError(
"NO_CHECK_SPEC_ENCODE should be in 'True', 'False', '0' or '1'. "
f"Got {_NO_CHECK_SPEC_ENCODE} instead."
)


def _default_dtype_and_device(
dtype: Union[None, torch.dtype],
Expand Down Expand Up @@ -214,7 +226,8 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
):
val = val.copy()
val = torch.as_tensor(val, dtype=self.dtype, device=self.device)
self.assert_is_in(val)
if not _NO_CHECK_SPEC_ENCODE:
self.assert_is_in(val)
return val

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
Expand Down

0 comments on commit 571d5aa

Please sign in to comment.