diff --git a/pfrl/functions/bound_by_tanh.py b/pfrl/functions/bound_by_tanh.py index 0218f1d26..2b9f656ac 100644 --- a/pfrl/functions/bound_by_tanh.py +++ b/pfrl/functions/bound_by_tanh.py @@ -20,4 +20,4 @@ def bound_by_tanh(x, low, high): high = torch.as_tensor(high, dtype=x.dtype, device=x.device) scale = (high - low) / 2 loc = (high + low) / 2 - return nn.functional.tanh(x) * scale + loc + return torch.tanh(x) * scale + loc