-
Notifications
You must be signed in to change notification settings - Fork 181
/
Copy pathvalue_rescale.py
89 lines (75 loc) · 3.91 KB
/
value_rescale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Typically, we need to apply normalization functions in RL training to reduce the scale of some predictions of neural networks (e.g. value function) to enhance the RL training process.
In this document, we will demonstrate two kinds of data normalization methods and their corresponding inverse operations.
- The first one is ``value_transform`` , which can reduce the scale of the action-value function. Its corresponding inverse operation is ``value_inv_transform`` . <link https://arxiv.org/pdf/1805.11593.pdf link>
- The second one is ``symlog`` , which is another approach to normalize the input tensor. Its corresponding inverse operation is ``inv_symlog`` . <link https://arxiv.org/pdf/2301.04104.pdf link>
"""
import torch
def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
"""
**Overview**:
A function to reduce the scale of the action-value function. For extensive reading, please refer to: Achieving Consistent Performance on Atari <link https://arxiv.org/abs/1805.11593 link>
Given the input tensor ``x`` , this function will return the normalized tensor.
The argument ``eps`` is a hyper-parameter that controls the additive regularization term to ensure the corresponding inverse operation is Lipschitz continuous.
"""
# Core implementation.
# The formula of the normalization is: $$h(x) = sign(x)(\sqrt{(|x|+1)} - 1) + \epsilon * x$$
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
# delimiter
def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
"""
**Overview**:
The inverse form of value transform. Given the input tensor ``x`` , this function will return the unnormalized tensor.
"""
# The formula of the unnormalization is: $$h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\epsilon(|x|+1+\epsilon)}-1}{2\epsilon})}^2-1)$$
return torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps)) ** 2 - 1)
# delimiter
def symlog(x: torch.Tensor) -> torch.Tensor:
"""
**Overview**:
A function to normalize the targets. For extensive reading, please refer to: Mastering Diverse Domains through World Models <link https://arxiv.org/abs/2301.04104 link>
Given the input tensor ``x`` , this function will return the normalized tensor.
"""
# The formula of the normalization is: $$symlog(x) = sign(x)(\ln{|x|+1})$$
return torch.sign(x) * (torch.log(torch.abs(x) + 1))
# delimiter
def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
**Overview**:
The inverse form of symlog. Given the input tensor ``x`` , this function will return the unnormalized tensor.
"""
# The formula of the unnormalization is: $$symexp(x) = sign(x)(\exp{|x|}-1)$$
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
# delimiter
def test_value_transform():
"""
**Overview**:
Generate fake data and test the ``value_transform`` and ``value_inv_transform`` functions.
"""
# Generate fake data.
test_x = torch.randn(10)
# Normalize the generated data.
normalized_x = value_transform(test_x)
assert normalized_x.shape == (10,)
# Unnormalize the data.
unnormalized_x = value_inv_transform(normalized_x)
# Test whether the data before and after the transformation is the same.
assert torch.sum(torch.abs(test_x - unnormalized_x)) < 1e-3
# delimiter
def test_symlog():
"""
**Overview**:
Generate fake data and test the ``symlog`` and ``inv_symlog`` functions.
"""
# Generate fake data.
test_x = torch.randn(10)
# Normalize the generated data.
normalized_x = symlog(test_x)
assert normalized_x.shape == (10,)
# Unnormalize the data.
unnormalized_x = inv_symlog(normalized_x)
# Test whether the data before and after the transformation is the same.
assert torch.sum(torch.abs(test_x - unnormalized_x)) < 1e-3
if __name__ == '__main__':
test_value_transform()
test_symlog()