From 6a6c88817a88bab35611abd52403857d6a639030 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Nov 2021 17:35:42 +0900 Subject: [PATCH 1/7] icm refactorized --- jorldy/core/network/icm.py | 564 +++++++++++++++++++++---------------- 1 file changed, 315 insertions(+), 249 deletions(-) diff --git a/jorldy/core/network/icm.py b/jorldy/core/network/icm.py index a87cbb83..ce8de40d 100644 --- a/jorldy/core/network/icm.py +++ b/jorldy/core/network/icm.py @@ -4,6 +4,169 @@ from .rnd import * +def mlp_head_weight(D_in, D_hidden, feature_size): + fc1 = torch.nn.Linear(D_in, D_hidden) + fc2 = torch.nn.Linear(D_hidden, feature_size) + + return fc1, fc2 + + +def mlp_batch_norm(D_hidden, feature_size): + bn1 = torch.nn.BatchNorm1d(D_hidden) + bn2 = torch.nn.BatchNorm1d(feature_size) + + bn1_next = torch.nn.BatchNorm1d(D_hidden) + + return bn1, bn2, bn1_next + + +def mlp_head(s, s_next, batch_norm, fc1, fc2, bn1, bn2, bn1_next): + if batch_norm: + s = F.elu(bn1(fc1(s))) + s = F.elu(bn2(fc2(s))) + + s_next = F.elu(bn1_next(fc1(s_next))) + else: + s = F.elu(fc1(s)) + s = F.elu(fc2(s)) + + s_next = F.elu(fc1(s_next)) + + s_next = F.elu(fc2(s_next)) + + return s, s_next + + +def conv_head_weight(D_in): + conv1 = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=3, stride=2 + ) + conv2 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) + conv3 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) + conv4 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) + + dim1 = ((D_in[1] - 3) // 2 + 1, (D_in[2] - 3) // 2 + 1) + dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) + dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) + dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) + + feature_size = 32 * dim4[0] * dim4[1] + + return conv1, conv2, conv3, conv4, feature_size + + +def conv_batch_norm(): + bn1 = torch.nn.BatchNorm2d(32) + bn2 = torch.nn.BatchNorm2d(32) + bn3 = torch.nn.BatchNorm2d(32) + bn4 = torch.nn.BatchNorm2d(32) + + bn1_next = torch.nn.BatchNorm2d(32) + bn2_next = torch.nn.BatchNorm2d(32) + bn3_next = torch.nn.BatchNorm2d(32) + + return bn1, bn2, bn3, bn4, bn1_next, bn2_next, bn3_next + + +def conv_head( + s, + s_next, + batch_norm, + conv1, + conv2, + conv3, + conv4, + bn1, + bn2, + bn3, + bn4, + bn1_next, + bn2_next, + bn3_next, +): + if batch_norm: + s = F.elu(bn1(conv1(s))) + s = F.elu(bn2(conv2(s))) + s = F.elu(bn3(conv3(s))) + s = F.elu(bn4(conv4(s))) + + s_next = F.elu(bn1_next(conv1(s_next))) + s_next = F.elu(bn2_next(conv2(s_next))) + s_next = F.elu(bn3_next(conv3(s_next))) + + else: + s = F.elu(conv1(s)) + s = F.elu(conv2(s)) + s = F.elu(conv3(s)) + s = F.elu(conv4(s)) + + s_next = F.elu(conv1(s_next)) + s_next = F.elu(conv2(s_next)) + s_next = F.elu(conv3(s_next)) + + s_next = F.elu(conv4(s_next)) + s = s.view(s.size(0), -1) + s_next = s_next.view(s_next.size(0), -1) + + return s, s_next + + +def forward_weight(feature_size, D_hidden, D_out, action_type): + if action_type == "discrete": + forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) + forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) + else: + forward_fc1 = torch.nn.Linear(feature_size + D_out, D_hidden) + forward_fc2 = torch.nn.Linear(D_hidden + D_out, feature_size) + + forward_loss = torch.nn.MSELoss() + + return forward_fc1, forward_fc2, forward_loss + + +def inverse_weight(feature_size, D_hidden, D_out, action_type): + inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) + inverse_fc2 = torch.nn.Linear(D_hidden, D_out) + + inverse_loss = ( + torch.nn.CrossEntropyLoss() if action_type == "discrete" else torch.nn.MSELoss() + ) + + return inverse_fc1, inverse_fc2, inverse_loss + + +def forward_model(s, a, s_next, forward_loss, forward_fc1, forward_fc2): + x_forward = torch.cat((s, a), axis=1) + x_forward = F.relu(forward_fc1(x_forward)) + x_forward = torch.cat((x_forward, a), axis=1) + x_forward = forward_fc2(x_forward) + + l_f = forward_loss(x_forward, s_next.detach()) + + return x_forward, l_f + + +def inverse_model(s, a, s_next, action_type, inverse_loss, inverse_fc1, inverse_fc2): + x_inverse = torch.cat((s, s_next), axis=1) + x_inverse = F.relu(inverse_fc1(x_inverse)) + x_inverse = inverse_fc2(x_inverse) + + if action_type == "discrete": + l_i = inverse_loss(x_inverse, a.view(-1).long()) + else: + l_i = inverse_loss(x_inverse, a) + + return l_i + + +def ri_update(r_i, num_workers, rff, update_rms_ri): + ri_T = r_i.view(num_workers, -1).T # (n_batch, n_workers) + rewems = torch.stack( + [rff.update(rit.detach()) for rit in ri_T] + ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) + update_rms_ri(rewems) + + class ICM_MLP(torch.nn.Module): def __init__( self, @@ -34,30 +197,15 @@ def __init__( feature_size = 256 - self.fc1 = torch.nn.Linear(self.D_in, D_hidden) - self.fc2 = torch.nn.Linear(D_hidden, feature_size) - - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() - - if self.batch_norm: - self.bn1 = torch.nn.BatchNorm1d(D_hidden) - self.bn2 = torch.nn.BatchNorm1d(feature_size) + self.fc1, self.fc2 = mlp_head_weight(D_in, D_hidden, feature_size) + self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( + feature_size, D_hidden, D_out, action_type + ) + self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( + feature_size, D_hidden, D_out, action_type + ) - self.bn1_next = torch.nn.BatchNorm1d(D_hidden) + self.bn1, self.bn2, self.bn1_next = mlp_batch_norm(D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -68,51 +216,44 @@ def update_rms_ri(self, v): def forward(self, s, a, s_next, update_ri=False): if self.obs_normalize: s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) - if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - s = F.elu(self.bn1(self.fc1(s))) - s = F.elu(self.bn2(self.fc2(s))) - - s_next = F.elu(self.bn1_next(self.fc1(s_next))) - else: - s = F.elu(self.fc1(s)) - s = F.elu(self.fc2(s)) - - s_next = F.elu(self.fc1(s_next)) - - s_next = F.elu(self.fc2(s_next)) + s, s_next = mlp_head( + s, + s_next, + self.batch_norm, + self.fc1, + self.fc2, + self.bn1, + self.bn2, + self.bn1_next, + ) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model( + s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 + ) + + # Inverse Model + l_i = inverse_model( + s, + a, + s_next, + self.action_type, + self.inverse_loss, + self.inverse_fc1, + self.inverse_fc2, + ) + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i @@ -144,51 +285,26 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - self.conv1 = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=3, stride=2 + self.conv1, self.conv2, self.conv3, self.conv4, feature_size = conv_head_weight( + self.D_in ) - self.conv2 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 + self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( + feature_size, D_hidden, D_out, action_type ) - self.conv3 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 + self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( + feature_size, D_hidden, D_out, action_type ) - self.conv4 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - - dim1 = ((self.D_in[1] - 3) // 2 + 1, (self.D_in[2] - 3) // 2 + 1) - dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) - dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) - dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) - - feature_size = 32 * dim4[0] * dim4[1] - - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() if self.batch_norm: - self.bn1 = torch.nn.BatchNorm2d(32) - self.bn2 = torch.nn.BatchNorm2d(32) - self.bn3 = torch.nn.BatchNorm2d(32) - self.bn4 = torch.nn.BatchNorm2d(32) - - self.bn1_next = torch.nn.BatchNorm2d(32) - self.bn2_next = torch.nn.BatchNorm2d(32) - self.bn3_next = torch.nn.BatchNorm2d(32) + ( + self.bn1, + self.bn2, + self.bn3, + self.bn4, + self.bn1_next, + self.bn2_next, + self.bn3_next, + ) = conv_batch_norm() def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -199,62 +315,50 @@ def update_rms_ri(self, v): def forward(self, s, a, s_next, update_ri=False): if self.obs_normalize: s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) - if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - s = F.elu(self.bn1(self.conv1(s))) - s = F.elu(self.bn2(self.conv2(s))) - s = F.elu(self.bn3(self.conv3(s))) - s = F.elu(self.bn4(self.conv4(s))) - - s_next = F.elu(self.bn1_next(self.conv1(s_next))) - s_next = F.elu(self.bn2_next(self.conv2(s_next))) - s_next = F.elu(self.bn3_next(self.conv3(s_next))) - - else: - s = F.elu(self.conv1(s)) - s = F.elu(self.conv2(s)) - s = F.elu(self.conv3(s)) - s = F.elu(self.conv4(s)) - - s_next = F.elu(self.conv1(s_next)) - s_next = F.elu(self.conv2(s_next)) - s_next = F.elu(self.conv3(s_next)) - - s_next = F.elu(self.conv4(s_next)) - s = s.view(s.size(0), -1) - s_next = s_next.view(s_next.size(0), -1) + s, s_next = conv_head( + s, + s_next, + self.batch_norm, + self.conv1, + self.conv2, + self.conv3, + self.conv4, + self.bn1, + self.bn2, + self.bn3, + self.bn4, + self.bn1_next, + self.bn2_next, + self.bn3_next, + ) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model( + s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 + ) + + # Inverse Model + l_i = inverse_model( + s, + a, + s_next, + self.action_type, + self.inverse_loss, + self.inverse_fc1, + self.inverse_fc2, + ) + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i @@ -290,66 +394,40 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ################################## Conv HEAD ################################## - self.conv1 = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=3, stride=2 - ) - self.conv2 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv3 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv4 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - - dim1 = ((self.D_in_img[1] - 3) // 2 + 1, (self.D_in_img[2] - 3) // 2 + 1) - dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) - dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) - dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) + ( + self.conv1, + self.conv2, + self.conv3, + self.conv4, + feature_size_img, + ) = conv_head_weight(self.D_in_img) - feature_size_img = 32 * dim4[0] * dim4[1] - - ################################## MLP HEAD ################################## feature_size_mlp = 256 - self.fc1_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) - ############################################################################## + self.fc1, self.fc2 = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() + self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( + feature_size, D_hidden, D_out, action_type + ) + self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( + feature_size, D_hidden, D_out, action_type + ) if self.batch_norm: - self.bn1_conv = torch.nn.BatchNorm2d(32) - self.bn2_conv = torch.nn.BatchNorm2d(32) - self.bn3_conv = torch.nn.BatchNorm2d(32) - self.bn4_conv = torch.nn.BatchNorm2d(32) - - self.bn1_next_conv = torch.nn.BatchNorm2d(32) - self.bn2_next_conv = torch.nn.BatchNorm2d(32) - self.bn3_next_conv = torch.nn.BatchNorm2d(32) - - self.bn1_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_mlp = torch.nn.BatchNorm1d(feature_size_mlp) - - self.bn1_next_mlp = torch.nn.BatchNorm1d(D_hidden) + self.bn1_mlp, self.bn2_mlp, self.bn1_next_mlp = mlp_batch_norm( + D_hidden, feature_size_mlp + ) + ( + self.bn1_conv, + self.bn2_conv, + self.bn3_conv, + self.bn4_conv, + self.bn1_next_conv, + self.bn2_next_conv, + self.bn3_next_conv, + ) = conv_batch_norm() def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -375,71 +453,59 @@ def forward(self, s, a, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - if self.batch_norm: - s_img = F.elu(self.bn1_conv(self.conv1(s_img))) - s_img = F.elu(self.bn2_conv(self.conv2(s_img))) - s_img = F.elu(self.bn3_conv(self.conv3(s_img))) - s_img = F.elu(self.bn4_conv(self.conv4(s_img))) - - s_next_img = F.elu(self.bn1_next_conv(self.conv1(s_next_img))) - s_next_img = F.elu(self.bn2_next_conv(self.conv2(s_next_img))) - s_next_img = F.elu(self.bn3_next_conv(self.conv3(s_next_img))) - - s_vec = F.elu(self.bn1_mlp(self.fc1_mlp(s_vec))) - s_vec = F.elu(self.bn2_mlp(self.fc2_mlp(s_vec))) - - s_next_vec = F.elu(self.bn1_next_mlp(self.fc1_mlp(s_next_vec))) - else: - s_img = F.elu(self.conv1(s_img)) - s_img = F.elu(self.conv2(s_img)) - s_img = F.elu(self.conv3(s_img)) - s_img = F.elu(self.conv4(s_img)) - - s_next_img = F.elu(self.conv1(s_next_img)) - s_next_img = F.elu(self.conv2(s_next_img)) - s_next_img = F.elu(self.conv3(s_next_img)) - - s_vec = F.elu(self.fc1_mlp(s_vec)) - s_vec = F.elu(self.fc2_mlp(s_vec)) - - s_next_vec = F.elu(self.fc1_mlp(s_next_vec)) - - s_next_img = F.elu(self.conv4(s_next_img)) - s_img = s_img.view(s_img.size(0), -1) - s_next_img = s_next_img.view(s_next_img.size(0), -1) - - s_next_vec = F.elu(self.fc2_mlp(s_next_vec)) + s_vec, s_next_vec = mlp_head( + s, + s_next, + self.batch_norm, + self.fc1, + self.fc2, + self.bn1_mlp, + self.bn2_mlp, + self.bn1_next_mlp, + ) + s_img, s_next_img = conv_head( + s, + s_next, + self.batch_norm, + self.conv1, + self.conv2, + self.conv3, + self.conv4, + self.bn1_conv, + self.bn2_conv, + self.bn3_conv, + self.bn4_conv, + self.bn1_next_conv, + self.bn2_next_conv, + self.bn3_next_conv, + ) s = torch.cat((s_img, s_vec), -1) s_next = torch.cat((s_next_img, s_next_vec), -1) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model( + s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 + ) + + # Inverse Model + l_i = inverse_model( + s, + a, + s_next, + self.action_type, + self.inverse_loss, + self.inverse_fc1, + self.inverse_fc2, + ) + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i From 56d00a42fd94419be0bcde7efab0927e68d457e4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Nov 2021 17:35:57 +0900 Subject: [PATCH 2/7] rnd refactorized --- jorldy/core/network/rnd.py | 469 +++++++++++++++++++++---------------- 1 file changed, 266 insertions(+), 203 deletions(-) diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index bcad4478..a4a12696 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -9,6 +9,148 @@ def normalize_obs(obs, m, v): return torch.clip((obs - m) / (torch.sqrt(v) + 1e-7), min=-5.0, max=5.0) +def mlp_head_weight(D_in, D_hidden, feature_size): + fc1_p = torch.nn.Linear(D_in, D_hidden) + fc2_p = torch.nn.Linear(D_hidden, feature_size) + + fc1_t = torch.nn.Linear(D_in, D_hidden) + fc2_t = torch.nn.Linear(D_hidden, feature_size) + + return fc1_p, fc2_p, fc1_t, fc2_t + + +def mlp_batch_norm(D_hidden, feature_size): + bn1_p = torch.nn.BatchNorm1d(D_hidden) + bn2_p = torch.nn.BatchNorm1d(feature_size) + + bn1_t = torch.nn.BatchNorm1d(D_hidden) + bn2_t = torch.nn.BatchNorm1d(feature_size) + + return bn1_p, bn2_p, bn1_t, bn2_t + + +def mlp_head( + s_next, batch_norm, fc1_p, fc2_p, fc1_t, fc2_t, bn1_p, bn2_p, bn1_t, bn2_t +): + if batch_norm: + p = F.relu(bn1_p(fc1_p(s_next))) + p = F.relu(bn2_p(fc2_p(p))) + + t = F.relu(bn1_t(fc1_t(s_next))) + t = F.relu(bn2_t(fc2_t(t))) + else: + p = F.relu(fc1_p(s_next)) + p = F.relu(fc2_p(p)) + + t = F.relu(fc1_t(s_next)) + t = F.relu(fc2_t(t)) + + return p, t + + +def conv_head_weight(D_in): + dim1 = ((D_in[1] - 8) // 4 + 1, (D_in[2] - 8) // 4 + 1) + dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) + dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) + + feature_size = 64 * dim3[0] * dim3[1] + + # Predictor Networks + conv1_p = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 + ) + conv2_p = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) + conv3_p = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + + # Target Networks + conv1_t = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 + ) + conv2_t = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) + conv3_t = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + + return conv1_p, conv2_p, conv3_p, conv1_t, conv2_t, conv3_t, feature_size + + +def conv_batch_norm(): + bn1_p = torch.nn.BatchNorm2d(32) + bn2_p = torch.nn.BatchNorm2d(64) + bn3_p = torch.nn.BatchNorm2d(64) + + bn1_t = torch.nn.BatchNorm2d(32) + bn2_t = torch.nn.BatchNorm2d(64) + bn3_t = torch.nn.BatchNorm2d(64) + + return bn1_p, bn2_p, bn3_p, bn1_t, bn2_t, bn3_t + + +def conv_head( + s_next, + batch_norm, + conv1_p, + conv2_p, + conv3_p, + conv1_t, + conv2_t, + conv3_t, + bn1_p, + bn2_p, + bn3_p, + bn1_t, + bn2_t, + bn3_t, +): + if batch_norm: + p = F.relu(bn1_p(conv1_p(s_next))) + p = F.relu(bn2_p(conv2_p(p))) + p = F.relu(bn3_p(conv3_p(p))) + + t = F.relu(bn1_t(conv1_t(s_next))) + t = F.relu(bn2_t(conv2_t(t))) + t = F.relu(bn3_t(conv3_t(t))) + else: + p = F.relu(conv1_p(s_next)) + p = F.relu(conv2_p(p)) + p = F.relu(conv3_p(p)) + + t = F.relu(conv1_t(s_next)) + t = F.relu(conv2_t(t)) + t = F.relu(conv3_t(t)) + + p = p.view(p.size(0), -1) + t = t.view(t.size(0), -1) + + return p, t + + +def fc_layers_weight(feature_size, D_hidden): + fc1_p = torch.nn.Linear(feature_size, D_hidden) + fc2_p = torch.nn.Linear(D_hidden, D_hidden) + fc3_p = torch.nn.Linear(D_hidden, D_hidden) + + fc1_t = torch.nn.Linear(feature_size, D_hidden) + + return fc1_p, fc2_p, fc3_p, fc1_t + + +def fc_layers(p, t, fc1_p, fc2_p, fc3_p, fc1_t): + p = F.relu(fc1_p(p)) + p = F.relu(fc2_p(p)) + p = fc3_p(p) + + t = fc1_t(t) + + return p, t + + +def ri_update(r_i, num_workers, rff, update_rms_ri): + ri_T = r_i.view(num_workers, -1).T # (n_batch, n_workers) + rewems = torch.stack( + [rff.update(rit.detach()) for rit in ri_T] + ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) + update_rms_ri(rewems) + + class RND_MLP(torch.nn.Module): def __init__( self, @@ -34,18 +176,12 @@ def __init__( feature_size = 256 - self.fc1_predict = torch.nn.Linear(self.D_in, D_hidden) - self.fc2_predict = torch.nn.Linear(D_hidden, feature_size) - - self.fc1_target = torch.nn.Linear(self.D_in, D_hidden) - self.fc2_target = torch.nn.Linear(D_hidden, feature_size) - - if batch_norm: - self.bn1_predict = torch.nn.BatchNorm1d(D_hidden) - self.bn2_predict = torch.nn.BatchNorm1d(feature_size) - - self.bn1_target = torch.nn.BatchNorm1d(D_hidden) - self.bn2_target = torch.nn.BatchNorm1d(feature_size) + self.fc1_p, self.fc2_p, self.fc1_t, self.fc2_t = mlp_head_weight( + D_in, D_hidden, feature_size + ) + self.bn1_p, self.bn2_p, self.bn1_t, self.bn2_t = mlp_batch_norm( + D_hidden, feature_size + ) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -57,27 +193,24 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - p = F.relu(self.bn1_predict(self.fc1_predict(s_next))) - p = F.relu(self.bn2_predict(self.fc2_predict(p))) - - t = F.relu(self.bn1_target(self.fc1_target(s_next))) - t = F.relu(self.bn2_target(self.fc2_target(t))) - else: - p = F.relu(self.fc1_predict(s_next)) - p = F.relu(self.fc2_predict(p)) - - t = F.relu(self.fc1_target(s_next)) - t = F.relu(self.fc2_target(t)) + p, t = mlp_head( + s_next, + self.batch_norm, + self.fc1_p, + self.fc2_p, + self.fc1_t, + self.fc2_t, + self.bn1_p, + self.bn2_p, + self.bn1_t, + self.bn2_t, + ) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) @@ -107,49 +240,27 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - dim1 = ((self.D_in[1] - 8) // 4 + 1, (self.D_in[2] - 8) // 4 + 1) - dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) - dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) - - feature_size = 64 * dim3[0] * dim3[1] - - # Predictor Networks - self.conv1_predict = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_predict = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_predict = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - self.fc1_predict = torch.nn.Linear(feature_size, D_hidden) - self.fc2_predict = torch.nn.Linear(D_hidden, D_hidden) - self.fc3_predict = torch.nn.Linear(D_hidden, D_hidden) - - # Target Networks - self.conv1_target = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_target = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_target = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 + ( + self.conv1_p, + self.conv2_p, + self.conv3_p, + self.conv1_t, + self.conv2_t, + self.conv3_t, + feature_size, + ) = conv_head_weight(D_in) + ( + self.bn1_p, + self.bn2_p, + self.bn3_p, + self.bn1_t, + self.bn2_t, + self.bn3_t, + ) = conv_batch_norm() + self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t = fc_layers_weight( + feature_size, D_hidden ) - self.fc1_target = torch.nn.Linear(feature_size, D_hidden) - - if batch_norm: - self.bn1_predict = torch.nn.BatchNorm2d(32) - self.bn2_predict = torch.nn.BatchNorm2d(64) - self.bn3_predict = torch.nn.BatchNorm2d(64) - - self.bn1_target = torch.nn.BatchNorm2d(32) - self.bn2_target = torch.nn.BatchNorm2d(64) - self.bn3_target = torch.nn.BatchNorm2d(64) - def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -161,40 +272,30 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - p = F.relu(self.bn1_predict(self.conv1_predict(s_next))) - p = F.relu(self.bn2_predict(self.conv2_predict(p))) - p = F.relu(self.bn3_predict(self.conv3_predict(p))) - else: - p = F.relu(self.conv1_predict(s_next)) - p = F.relu(self.conv2_predict(p)) - p = F.relu(self.conv3_predict(p)) - - p = p.view(p.size(0), -1) - p = F.relu(self.fc1_predict(p)) - p = F.relu(self.fc2_predict(p)) - p = self.fc3_predict(p) - - if self.batch_norm: - t = F.relu(self.bn1_target(self.conv1_target(s_next))) - t = F.relu(self.bn2_target(self.conv2_target(t))) - t = F.relu(self.bn3_target(self.conv3_target(t))) - else: - t = F.relu(self.conv1_target(s_next)) - t = F.relu(self.conv2_target(t)) - t = F.relu(self.conv3_target(t)) - - t = t.view(t.size(0), -1) - t = self.fc1_target(t) + p, t = conv_head( + s_next, + self.batch_norm, + self.conv1_p, + self.conv2_p, + self.conv3_p, + self.conv1_t, + self.conv2_t, + self.conv3_t, + self.bn1_p, + self.bn2_p, + self.bn3_p, + self.bn1_t, + self.bn2_t, + self.bn3_t, + ) + + p, t = fc_layers(p, t, self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) @@ -229,68 +330,41 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ################################## Conv HEAD ################################## - dim1 = ((self.D_in_img[1] - 8) // 4 + 1, (self.D_in_img[2] - 8) // 4 + 1) - dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) - dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) - - feature_size_img = 64 * dim3[0] * dim3[1] + ( + self.conv1_p, + self.conv2_p, + self.conv3_p, + self.conv1_t, + self.conv2_t, + self.conv3_t, + feature_size_img, + ) = conv_head_weight(self.D_in_img) + ( + self.bn1_p_conv, + self.bn2_p_conv, + self.bn3_p_conv, + self.bn1_t_conv, + self.bn2_t_conv, + self.bn3_t_conv, + ) = conv_batch_norm() - # Predictor Networks - self.conv1_predict = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_predict = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_predict = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - # Target Networks - self.conv1_target = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_target = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_target = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - if batch_norm: - self.bn1_predict_conv = torch.nn.BatchNorm2d(32) - self.bn2_predict_conv = torch.nn.BatchNorm2d(64) - self.bn3_predict_conv = torch.nn.BatchNorm2d(64) - - self.bn1_target_conv = torch.nn.BatchNorm2d(32) - self.bn2_target_conv = torch.nn.BatchNorm2d(64) - self.bn3_target_conv = torch.nn.BatchNorm2d(64) - - ################################## MLP HEAD ################################## feature_size_mlp = 256 - self.fc1_predict_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_predict_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) - - self.fc1_target_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_target_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) - - if batch_norm: - self.bn1_predict_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_predict_mlp = torch.nn.BatchNorm1d(feature_size_mlp) + ( + self.fc1_p_mlp, + self.fc2_p_mlp, + self.fc1_t_mlp, + self.fc2_t_mlp, + ) = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) + self.bn1_p_mlp, self.bn2_p_mlp, self.bn1_t_mlp, self.bn2_t_mlp = mlp_batch_norm( + D_hidden, feature_size + ) - self.bn1_target_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_target_mlp = torch.nn.BatchNorm1d(feature_size_mlp) + feature_size = feature_size_img + feature_size_mlp - ################################## FC Layers ################################## - self.fc1_predict = torch.nn.Linear( - feature_size_img + feature_size_mlp, D_hidden + self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t = fc_layers_weight( + feature_size, D_hidden ) - self.fc2_predict = torch.nn.Linear(D_hidden, D_hidden) - self.fc3_predict = torch.nn.Linear(D_hidden, D_hidden) - - self.fc1_target = torch.nn.Linear(feature_size_img + feature_size_mlp, D_hidden) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -313,57 +387,46 @@ def forward(self, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - ################################## Predict ################################## - if self.batch_norm: - p_i = F.relu(self.bn1_predict_conv(self.conv1_predict(s_next_img))) - p_i = F.relu(self.bn2_predict_conv(self.conv2_predict(p_i))) - p_i = F.relu(self.bn3_predict_conv(self.conv3_predict(p_i))) - - p_v = F.relu(self.bn1_predict_mlp(self.fc1_predict_mlp(s_next_vec))) - p_v = F.relu(self.bn2_predict_mlp(self.fc2_predict_mlp(p_v))) - else: - p_i = F.relu(self.conv1_predict(s_next_img)) - p_i = F.relu(self.conv2_predict(p_i)) - p_i = F.relu(self.conv3_predict(p_i)) - - p_v = F.relu(self.fc1_predict_mlp(s_next_vec)) - p_v = F.relu(self.fc2_predict_mlp(p_v)) - - p_i = p_i.view(p_i.size(0), -1) - p = torch.cat((p_i, p_v), -1) - - p = F.relu(self.fc1_predict(p)) - p = F.relu(self.fc2_predict(p)) - p = self.fc3_predict(p) - - ################################## target ################################## - if self.batch_norm: - t_i = F.relu(self.bn1_target_conv(self.conv1_target(s_next_img))) - t_i = F.relu(self.bn2_target_conv(self.conv2_target(t_i))) - t_i = F.relu(self.bn3_target_conv(self.conv3_target(t_i))) - - t_v = F.relu(self.bn1_target_mlp(self.fc1_target_mlp(s_next_vec))) - t_v = F.relu(self.bn2_target_mlp(self.fc2_target_mlp(t_v))) - else: - t_i = F.relu(self.conv1_target(s_next_img)) - t_i = F.relu(self.conv2_target(t_i)) - t_i = F.relu(self.conv3_target(t_i)) - - t_v = F.relu(self.fc1_target_mlp(s_next_vec)) - t_v = F.relu(self.fc2_target_mlp(t_v)) - - t_i = t_i.view(t_i.size(0), -1) - t = torch.cat((t_i, t_v), -1) - t = self.fc1_target(t) + p_conv, t_conv = conv_head( + s_next, + self.batch_norm, + self.conv1_p, + self.conv2_p, + self.conv3_p, + self.conv1_t, + self.conv2_t, + self.conv3_t, + self.bn1_p_conv, + self.bn2_p_conv, + self.bn3_p_conv, + self.bn1_t_conv, + self.bn2_t_conv, + self.bn3_t_conv, + ) + + p_mlp, t_mlp = mlp_head( + s_next, + self.batch_norm, + self.fc1_p_mlp, + self.fc2_p_mlp, + self.fc1_t_mlp, + self.fc2_t_mlp, + self.bn1_p_mlp, + self.bn2_p_mlp, + self.bn1_t_mlp, + self.bn2_t_mlp, + ) + + p = torch.cat((p_conv, p_mlp), -1) + t = torch.cat((t_conv, t_mlp), -1) + + p, t = fc_layers(p, t, self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) From 779cff9fabd99d168f08bad350055fa85f6114e7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Nov 2021 18:09:18 +0900 Subject: [PATCH 3/7] fix minor issues of icm and rnd network, apply chmod to DroneDelivery Linux env --- .../DroneDelivery/Linux/DroneDelivery.x86_64 | Bin .../ML-Agents/Timers/TrainScene_timers.json | 1 + jorldy/core/network/icm.py | 8 ++++---- jorldy/core/network/rnd.py | 7 ++++--- 4 files changed, 9 insertions(+), 7 deletions(-) mode change 100644 => 100755 jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery.x86_64 create mode 100644 jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json diff --git a/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery.x86_64 b/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery.x86_64 old mode 100644 new mode 100755 diff --git a/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json b/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json new file mode 100644 index 00000000..60e51160 --- /dev/null +++ b/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json @@ -0,0 +1 @@ +{"count":1,"self":43.5147584,"total":255.42027,"children":{"InitializeActuators":{"count":2,"self":0.0025199999999999997,"total":0.0025199999999999997,"children":null},"InitializeSensors":{"count":2,"self":0.011203999999999999,"total":0.011203999999999999,"children":null},"AgentSendState":{"count":118848,"self":0.214773,"total":19.392991,"children":{"CollectObservations":{"count":23859,"self":2.512253,"total":2.512253,"children":null},"WriteActionMask":{"count":23859,"self":0.018352,"total":0.018352,"children":null},"RequestDecision":{"count":23859,"self":0.061986,"total":16.647613,"children":{"AgentInfo.ToProto":{"count":23859,"self":0.081139,"total":16.585627,"children":{"GenerateSensorData":{"count":23859,"self":2.2982519999999997,"total":16.504488,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":119295,"self":14.2062368,"total":14.206235999999999,"children":null}}}}}}}}},"DecideAction":{"count":118848,"self":191.884864,"total":191.88486699999999,"children":null},"AgentAct":{"count":118848,"self":0.487274,"total":0.53745999999999994,"children":{"AgentInfo.ToProto":{"count":77,"self":0.00027499999999999996,"total":0.050185999999999995,"children":{"GenerateSensorData":{"count":77,"self":0.0055049999999999995,"total":0.049911,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":385,"self":0.044406,"total":0.044406,"children":null}}}}}}},"AgentInfo.ToProto":{"count":106,"self":0.000493,"total":0.074543,"children":{"GenerateSensorData":{"count":106,"self":0.010235,"total":0.074049999999999991,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":530,"self":0.063815,"total":0.063815,"children":null}}}}}},"gauges":{"My Behavior.CumulativeReward":{"count":183,"max":1.13508928,"min":-5.50699663,"runningAverage":-1.02159321,"value":-0.9066612,"weightedAverage":-0.937120557}},"metadata":{"timer_format_version":"0.1.0","start_time_seconds":"1637226226","unity_version":"2021.1.6f1","command_line_arguments":"\/data\/private\/JORLDY\/jorldy\/.\/core\/env\/mlagents\/DroneDelivery\/Linux\/DroneDelivery.x86_64 -nographics -batchmode --mlagents-port 5007","communication_protocol_version":"1.5.0","com.unity.ml-agents_version":"2.1.0-exp.1","scene_name":"TrainScene","end_time_seconds":"1637226482"}}name":"TrainScene","end_time_seconds":"1637226482"}} \ No newline at end of file diff --git a/jorldy/core/network/icm.py b/jorldy/core/network/icm.py index ce8de40d..9959fd86 100644 --- a/jorldy/core/network/icm.py +++ b/jorldy/core/network/icm.py @@ -454,8 +454,8 @@ def forward(self, s, a, s_next, update_ri=False): ) s_vec, s_next_vec = mlp_head( - s, - s_next, + s_vec, + s_next_vec, self.batch_norm, self.fc1, self.fc2, @@ -464,8 +464,8 @@ def forward(self, s, a, s_next, update_ri=False): self.bn1_next_mlp, ) s_img, s_next_img = conv_head( - s, - s_next, + s_img, + s_next_img, self.batch_norm, self.conv1, self.conv2, diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index a4a12696..9307cc3a 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -356,8 +356,9 @@ def __init__( self.fc1_t_mlp, self.fc2_t_mlp, ) = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) + self.bn1_p_mlp, self.bn2_p_mlp, self.bn1_t_mlp, self.bn2_t_mlp = mlp_batch_norm( - D_hidden, feature_size + D_hidden, feature_size_mlp ) feature_size = feature_size_img + feature_size_mlp @@ -388,7 +389,7 @@ def forward(self, s_next, update_ri=False): ) p_conv, t_conv = conv_head( - s_next, + s_next_img, self.batch_norm, self.conv1_p, self.conv2_p, @@ -405,7 +406,7 @@ def forward(self, s_next, update_ri=False): ) p_mlp, t_mlp = mlp_head( - s_next, + s_next_vec, self.batch_norm, self.fc1_p_mlp, self.fc2_p_mlp, From 1a8d4523a1715793e638e9b4e53d69967b88fb7e Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Nov 2021 18:14:54 +0900 Subject: [PATCH 4/7] apply black --- jorldy/core/network/rnd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index 9307cc3a..bef9650b 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -356,7 +356,7 @@ def __init__( self.fc1_t_mlp, self.fc2_t_mlp, ) = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) - + self.bn1_p_mlp, self.bn2_p_mlp, self.bn1_t_mlp, self.bn2_t_mlp = mlp_batch_norm( D_hidden, feature_size_mlp ) From 4256bd45adec889f1c26a1a4231035b9b76eeea8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Nov 2021 17:37:38 +0900 Subject: [PATCH 5/7] icm refactoring done --- jorldy/core/network/icm.py | 351 ++++++++++++------------------------- 1 file changed, 109 insertions(+), 242 deletions(-) diff --git a/jorldy/core/network/icm.py b/jorldy/core/network/icm.py index 9959fd86..bcd69e54 100644 --- a/jorldy/core/network/icm.py +++ b/jorldy/core/network/icm.py @@ -1,49 +1,52 @@ import torch import torch.nn.functional as F -from .rnd import * +from .rnd import normalize_obs +from .utils import RewardForwardFilter, RunningMeanStd -def mlp_head_weight(D_in, D_hidden, feature_size): - fc1 = torch.nn.Linear(D_in, D_hidden) - fc2 = torch.nn.Linear(D_hidden, feature_size) +def mlp_head_weight(instance, D_in, D_hidden, feature_size): + instance.fc1 = torch.nn.Linear(D_in, D_hidden) + instance.fc2 = torch.nn.Linear(D_hidden, feature_size) - return fc1, fc2 +def mlp_batch_norm(instance, D_hidden, feature_size): + instance.bn1 = torch.nn.BatchNorm1d(D_hidden) + instance.bn2 = torch.nn.BatchNorm1d(feature_size) -def mlp_batch_norm(D_hidden, feature_size): - bn1 = torch.nn.BatchNorm1d(D_hidden) - bn2 = torch.nn.BatchNorm1d(feature_size) + instance.bn1_next = torch.nn.BatchNorm1d(D_hidden) - bn1_next = torch.nn.BatchNorm1d(D_hidden) - return bn1, bn2, bn1_next +def mlp_head(instance, s, s_next): + if instance.batch_norm: + s = F.elu(instance.bn1(instance.fc1(s))) + s = F.elu(instance.bn2(instance.fc2(s))) - -def mlp_head(s, s_next, batch_norm, fc1, fc2, bn1, bn2, bn1_next): - if batch_norm: - s = F.elu(bn1(fc1(s))) - s = F.elu(bn2(fc2(s))) - - s_next = F.elu(bn1_next(fc1(s_next))) + s_next = F.elu(instance.bn1_next(instance.fc1(s_next))) else: - s = F.elu(fc1(s)) - s = F.elu(fc2(s)) + s = F.elu(instance.fc1(s)) + s = F.elu(instance.fc2(s)) - s_next = F.elu(fc1(s_next)) + s_next = F.elu(instance.fc1(s_next)) - s_next = F.elu(fc2(s_next)) + s_next = F.elu(instance.fc2(s_next)) return s, s_next -def conv_head_weight(D_in): - conv1 = torch.nn.Conv2d( +def conv_head_weight(instance, D_in): + instance.conv1 = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=3, stride=2 ) - conv2 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) - conv3 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) - conv4 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2) + instance.conv2 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) + instance.conv3 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) + instance.conv4 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) dim1 = ((D_in[1] - 3) // 2 + 1, (D_in[2] - 3) // 2 + 1) dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) @@ -51,110 +54,90 @@ def conv_head_weight(D_in): dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) feature_size = 32 * dim4[0] * dim4[1] + return feature_size + + +def conv_batch_norm(instance): + instance.bn1_conv = torch.nn.BatchNorm2d(32) + instance.bn2_conv = torch.nn.BatchNorm2d(32) + instance.bn3_conv = torch.nn.BatchNorm2d(32) + instance.bn4_conv = torch.nn.BatchNorm2d(32) + + instance.bn1_next_conv = torch.nn.BatchNorm2d(32) + instance.bn2_next_conv = torch.nn.BatchNorm2d(32) + instance.bn3_next_conv = torch.nn.BatchNorm2d(32) - return conv1, conv2, conv3, conv4, feature_size - - -def conv_batch_norm(): - bn1 = torch.nn.BatchNorm2d(32) - bn2 = torch.nn.BatchNorm2d(32) - bn3 = torch.nn.BatchNorm2d(32) - bn4 = torch.nn.BatchNorm2d(32) - - bn1_next = torch.nn.BatchNorm2d(32) - bn2_next = torch.nn.BatchNorm2d(32) - bn3_next = torch.nn.BatchNorm2d(32) - - return bn1, bn2, bn3, bn4, bn1_next, bn2_next, bn3_next - - -def conv_head( - s, - s_next, - batch_norm, - conv1, - conv2, - conv3, - conv4, - bn1, - bn2, - bn3, - bn4, - bn1_next, - bn2_next, - bn3_next, -): - if batch_norm: - s = F.elu(bn1(conv1(s))) - s = F.elu(bn2(conv2(s))) - s = F.elu(bn3(conv3(s))) - s = F.elu(bn4(conv4(s))) - - s_next = F.elu(bn1_next(conv1(s_next))) - s_next = F.elu(bn2_next(conv2(s_next))) - s_next = F.elu(bn3_next(conv3(s_next))) + +def conv_head(instance, s, s_next): + if instance.batch_norm: + s = F.elu(instance.bn1_conv(instance.conv1(s))) + s = F.elu(instance.bn2_conv(instance.conv2(s))) + s = F.elu(instance.bn3_conv(instance.conv3(s))) + s = F.elu(instance.bn4_conv(instance.conv4(s))) + + s_next = F.elu(instance.bn1_next_conv(instance.conv1(s_next))) + s_next = F.elu(instance.bn2_next_conv(instance.conv2(s_next))) + s_next = F.elu(instance.bn3_next_conv(instance.conv3(s_next))) else: - s = F.elu(conv1(s)) - s = F.elu(conv2(s)) - s = F.elu(conv3(s)) - s = F.elu(conv4(s)) + s = F.elu(instance.conv1(s)) + s = F.elu(instance.conv2(s)) + s = F.elu(instance.conv3(s)) + s = F.elu(instance.conv4(s)) - s_next = F.elu(conv1(s_next)) - s_next = F.elu(conv2(s_next)) - s_next = F.elu(conv3(s_next)) + s_next = F.elu(instance.conv1(s_next)) + s_next = F.elu(instance.conv2(s_next)) + s_next = F.elu(instance.conv3(s_next)) - s_next = F.elu(conv4(s_next)) + s_next = F.elu(instance.conv4(s_next)) s = s.view(s.size(0), -1) s_next = s_next.view(s_next.size(0), -1) return s, s_next -def forward_weight(feature_size, D_hidden, D_out, action_type): - if action_type == "discrete": - forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) +def forward_weight(instance, feature_size, D_hidden, D_out): + if instance.action_type == "discrete": + instance.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) + instance.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) else: - forward_fc1 = torch.nn.Linear(feature_size + D_out, D_hidden) - forward_fc2 = torch.nn.Linear(D_hidden + D_out, feature_size) + instance.forward_fc1 = torch.nn.Linear(feature_size + D_out, D_hidden) + instance.forward_fc2 = torch.nn.Linear(D_hidden + D_out, feature_size) - forward_loss = torch.nn.MSELoss() + instance.forward_loss = torch.nn.MSELoss() - return forward_fc1, forward_fc2, forward_loss +def inverse_weight(instance, feature_size, D_hidden, D_out): + instance.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) + instance.inverse_fc2 = torch.nn.Linear(D_hidden, D_out) -def inverse_weight(feature_size, D_hidden, D_out, action_type): - inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - inverse_fc2 = torch.nn.Linear(D_hidden, D_out) - - inverse_loss = ( - torch.nn.CrossEntropyLoss() if action_type == "discrete" else torch.nn.MSELoss() + instance.inverse_loss = ( + torch.nn.CrossEntropyLoss() + if instance.action_type == "discrete" + else torch.nn.MSELoss() ) - return inverse_fc1, inverse_fc2, inverse_loss - -def forward_model(s, a, s_next, forward_loss, forward_fc1, forward_fc2): +def forward_model(instance, s, a, s_next): x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(forward_fc1(x_forward)) + x_forward = F.relu(instance.forward_fc1(x_forward)) x_forward = torch.cat((x_forward, a), axis=1) - x_forward = forward_fc2(x_forward) + x_forward = instance.forward_fc2(x_forward) - l_f = forward_loss(x_forward, s_next.detach()) + l_f = instance.forward_loss(x_forward, s_next.detach()) return x_forward, l_f -def inverse_model(s, a, s_next, action_type, inverse_loss, inverse_fc1, inverse_fc2): +def inverse_model(instance, s, a, s_next): x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(inverse_fc1(x_inverse)) - x_inverse = inverse_fc2(x_inverse) + x_inverse = F.relu(instance.inverse_fc1(x_inverse)) + x_inverse = instance.inverse_fc2(x_inverse) - if action_type == "discrete": - l_i = inverse_loss(x_inverse, a.view(-1).long()) + if instance.action_type == "discrete": + l_i = instance.inverse_loss(x_inverse, a.view(-1).long()) else: - l_i = inverse_loss(x_inverse, a) + l_i = instance.inverse_loss(x_inverse, a) return l_i @@ -197,15 +180,12 @@ def __init__( feature_size = 256 - self.fc1, self.fc2 = mlp_head_weight(D_in, D_hidden, feature_size) - self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( - feature_size, D_hidden, D_out, action_type - ) - self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( - feature_size, D_hidden, D_out, action_type - ) + mlp_head_weight(self, D_in, D_hidden, feature_size) + forward_weight(self, feature_size, D_hidden, D_out) + inverse_weight(self, feature_size, D_hidden, D_out) - self.bn1, self.bn2, self.bn1_next = mlp_batch_norm(D_hidden, feature_size) + if self.batch_norm: + mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -218,32 +198,13 @@ def forward(self, s, a, s_next, update_ri=False): s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - s, s_next = mlp_head( - s, - s_next, - self.batch_norm, - self.fc1, - self.fc2, - self.bn1, - self.bn2, - self.bn1_next, - ) + s, s_next = mlp_head(self, s, s_next) # Forward Model - x_forward, l_f = forward_model( - s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 - ) + x_forward, l_f = forward_model(self, s, a, s_next) # Inverse Model - l_i = inverse_model( - s, - a, - s_next, - self.action_type, - self.inverse_loss, - self.inverse_fc1, - self.inverse_fc2, - ) + l_i = inverse_model(self, s, a, s_next) # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) @@ -285,26 +246,12 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - self.conv1, self.conv2, self.conv3, self.conv4, feature_size = conv_head_weight( - self.D_in - ) - self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( - feature_size, D_hidden, D_out, action_type - ) - self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( - feature_size, D_hidden, D_out, action_type - ) + feature_size = conv_head_weight(self, self.D_in) + forward_weight(self, feature_size, D_hidden, D_out) + inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - ( - self.bn1, - self.bn2, - self.bn3, - self.bn4, - self.bn1_next, - self.bn2_next, - self.bn3_next, - ) = conv_batch_norm() + conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -317,38 +264,13 @@ def forward(self, s, a, s_next, update_ri=False): s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - s, s_next = conv_head( - s, - s_next, - self.batch_norm, - self.conv1, - self.conv2, - self.conv3, - self.conv4, - self.bn1, - self.bn2, - self.bn3, - self.bn4, - self.bn1_next, - self.bn2_next, - self.bn3_next, - ) + s, s_next = conv_head(self, s, s_next) # Forward Model - x_forward, l_f = forward_model( - s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 - ) + x_forward, l_f = forward_model(self, s, a, s_next) # Inverse Model - l_i = inverse_model( - s, - a, - s_next, - self.action_type, - self.inverse_loss, - self.inverse_fc1, - self.inverse_fc2, - ) + l_i = inverse_model(self, s, a, s_next) # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) @@ -394,40 +316,19 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ( - self.conv1, - self.conv2, - self.conv3, - self.conv4, - feature_size_img, - ) = conv_head_weight(self.D_in_img) - + feature_size_img = conv_head_weight(self, self.D_in_img) feature_size_mlp = 256 - self.fc1, self.fc2 = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) + mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - self.forward_fc1, self.forward_fc2, self.forward_loss = forward_weight( - feature_size, D_hidden, D_out, action_type - ) - self.inverse_fc1, self.inverse_fc2, self.inverse_loss = inverse_weight( - feature_size, D_hidden, D_out, action_type - ) + forward_weight(self, feature_size, D_hidden, D_out) + inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - self.bn1_mlp, self.bn2_mlp, self.bn1_next_mlp = mlp_batch_norm( - D_hidden, feature_size_mlp - ) - ( - self.bn1_conv, - self.bn2_conv, - self.bn3_conv, - self.bn4_conv, - self.bn1_next_conv, - self.bn2_next_conv, - self.bn3_next_conv, - ) = conv_batch_norm() + mlp_batch_norm(self, D_hidden, feature_size_mlp) + conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -453,51 +354,17 @@ def forward(self, s, a, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - s_vec, s_next_vec = mlp_head( - s_vec, - s_next_vec, - self.batch_norm, - self.fc1, - self.fc2, - self.bn1_mlp, - self.bn2_mlp, - self.bn1_next_mlp, - ) - s_img, s_next_img = conv_head( - s_img, - s_next_img, - self.batch_norm, - self.conv1, - self.conv2, - self.conv3, - self.conv4, - self.bn1_conv, - self.bn2_conv, - self.bn3_conv, - self.bn4_conv, - self.bn1_next_conv, - self.bn2_next_conv, - self.bn3_next_conv, - ) + s_vec, s_next_vec = mlp_head(self, s_vec, s_next_vec) + s_img, s_next_img = conv_head(self, s_img, s_next_img) s = torch.cat((s_img, s_vec), -1) s_next = torch.cat((s_next_img, s_next_vec), -1) # Forward Model - x_forward, l_f = forward_model( - s, a, s_next, self.forward_loss, self.forward_fc1, self.forward_fc2 - ) + x_forward, l_f = forward_model(self, s, a, s_next) # Inverse Model - l_i = inverse_model( - s, - a, - s_next, - self.action_type, - self.inverse_loss, - self.inverse_fc1, - self.inverse_fc2, - ) + l_i = inverse_model(self, s, a, s_next) # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) From 8285e06512771c44d56e1d826cb8e30311e961ed Mon Sep 17 00:00:00 2001 From: root Date: Tue, 23 Nov 2021 10:58:21 +0900 Subject: [PATCH 6/7] RND code refactoring done --- jorldy/core/network/rnd.py | 291 +++++++++++-------------------------- 1 file changed, 84 insertions(+), 207 deletions(-) diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index bef9650b..ff22e03d 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -9,46 +9,40 @@ def normalize_obs(obs, m, v): return torch.clip((obs - m) / (torch.sqrt(v) + 1e-7), min=-5.0, max=5.0) -def mlp_head_weight(D_in, D_hidden, feature_size): - fc1_p = torch.nn.Linear(D_in, D_hidden) - fc2_p = torch.nn.Linear(D_hidden, feature_size) +def mlp_head_weight(instance, D_in, D_hidden, feature_size): + instance.fc1_p_mlp = torch.nn.Linear(D_in, D_hidden) + instance.fc2_p_mlp = torch.nn.Linear(D_hidden, feature_size) - fc1_t = torch.nn.Linear(D_in, D_hidden) - fc2_t = torch.nn.Linear(D_hidden, feature_size) + instance.fc1_t_mlp = torch.nn.Linear(D_in, D_hidden) + instance.fc2_t_mlp = torch.nn.Linear(D_hidden, feature_size) - return fc1_p, fc2_p, fc1_t, fc2_t +def mlp_batch_norm(instance, D_hidden, feature_size): + instance.bn1_p_mlp = torch.nn.BatchNorm1d(D_hidden) + instance.bn2_p_mlp = torch.nn.BatchNorm1d(feature_size) -def mlp_batch_norm(D_hidden, feature_size): - bn1_p = torch.nn.BatchNorm1d(D_hidden) - bn2_p = torch.nn.BatchNorm1d(feature_size) + instance.bn1_t_mlp = torch.nn.BatchNorm1d(D_hidden) + instance.bn2_t_mlp = torch.nn.BatchNorm1d(feature_size) - bn1_t = torch.nn.BatchNorm1d(D_hidden) - bn2_t = torch.nn.BatchNorm1d(feature_size) - return bn1_p, bn2_p, bn1_t, bn2_t +def mlp_head(instance, s_next): + if instance.batch_norm: + p = F.relu(instance.bn1_p_mlp(instance.fc1_p_mlp(s_next))) + p = F.relu(instance.bn2_p_mlp(instance.fc2_p_mlp(p))) - -def mlp_head( - s_next, batch_norm, fc1_p, fc2_p, fc1_t, fc2_t, bn1_p, bn2_p, bn1_t, bn2_t -): - if batch_norm: - p = F.relu(bn1_p(fc1_p(s_next))) - p = F.relu(bn2_p(fc2_p(p))) - - t = F.relu(bn1_t(fc1_t(s_next))) - t = F.relu(bn2_t(fc2_t(t))) + t = F.relu(instance.bn1_t_mlp(instance.fc1_t_mlp(s_next))) + t = F.relu(instance.bn2_t_mlp(instance.fc2_t_mlp(t))) else: - p = F.relu(fc1_p(s_next)) - p = F.relu(fc2_p(p)) + p = F.relu(instance.fc1_p_mlp(s_next)) + p = F.relu(instance.fc2_p_mlp(p)) - t = F.relu(fc1_t(s_next)) - t = F.relu(fc2_t(t)) + t = F.relu(instance.fc1_t_mlp(s_next)) + t = F.relu(instance.fc2_t_mlp(t)) return p, t -def conv_head_weight(D_in): +def conv_head_weight(instance, D_in): dim1 = ((D_in[1] - 8) // 4 + 1, (D_in[2] - 8) // 4 + 1) dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) @@ -56,66 +50,49 @@ def conv_head_weight(D_in): feature_size = 64 * dim3[0] * dim3[1] # Predictor Networks - conv1_p = torch.nn.Conv2d( + instance.conv1_p = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 ) - conv2_p = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) - conv3_p = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + instance.conv2_p = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) + instance.conv3_p = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) # Target Networks - conv1_t = torch.nn.Conv2d( + instance.conv1_t = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 ) - conv2_t = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) - conv3_t = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) - - return conv1_p, conv2_p, conv3_p, conv1_t, conv2_t, conv3_t, feature_size - - -def conv_batch_norm(): - bn1_p = torch.nn.BatchNorm2d(32) - bn2_p = torch.nn.BatchNorm2d(64) - bn3_p = torch.nn.BatchNorm2d(64) - - bn1_t = torch.nn.BatchNorm2d(32) - bn2_t = torch.nn.BatchNorm2d(64) - bn3_t = torch.nn.BatchNorm2d(64) - - return bn1_p, bn2_p, bn3_p, bn1_t, bn2_t, bn3_t - - -def conv_head( - s_next, - batch_norm, - conv1_p, - conv2_p, - conv3_p, - conv1_t, - conv2_t, - conv3_t, - bn1_p, - bn2_p, - bn3_p, - bn1_t, - bn2_t, - bn3_t, -): - if batch_norm: - p = F.relu(bn1_p(conv1_p(s_next))) - p = F.relu(bn2_p(conv2_p(p))) - p = F.relu(bn3_p(conv3_p(p))) - - t = F.relu(bn1_t(conv1_t(s_next))) - t = F.relu(bn2_t(conv2_t(t))) - t = F.relu(bn3_t(conv3_t(t))) + instance.conv2_t = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) + instance.conv3_t = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + + return feature_size + + +def conv_batch_norm(instance): + instance.bn1_p_conv = torch.nn.BatchNorm2d(32) + instance.bn2_p_conv = torch.nn.BatchNorm2d(64) + instance.bn3_p_conv = torch.nn.BatchNorm2d(64) + + instance.bn1_t_conv = torch.nn.BatchNorm2d(32) + instance.bn2_t_conv = torch.nn.BatchNorm2d(64) + instance.bn3_t_conv = torch.nn.BatchNorm2d(64) + + +def conv_head(instance, s_next): + if instance.batch_norm: + p = F.relu(instance.bn1_p_conv(instance.conv1_p(s_next))) + p = F.relu(instance.bn2_p_conv(instance.conv2_p(p))) + p = F.relu(instance.bn3_p_conv(instance.conv3_p(p))) + + t = F.relu(instance.bn1_t_conv(instance.conv1_t(s_next))) + t = F.relu(instance.bn2_t_conv(instance.conv2_t(t))) + t = F.relu(instance.bn3_t_conv(instance.conv3_t(t))) else: - p = F.relu(conv1_p(s_next)) - p = F.relu(conv2_p(p)) - p = F.relu(conv3_p(p)) + p = F.relu(instance.conv1_p(s_next)) + p = F.relu(instance.conv2_p(p)) + p = F.relu(instance.conv3_p(p)) - t = F.relu(conv1_t(s_next)) - t = F.relu(conv2_t(t)) - t = F.relu(conv3_t(t)) + t = F.relu(instance.conv1_t(s_next)) + t = F.relu(instance.conv2_t(t)) + t = F.relu(instance.conv3_t(t)) p = p.view(p.size(0), -1) t = t.view(t.size(0), -1) @@ -123,22 +100,20 @@ def conv_head( return p, t -def fc_layers_weight(feature_size, D_hidden): - fc1_p = torch.nn.Linear(feature_size, D_hidden) - fc2_p = torch.nn.Linear(D_hidden, D_hidden) - fc3_p = torch.nn.Linear(D_hidden, D_hidden) +def fc_layers_weight(instance, feature_size, D_hidden): + instance.fc1_p = torch.nn.Linear(feature_size, D_hidden) + instance.fc2_p = torch.nn.Linear(D_hidden, D_hidden) + instance.fc3_p = torch.nn.Linear(D_hidden, D_hidden) - fc1_t = torch.nn.Linear(feature_size, D_hidden) + instance.fc1_t = torch.nn.Linear(feature_size, D_hidden) - return fc1_p, fc2_p, fc3_p, fc1_t +def fc_layers(instance, p, t): + p = F.relu(instance.fc1_p(p)) + p = F.relu(instance.fc2_p(p)) + p = instance.fc3_p(p) -def fc_layers(p, t, fc1_p, fc2_p, fc3_p, fc1_t): - p = F.relu(fc1_p(p)) - p = F.relu(fc2_p(p)) - p = fc3_p(p) - - t = fc1_t(t) + t = instance.fc1_t(t) return p, t @@ -176,12 +151,10 @@ def __init__( feature_size = 256 - self.fc1_p, self.fc2_p, self.fc1_t, self.fc2_t = mlp_head_weight( - D_in, D_hidden, feature_size - ) - self.bn1_p, self.bn2_p, self.bn1_t, self.bn2_t = mlp_batch_norm( - D_hidden, feature_size - ) + mlp_head_weight(self, D_in, D_hidden, feature_size) + + if self.batch_norm: + mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -193,18 +166,7 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - p, t = mlp_head( - s_next, - self.batch_norm, - self.fc1_p, - self.fc2_p, - self.fc1_t, - self.fc2_t, - self.bn1_p, - self.bn2_p, - self.bn1_t, - self.bn2_t, - ) + p, t = mlp_head(self, s_next) r_i = torch.mean(torch.square(p - t), axis=1) @@ -240,26 +202,9 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ( - self.conv1_p, - self.conv2_p, - self.conv3_p, - self.conv1_t, - self.conv2_t, - self.conv3_t, - feature_size, - ) = conv_head_weight(D_in) - ( - self.bn1_p, - self.bn2_p, - self.bn3_p, - self.bn1_t, - self.bn2_t, - self.bn3_t, - ) = conv_batch_norm() - self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t = fc_layers_weight( - feature_size, D_hidden - ) + feature_size = conv_head_weight(self, D_in) + conv_batch_norm(self) + fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -272,24 +217,8 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - p, t = conv_head( - s_next, - self.batch_norm, - self.conv1_p, - self.conv2_p, - self.conv3_p, - self.conv1_t, - self.conv2_t, - self.conv3_t, - self.bn1_p, - self.bn2_p, - self.bn3_p, - self.bn1_t, - self.bn2_t, - self.bn3_t, - ) - - p, t = fc_layers(p, t, self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t) + p, t = conv_head(self, s_next) + p, t = fc_layers(self, p, t) r_i = torch.mean(torch.square(p - t), axis=1) @@ -330,42 +259,17 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ( - self.conv1_p, - self.conv2_p, - self.conv3_p, - self.conv1_t, - self.conv2_t, - self.conv3_t, - feature_size_img, - ) = conv_head_weight(self.D_in_img) - ( - self.bn1_p_conv, - self.bn2_p_conv, - self.bn3_p_conv, - self.bn1_t_conv, - self.bn2_t_conv, - self.bn3_t_conv, - ) = conv_batch_norm() + feature_size_img = conv_head_weight(self, self.D_in_img) + conv_batch_norm(self) feature_size_mlp = 256 + mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) - ( - self.fc1_p_mlp, - self.fc2_p_mlp, - self.fc1_t_mlp, - self.fc2_t_mlp, - ) = mlp_head_weight(self.D_in_vec, D_hidden, feature_size_mlp) - - self.bn1_p_mlp, self.bn2_p_mlp, self.bn1_t_mlp, self.bn2_t_mlp = mlp_batch_norm( - D_hidden, feature_size_mlp - ) + mlp_batch_norm(self, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t = fc_layers_weight( - feature_size, D_hidden - ) + fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -388,40 +292,13 @@ def forward(self, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - p_conv, t_conv = conv_head( - s_next_img, - self.batch_norm, - self.conv1_p, - self.conv2_p, - self.conv3_p, - self.conv1_t, - self.conv2_t, - self.conv3_t, - self.bn1_p_conv, - self.bn2_p_conv, - self.bn3_p_conv, - self.bn1_t_conv, - self.bn2_t_conv, - self.bn3_t_conv, - ) - - p_mlp, t_mlp = mlp_head( - s_next_vec, - self.batch_norm, - self.fc1_p_mlp, - self.fc2_p_mlp, - self.fc1_t_mlp, - self.fc2_t_mlp, - self.bn1_p_mlp, - self.bn2_p_mlp, - self.bn1_t_mlp, - self.bn2_t_mlp, - ) + p_conv, t_conv = conv_head(self, s_next_img) + p_mlp, t_mlp = mlp_head(self, s_next_vec) p = torch.cat((p_conv, p_mlp), -1) t = torch.cat((t_conv, t_mlp), -1) - p, t = fc_layers(p, t, self.fc1_p, self.fc2_p, self.fc3_p, self.fc1_t) + p, t = fc_layers(self, p, t) r_i = torch.mean(torch.square(p - t), axis=1) From 0cb94c732e2e3f2fd975b2287c7be19b76263f0f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 23 Nov 2021 11:12:15 +0900 Subject: [PATCH 7/7] change function name of icm and rnd weight define --- jorldy/core/network/icm.py | 40 +++++++++++++++---------------- jorldy/core/network/rnd.py | 48 ++++++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/jorldy/core/network/icm.py b/jorldy/core/network/icm.py index bcd69e54..37c3e439 100644 --- a/jorldy/core/network/icm.py +++ b/jorldy/core/network/icm.py @@ -5,12 +5,12 @@ from .utils import RewardForwardFilter, RunningMeanStd -def mlp_head_weight(instance, D_in, D_hidden, feature_size): +def define_mlp_head_weight(instance, D_in, D_hidden, feature_size): instance.fc1 = torch.nn.Linear(D_in, D_hidden) instance.fc2 = torch.nn.Linear(D_hidden, feature_size) -def mlp_batch_norm(instance, D_hidden, feature_size): +def define_mlp_batch_norm(instance, D_hidden, feature_size): instance.bn1 = torch.nn.BatchNorm1d(D_hidden) instance.bn2 = torch.nn.BatchNorm1d(feature_size) @@ -34,7 +34,7 @@ def mlp_head(instance, s, s_next): return s, s_next -def conv_head_weight(instance, D_in): +def define_conv_head_weight(instance, D_in): instance.conv1 = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=3, stride=2 ) @@ -57,7 +57,7 @@ def conv_head_weight(instance, D_in): return feature_size -def conv_batch_norm(instance): +def define_conv_batch_norm(instance): instance.bn1_conv = torch.nn.BatchNorm2d(32) instance.bn2_conv = torch.nn.BatchNorm2d(32) instance.bn3_conv = torch.nn.BatchNorm2d(32) @@ -96,7 +96,7 @@ def conv_head(instance, s, s_next): return s, s_next -def forward_weight(instance, feature_size, D_hidden, D_out): +def define_forward_weight(instance, feature_size, D_hidden, D_out): if instance.action_type == "discrete": instance.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) instance.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) @@ -107,7 +107,7 @@ def forward_weight(instance, feature_size, D_hidden, D_out): instance.forward_loss = torch.nn.MSELoss() -def inverse_weight(instance, feature_size, D_hidden, D_out): +def define_inverse_weight(instance, feature_size, D_hidden, D_out): instance.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) instance.inverse_fc2 = torch.nn.Linear(D_hidden, D_out) @@ -180,12 +180,12 @@ def __init__( feature_size = 256 - mlp_head_weight(self, D_in, D_hidden, feature_size) - forward_weight(self, feature_size, D_hidden, D_out) - inverse_weight(self, feature_size, D_hidden, D_out) + define_mlp_head_weight(self, D_in, D_hidden, feature_size) + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - mlp_batch_norm(self, D_hidden, feature_size) + define_mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -246,12 +246,12 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - feature_size = conv_head_weight(self, self.D_in) - forward_weight(self, feature_size, D_hidden, D_out) - inverse_weight(self, feature_size, D_hidden, D_out) + feature_size = define_conv_head_weight(self, self.D_in) + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - conv_batch_norm(self) + define_conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -316,19 +316,19 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - feature_size_img = conv_head_weight(self, self.D_in_img) + feature_size_img = define_conv_head_weight(self, self.D_in_img) feature_size_mlp = 256 - mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) + define_mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - forward_weight(self, feature_size, D_hidden, D_out) - inverse_weight(self, feature_size, D_hidden, D_out) + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - mlp_batch_norm(self, D_hidden, feature_size_mlp) - conv_batch_norm(self) + define_mlp_batch_norm(self, D_hidden, feature_size_mlp) + define_conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index ff22e03d..eec63aac 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -9,7 +9,7 @@ def normalize_obs(obs, m, v): return torch.clip((obs - m) / (torch.sqrt(v) + 1e-7), min=-5.0, max=5.0) -def mlp_head_weight(instance, D_in, D_hidden, feature_size): +def define_mlp_head_weight(instance, D_in, D_hidden, feature_size): instance.fc1_p_mlp = torch.nn.Linear(D_in, D_hidden) instance.fc2_p_mlp = torch.nn.Linear(D_hidden, feature_size) @@ -17,7 +17,7 @@ def mlp_head_weight(instance, D_in, D_hidden, feature_size): instance.fc2_t_mlp = torch.nn.Linear(D_hidden, feature_size) -def mlp_batch_norm(instance, D_hidden, feature_size): +def define_mlp_batch_norm(instance, D_hidden, feature_size): instance.bn1_p_mlp = torch.nn.BatchNorm1d(D_hidden) instance.bn2_p_mlp = torch.nn.BatchNorm1d(feature_size) @@ -42,7 +42,7 @@ def mlp_head(instance, s_next): return p, t -def conv_head_weight(instance, D_in): +def define_conv_head_weight(instance, D_in): dim1 = ((D_in[1] - 8) // 4 + 1, (D_in[2] - 8) // 4 + 1) dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) @@ -53,20 +53,28 @@ def conv_head_weight(instance, D_in): instance.conv1_p = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 ) - instance.conv2_p = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) - instance.conv3_p = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + instance.conv2_p = torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2 + ) + instance.conv3_p = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1 + ) # Target Networks instance.conv1_t = torch.nn.Conv2d( in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 ) - instance.conv2_t = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) - instance.conv3_t = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1) + instance.conv2_t = torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2 + ) + instance.conv3_t = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1 + ) return feature_size -def conv_batch_norm(instance): +def define_conv_batch_norm(instance): instance.bn1_p_conv = torch.nn.BatchNorm2d(32) instance.bn2_p_conv = torch.nn.BatchNorm2d(64) instance.bn3_p_conv = torch.nn.BatchNorm2d(64) @@ -100,7 +108,7 @@ def conv_head(instance, s_next): return p, t -def fc_layers_weight(instance, feature_size, D_hidden): +def define_fc_layers_weight(instance, feature_size, D_hidden): instance.fc1_p = torch.nn.Linear(feature_size, D_hidden) instance.fc2_p = torch.nn.Linear(D_hidden, D_hidden) instance.fc3_p = torch.nn.Linear(D_hidden, D_hidden) @@ -151,10 +159,10 @@ def __init__( feature_size = 256 - mlp_head_weight(self, D_in, D_hidden, feature_size) - + define_mlp_head_weight(self, D_in, D_hidden, feature_size) + if self.batch_norm: - mlp_batch_norm(self, D_hidden, feature_size) + define_mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -202,9 +210,9 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - feature_size = conv_head_weight(self, D_in) - conv_batch_norm(self) - fc_layers_weight(self, feature_size, D_hidden) + feature_size = define_conv_head_weight(self, D_in) + define_conv_batch_norm(self) + define_fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -259,17 +267,17 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - feature_size_img = conv_head_weight(self, self.D_in_img) - conv_batch_norm(self) + feature_size_img = define_conv_head_weight(self, self.D_in_img) + define_conv_batch_norm(self) feature_size_mlp = 256 - mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) + define_mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) - mlp_batch_norm(self, D_hidden, feature_size_mlp) + define_mlp_batch_norm(self, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - fc_layers_weight(self, feature_size, D_hidden) + define_fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0)