diff --git a/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json b/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json new file mode 100644 index 00000000..60e51160 --- /dev/null +++ b/jorldy/core/env/mlagents/DroneDelivery/Linux/DroneDelivery_Data/ML-Agents/Timers/TrainScene_timers.json @@ -0,0 +1 @@ +{"count":1,"self":43.5147584,"total":255.42027,"children":{"InitializeActuators":{"count":2,"self":0.0025199999999999997,"total":0.0025199999999999997,"children":null},"InitializeSensors":{"count":2,"self":0.011203999999999999,"total":0.011203999999999999,"children":null},"AgentSendState":{"count":118848,"self":0.214773,"total":19.392991,"children":{"CollectObservations":{"count":23859,"self":2.512253,"total":2.512253,"children":null},"WriteActionMask":{"count":23859,"self":0.018352,"total":0.018352,"children":null},"RequestDecision":{"count":23859,"self":0.061986,"total":16.647613,"children":{"AgentInfo.ToProto":{"count":23859,"self":0.081139,"total":16.585627,"children":{"GenerateSensorData":{"count":23859,"self":2.2982519999999997,"total":16.504488,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":119295,"self":14.2062368,"total":14.206235999999999,"children":null}}}}}}}}},"DecideAction":{"count":118848,"self":191.884864,"total":191.88486699999999,"children":null},"AgentAct":{"count":118848,"self":0.487274,"total":0.53745999999999994,"children":{"AgentInfo.ToProto":{"count":77,"self":0.00027499999999999996,"total":0.050185999999999995,"children":{"GenerateSensorData":{"count":77,"self":0.0055049999999999995,"total":0.049911,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":385,"self":0.044406,"total":0.044406,"children":null}}}}}}},"AgentInfo.ToProto":{"count":106,"self":0.000493,"total":0.074543,"children":{"GenerateSensorData":{"count":106,"self":0.010235,"total":0.074049999999999991,"children":{"RenderTextureSensor.GetCompressedObservation":{"count":530,"self":0.063815,"total":0.063815,"children":null}}}}}},"gauges":{"My Behavior.CumulativeReward":{"count":183,"max":1.13508928,"min":-5.50699663,"runningAverage":-1.02159321,"value":-0.9066612,"weightedAverage":-0.937120557}},"metadata":{"timer_format_version":"0.1.0","start_time_seconds":"1637226226","unity_version":"2021.1.6f1","command_line_arguments":"\/data\/private\/JORLDY\/jorldy\/.\/core\/env\/mlagents\/DroneDelivery\/Linux\/DroneDelivery.x86_64 -nographics -batchmode --mlagents-port 5007","communication_protocol_version":"1.5.0","com.unity.ml-agents_version":"2.1.0-exp.1","scene_name":"TrainScene","end_time_seconds":"1637226482"}}name":"TrainScene","end_time_seconds":"1637226482"}} \ No newline at end of file diff --git a/jorldy/core/network/icm.py b/jorldy/core/network/icm.py index a87cbb83..37c3e439 100644 --- a/jorldy/core/network/icm.py +++ b/jorldy/core/network/icm.py @@ -1,7 +1,153 @@ import torch import torch.nn.functional as F -from .rnd import * +from .rnd import normalize_obs +from .utils import RewardForwardFilter, RunningMeanStd + + +def define_mlp_head_weight(instance, D_in, D_hidden, feature_size): + instance.fc1 = torch.nn.Linear(D_in, D_hidden) + instance.fc2 = torch.nn.Linear(D_hidden, feature_size) + + +def define_mlp_batch_norm(instance, D_hidden, feature_size): + instance.bn1 = torch.nn.BatchNorm1d(D_hidden) + instance.bn2 = torch.nn.BatchNorm1d(feature_size) + + instance.bn1_next = torch.nn.BatchNorm1d(D_hidden) + + +def mlp_head(instance, s, s_next): + if instance.batch_norm: + s = F.elu(instance.bn1(instance.fc1(s))) + s = F.elu(instance.bn2(instance.fc2(s))) + + s_next = F.elu(instance.bn1_next(instance.fc1(s_next))) + else: + s = F.elu(instance.fc1(s)) + s = F.elu(instance.fc2(s)) + + s_next = F.elu(instance.fc1(s_next)) + + s_next = F.elu(instance.fc2(s_next)) + + return s, s_next + + +def define_conv_head_weight(instance, D_in): + instance.conv1 = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=3, stride=2 + ) + instance.conv2 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) + instance.conv3 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) + instance.conv4 = torch.nn.Conv2d( + in_channels=32, out_channels=32, kernel_size=3, stride=2 + ) + + dim1 = ((D_in[1] - 3) // 2 + 1, (D_in[2] - 3) // 2 + 1) + dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) + dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) + dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) + + feature_size = 32 * dim4[0] * dim4[1] + return feature_size + + +def define_conv_batch_norm(instance): + instance.bn1_conv = torch.nn.BatchNorm2d(32) + instance.bn2_conv = torch.nn.BatchNorm2d(32) + instance.bn3_conv = torch.nn.BatchNorm2d(32) + instance.bn4_conv = torch.nn.BatchNorm2d(32) + + instance.bn1_next_conv = torch.nn.BatchNorm2d(32) + instance.bn2_next_conv = torch.nn.BatchNorm2d(32) + instance.bn3_next_conv = torch.nn.BatchNorm2d(32) + + +def conv_head(instance, s, s_next): + if instance.batch_norm: + s = F.elu(instance.bn1_conv(instance.conv1(s))) + s = F.elu(instance.bn2_conv(instance.conv2(s))) + s = F.elu(instance.bn3_conv(instance.conv3(s))) + s = F.elu(instance.bn4_conv(instance.conv4(s))) + + s_next = F.elu(instance.bn1_next_conv(instance.conv1(s_next))) + s_next = F.elu(instance.bn2_next_conv(instance.conv2(s_next))) + s_next = F.elu(instance.bn3_next_conv(instance.conv3(s_next))) + + else: + s = F.elu(instance.conv1(s)) + s = F.elu(instance.conv2(s)) + s = F.elu(instance.conv3(s)) + s = F.elu(instance.conv4(s)) + + s_next = F.elu(instance.conv1(s_next)) + s_next = F.elu(instance.conv2(s_next)) + s_next = F.elu(instance.conv3(s_next)) + + s_next = F.elu(instance.conv4(s_next)) + s = s.view(s.size(0), -1) + s_next = s_next.view(s_next.size(0), -1) + + return s, s_next + + +def define_forward_weight(instance, feature_size, D_hidden, D_out): + if instance.action_type == "discrete": + instance.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) + instance.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) + else: + instance.forward_fc1 = torch.nn.Linear(feature_size + D_out, D_hidden) + instance.forward_fc2 = torch.nn.Linear(D_hidden + D_out, feature_size) + + instance.forward_loss = torch.nn.MSELoss() + + +def define_inverse_weight(instance, feature_size, D_hidden, D_out): + instance.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) + instance.inverse_fc2 = torch.nn.Linear(D_hidden, D_out) + + instance.inverse_loss = ( + torch.nn.CrossEntropyLoss() + if instance.action_type == "discrete" + else torch.nn.MSELoss() + ) + + +def forward_model(instance, s, a, s_next): + x_forward = torch.cat((s, a), axis=1) + x_forward = F.relu(instance.forward_fc1(x_forward)) + x_forward = torch.cat((x_forward, a), axis=1) + x_forward = instance.forward_fc2(x_forward) + + l_f = instance.forward_loss(x_forward, s_next.detach()) + + return x_forward, l_f + + +def inverse_model(instance, s, a, s_next): + x_inverse = torch.cat((s, s_next), axis=1) + x_inverse = F.relu(instance.inverse_fc1(x_inverse)) + x_inverse = instance.inverse_fc2(x_inverse) + + if instance.action_type == "discrete": + l_i = instance.inverse_loss(x_inverse, a.view(-1).long()) + else: + l_i = instance.inverse_loss(x_inverse, a) + + return l_i + + +def ri_update(r_i, num_workers, rff, update_rms_ri): + ri_T = r_i.view(num_workers, -1).T # (n_batch, n_workers) + rewems = torch.stack( + [rff.update(rit.detach()) for rit in ri_T] + ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) + update_rms_ri(rewems) class ICM_MLP(torch.nn.Module): @@ -34,30 +180,12 @@ def __init__( feature_size = 256 - self.fc1 = torch.nn.Linear(self.D_in, D_hidden) - self.fc2 = torch.nn.Linear(D_hidden, feature_size) - - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() + define_mlp_head_weight(self, D_in, D_hidden, feature_size) + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - self.bn1 = torch.nn.BatchNorm1d(D_hidden) - self.bn2 = torch.nn.BatchNorm1d(feature_size) - - self.bn1_next = torch.nn.BatchNorm1d(D_hidden) + define_mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -68,51 +196,25 @@ def update_rms_ri(self, v): def forward(self, s, a, s_next, update_ri=False): if self.obs_normalize: s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) - if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - s = F.elu(self.bn1(self.fc1(s))) - s = F.elu(self.bn2(self.fc2(s))) - - s_next = F.elu(self.bn1_next(self.fc1(s_next))) - else: - s = F.elu(self.fc1(s)) - s = F.elu(self.fc2(s)) - - s_next = F.elu(self.fc1(s_next)) - - s_next = F.elu(self.fc2(s_next)) + s, s_next = mlp_head(self, s, s_next) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model(self, s, a, s_next) + + # Inverse Model + l_i = inverse_model(self, s, a, s_next) + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i @@ -144,51 +246,12 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - self.conv1 = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=3, stride=2 - ) - self.conv2 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv3 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv4 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - - dim1 = ((self.D_in[1] - 3) // 2 + 1, (self.D_in[2] - 3) // 2 + 1) - dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) - dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) - dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) - - feature_size = 32 * dim4[0] * dim4[1] - - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() + feature_size = define_conv_head_weight(self, self.D_in) + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - self.bn1 = torch.nn.BatchNorm2d(32) - self.bn2 = torch.nn.BatchNorm2d(32) - self.bn3 = torch.nn.BatchNorm2d(32) - self.bn4 = torch.nn.BatchNorm2d(32) - - self.bn1_next = torch.nn.BatchNorm2d(32) - self.bn2_next = torch.nn.BatchNorm2d(32) - self.bn3_next = torch.nn.BatchNorm2d(32) + define_conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -199,62 +262,25 @@ def update_rms_ri(self, v): def forward(self, s, a, s_next, update_ri=False): if self.obs_normalize: s = normalize_obs(s, self.rms_obs.mean, self.rms_obs.var) - if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - s = F.elu(self.bn1(self.conv1(s))) - s = F.elu(self.bn2(self.conv2(s))) - s = F.elu(self.bn3(self.conv3(s))) - s = F.elu(self.bn4(self.conv4(s))) - - s_next = F.elu(self.bn1_next(self.conv1(s_next))) - s_next = F.elu(self.bn2_next(self.conv2(s_next))) - s_next = F.elu(self.bn3_next(self.conv3(s_next))) - - else: - s = F.elu(self.conv1(s)) - s = F.elu(self.conv2(s)) - s = F.elu(self.conv3(s)) - s = F.elu(self.conv4(s)) - - s_next = F.elu(self.conv1(s_next)) - s_next = F.elu(self.conv2(s_next)) - s_next = F.elu(self.conv3(s_next)) - - s_next = F.elu(self.conv4(s_next)) - s = s.view(s.size(0), -1) - s_next = s_next.view(s_next.size(0), -1) + s, s_next = conv_head(self, s, s_next) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model(self, s, a, s_next) + + # Inverse Model + l_i = inverse_model(self, s, a, s_next) + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i @@ -290,66 +316,19 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ################################## Conv HEAD ################################## - self.conv1 = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=3, stride=2 - ) - self.conv2 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv3 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - self.conv4 = torch.nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=3, stride=2 - ) - - dim1 = ((self.D_in_img[1] - 3) // 2 + 1, (self.D_in_img[2] - 3) // 2 + 1) - dim2 = ((dim1[0] - 3) // 2 + 1, (dim1[1] - 3) // 2 + 1) - dim3 = ((dim2[0] - 3) // 2 + 1, (dim2[1] - 3) // 2 + 1) - dim4 = ((dim3[0] - 3) // 2 + 1, (dim3[1] - 3) // 2 + 1) - - feature_size_img = 32 * dim4[0] * dim4[1] - - ################################## MLP HEAD ################################## + feature_size_img = define_conv_head_weight(self, self.D_in_img) feature_size_mlp = 256 - self.fc1_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) - ############################################################################## + define_mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) feature_size = feature_size_img + feature_size_mlp - self.inverse_fc1 = torch.nn.Linear(2 * feature_size, D_hidden) - self.inverse_fc2 = torch.nn.Linear(D_hidden, self.D_out) - - self.forward_loss = torch.nn.MSELoss() - - if self.action_type == "discrete": - self.forward_fc1 = torch.nn.Linear(feature_size + 1, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + 1, feature_size) - - self.inverse_loss = torch.nn.CrossEntropyLoss() - else: - self.forward_fc1 = torch.nn.Linear(feature_size + self.D_out, D_hidden) - self.forward_fc2 = torch.nn.Linear(D_hidden + self.D_out, feature_size) - - self.inverse_loss = torch.nn.MSELoss() + define_forward_weight(self, feature_size, D_hidden, D_out) + define_inverse_weight(self, feature_size, D_hidden, D_out) if self.batch_norm: - self.bn1_conv = torch.nn.BatchNorm2d(32) - self.bn2_conv = torch.nn.BatchNorm2d(32) - self.bn3_conv = torch.nn.BatchNorm2d(32) - self.bn4_conv = torch.nn.BatchNorm2d(32) - - self.bn1_next_conv = torch.nn.BatchNorm2d(32) - self.bn2_next_conv = torch.nn.BatchNorm2d(32) - self.bn3_next_conv = torch.nn.BatchNorm2d(32) - - self.bn1_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_mlp = torch.nn.BatchNorm1d(feature_size_mlp) - - self.bn1_next_mlp = torch.nn.BatchNorm1d(D_hidden) + define_mlp_batch_norm(self, D_hidden, feature_size_mlp) + define_conv_batch_norm(self) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -375,71 +354,25 @@ def forward(self, s, a, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - if self.batch_norm: - s_img = F.elu(self.bn1_conv(self.conv1(s_img))) - s_img = F.elu(self.bn2_conv(self.conv2(s_img))) - s_img = F.elu(self.bn3_conv(self.conv3(s_img))) - s_img = F.elu(self.bn4_conv(self.conv4(s_img))) - - s_next_img = F.elu(self.bn1_next_conv(self.conv1(s_next_img))) - s_next_img = F.elu(self.bn2_next_conv(self.conv2(s_next_img))) - s_next_img = F.elu(self.bn3_next_conv(self.conv3(s_next_img))) - - s_vec = F.elu(self.bn1_mlp(self.fc1_mlp(s_vec))) - s_vec = F.elu(self.bn2_mlp(self.fc2_mlp(s_vec))) - - s_next_vec = F.elu(self.bn1_next_mlp(self.fc1_mlp(s_next_vec))) - else: - s_img = F.elu(self.conv1(s_img)) - s_img = F.elu(self.conv2(s_img)) - s_img = F.elu(self.conv3(s_img)) - s_img = F.elu(self.conv4(s_img)) - - s_next_img = F.elu(self.conv1(s_next_img)) - s_next_img = F.elu(self.conv2(s_next_img)) - s_next_img = F.elu(self.conv3(s_next_img)) - - s_vec = F.elu(self.fc1_mlp(s_vec)) - s_vec = F.elu(self.fc2_mlp(s_vec)) - - s_next_vec = F.elu(self.fc1_mlp(s_next_vec)) - - s_next_img = F.elu(self.conv4(s_next_img)) - s_img = s_img.view(s_img.size(0), -1) - s_next_img = s_next_img.view(s_next_img.size(0), -1) - - s_next_vec = F.elu(self.fc2_mlp(s_next_vec)) + s_vec, s_next_vec = mlp_head(self, s_vec, s_next_vec) + s_img, s_next_img = conv_head(self, s_img, s_next_img) s = torch.cat((s_img, s_vec), -1) s_next = torch.cat((s_next_img, s_next_vec), -1) # Forward Model - x_forward = torch.cat((s, a), axis=1) - x_forward = F.relu(self.forward_fc1(x_forward)) - x_forward = torch.cat((x_forward, a), axis=1) - x_forward = self.forward_fc2(x_forward) + x_forward, l_f = forward_model(self, s, a, s_next) + # Inverse Model + l_i = inverse_model(self, s, a, s_next) + + # Get Ri r_i = (self.eta * 0.5) * torch.sum(torch.abs(x_forward - s_next), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) - l_f = self.forward_loss(x_forward, s_next.detach()) - - # Inverse Model - x_inverse = torch.cat((s, s_next), axis=1) - x_inverse = F.relu(self.inverse_fc1(x_inverse)) - x_inverse = self.inverse_fc2(x_inverse) - - if self.action_type == "discrete": - l_i = self.inverse_loss(x_inverse, a.view(-1).long()) - else: - l_i = self.inverse_loss(x_inverse, a) - return r_i, l_f, l_i diff --git a/jorldy/core/network/rnd.py b/jorldy/core/network/rnd.py index bcad4478..eec63aac 100644 --- a/jorldy/core/network/rnd.py +++ b/jorldy/core/network/rnd.py @@ -9,6 +9,131 @@ def normalize_obs(obs, m, v): return torch.clip((obs - m) / (torch.sqrt(v) + 1e-7), min=-5.0, max=5.0) +def define_mlp_head_weight(instance, D_in, D_hidden, feature_size): + instance.fc1_p_mlp = torch.nn.Linear(D_in, D_hidden) + instance.fc2_p_mlp = torch.nn.Linear(D_hidden, feature_size) + + instance.fc1_t_mlp = torch.nn.Linear(D_in, D_hidden) + instance.fc2_t_mlp = torch.nn.Linear(D_hidden, feature_size) + + +def define_mlp_batch_norm(instance, D_hidden, feature_size): + instance.bn1_p_mlp = torch.nn.BatchNorm1d(D_hidden) + instance.bn2_p_mlp = torch.nn.BatchNorm1d(feature_size) + + instance.bn1_t_mlp = torch.nn.BatchNorm1d(D_hidden) + instance.bn2_t_mlp = torch.nn.BatchNorm1d(feature_size) + + +def mlp_head(instance, s_next): + if instance.batch_norm: + p = F.relu(instance.bn1_p_mlp(instance.fc1_p_mlp(s_next))) + p = F.relu(instance.bn2_p_mlp(instance.fc2_p_mlp(p))) + + t = F.relu(instance.bn1_t_mlp(instance.fc1_t_mlp(s_next))) + t = F.relu(instance.bn2_t_mlp(instance.fc2_t_mlp(t))) + else: + p = F.relu(instance.fc1_p_mlp(s_next)) + p = F.relu(instance.fc2_p_mlp(p)) + + t = F.relu(instance.fc1_t_mlp(s_next)) + t = F.relu(instance.fc2_t_mlp(t)) + + return p, t + + +def define_conv_head_weight(instance, D_in): + dim1 = ((D_in[1] - 8) // 4 + 1, (D_in[2] - 8) // 4 + 1) + dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) + dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) + + feature_size = 64 * dim3[0] * dim3[1] + + # Predictor Networks + instance.conv1_p = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 + ) + instance.conv2_p = torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2 + ) + instance.conv3_p = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1 + ) + + # Target Networks + instance.conv1_t = torch.nn.Conv2d( + in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4 + ) + instance.conv2_t = torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2 + ) + instance.conv3_t = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1 + ) + + return feature_size + + +def define_conv_batch_norm(instance): + instance.bn1_p_conv = torch.nn.BatchNorm2d(32) + instance.bn2_p_conv = torch.nn.BatchNorm2d(64) + instance.bn3_p_conv = torch.nn.BatchNorm2d(64) + + instance.bn1_t_conv = torch.nn.BatchNorm2d(32) + instance.bn2_t_conv = torch.nn.BatchNorm2d(64) + instance.bn3_t_conv = torch.nn.BatchNorm2d(64) + + +def conv_head(instance, s_next): + if instance.batch_norm: + p = F.relu(instance.bn1_p_conv(instance.conv1_p(s_next))) + p = F.relu(instance.bn2_p_conv(instance.conv2_p(p))) + p = F.relu(instance.bn3_p_conv(instance.conv3_p(p))) + + t = F.relu(instance.bn1_t_conv(instance.conv1_t(s_next))) + t = F.relu(instance.bn2_t_conv(instance.conv2_t(t))) + t = F.relu(instance.bn3_t_conv(instance.conv3_t(t))) + else: + p = F.relu(instance.conv1_p(s_next)) + p = F.relu(instance.conv2_p(p)) + p = F.relu(instance.conv3_p(p)) + + t = F.relu(instance.conv1_t(s_next)) + t = F.relu(instance.conv2_t(t)) + t = F.relu(instance.conv3_t(t)) + + p = p.view(p.size(0), -1) + t = t.view(t.size(0), -1) + + return p, t + + +def define_fc_layers_weight(instance, feature_size, D_hidden): + instance.fc1_p = torch.nn.Linear(feature_size, D_hidden) + instance.fc2_p = torch.nn.Linear(D_hidden, D_hidden) + instance.fc3_p = torch.nn.Linear(D_hidden, D_hidden) + + instance.fc1_t = torch.nn.Linear(feature_size, D_hidden) + + +def fc_layers(instance, p, t): + p = F.relu(instance.fc1_p(p)) + p = F.relu(instance.fc2_p(p)) + p = instance.fc3_p(p) + + t = instance.fc1_t(t) + + return p, t + + +def ri_update(r_i, num_workers, rff, update_rms_ri): + ri_T = r_i.view(num_workers, -1).T # (n_batch, n_workers) + rewems = torch.stack( + [rff.update(rit.detach()) for rit in ri_T] + ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) + update_rms_ri(rewems) + + class RND_MLP(torch.nn.Module): def __init__( self, @@ -34,18 +159,10 @@ def __init__( feature_size = 256 - self.fc1_predict = torch.nn.Linear(self.D_in, D_hidden) - self.fc2_predict = torch.nn.Linear(D_hidden, feature_size) - - self.fc1_target = torch.nn.Linear(self.D_in, D_hidden) - self.fc2_target = torch.nn.Linear(D_hidden, feature_size) + define_mlp_head_weight(self, D_in, D_hidden, feature_size) - if batch_norm: - self.bn1_predict = torch.nn.BatchNorm1d(D_hidden) - self.bn2_predict = torch.nn.BatchNorm1d(feature_size) - - self.bn1_target = torch.nn.BatchNorm1d(D_hidden) - self.bn2_target = torch.nn.BatchNorm1d(feature_size) + if self.batch_norm: + define_mlp_batch_norm(self, D_hidden, feature_size) def update_rms_obs(self, v): self.rms_obs.update(v) @@ -57,27 +174,13 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - p = F.relu(self.bn1_predict(self.fc1_predict(s_next))) - p = F.relu(self.bn2_predict(self.fc2_predict(p))) - - t = F.relu(self.bn1_target(self.fc1_target(s_next))) - t = F.relu(self.bn2_target(self.fc2_target(t))) - else: - p = F.relu(self.fc1_predict(s_next)) - p = F.relu(self.fc2_predict(p)) - - t = F.relu(self.fc1_target(s_next)) - t = F.relu(self.fc2_target(t)) + p, t = mlp_head(self, s_next) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) @@ -107,48 +210,9 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - dim1 = ((self.D_in[1] - 8) // 4 + 1, (self.D_in[2] - 8) // 4 + 1) - dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) - dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) - - feature_size = 64 * dim3[0] * dim3[1] - - # Predictor Networks - self.conv1_predict = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_predict = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_predict = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - self.fc1_predict = torch.nn.Linear(feature_size, D_hidden) - self.fc2_predict = torch.nn.Linear(D_hidden, D_hidden) - self.fc3_predict = torch.nn.Linear(D_hidden, D_hidden) - - # Target Networks - self.conv1_target = torch.nn.Conv2d( - in_channels=self.D_in[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_target = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_target = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - self.fc1_target = torch.nn.Linear(feature_size, D_hidden) - - if batch_norm: - self.bn1_predict = torch.nn.BatchNorm2d(32) - self.bn2_predict = torch.nn.BatchNorm2d(64) - self.bn3_predict = torch.nn.BatchNorm2d(64) - - self.bn1_target = torch.nn.BatchNorm2d(32) - self.bn2_target = torch.nn.BatchNorm2d(64) - self.bn3_target = torch.nn.BatchNorm2d(64) + feature_size = define_conv_head_weight(self, D_in) + define_conv_batch_norm(self) + define_fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs.update(v / 255.0) @@ -161,40 +225,14 @@ def forward(self, s_next, update_ri=False): if self.obs_normalize: s_next = normalize_obs(s_next, self.rms_obs.mean, self.rms_obs.var) - if self.batch_norm: - p = F.relu(self.bn1_predict(self.conv1_predict(s_next))) - p = F.relu(self.bn2_predict(self.conv2_predict(p))) - p = F.relu(self.bn3_predict(self.conv3_predict(p))) - else: - p = F.relu(self.conv1_predict(s_next)) - p = F.relu(self.conv2_predict(p)) - p = F.relu(self.conv3_predict(p)) - - p = p.view(p.size(0), -1) - p = F.relu(self.fc1_predict(p)) - p = F.relu(self.fc2_predict(p)) - p = self.fc3_predict(p) - - if self.batch_norm: - t = F.relu(self.bn1_target(self.conv1_target(s_next))) - t = F.relu(self.bn2_target(self.conv2_target(t))) - t = F.relu(self.bn3_target(self.conv3_target(t))) - else: - t = F.relu(self.conv1_target(s_next)) - t = F.relu(self.conv2_target(t)) - t = F.relu(self.conv3_target(t)) - - t = t.view(t.size(0), -1) - t = self.fc1_target(t) + p, t = conv_head(self, s_next) + p, t = fc_layers(self, p, t) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7) @@ -229,68 +267,17 @@ def __init__( self.ri_normalize = ri_normalize self.batch_norm = batch_norm - ################################## Conv HEAD ################################## - dim1 = ((self.D_in_img[1] - 8) // 4 + 1, (self.D_in_img[2] - 8) // 4 + 1) - dim2 = ((dim1[0] - 4) // 2 + 1, (dim1[1] - 4) // 2 + 1) - dim3 = ((dim2[0] - 3) // 1 + 1, (dim2[1] - 3) // 1 + 1) - - feature_size_img = 64 * dim3[0] * dim3[1] - - # Predictor Networks - self.conv1_predict = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_predict = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_predict = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - # Target Networks - self.conv1_target = torch.nn.Conv2d( - in_channels=self.D_in_img[0], out_channels=32, kernel_size=8, stride=4 - ) - self.conv2_target = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2 - ) - self.conv3_target = torch.nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1 - ) - - if batch_norm: - self.bn1_predict_conv = torch.nn.BatchNorm2d(32) - self.bn2_predict_conv = torch.nn.BatchNorm2d(64) - self.bn3_predict_conv = torch.nn.BatchNorm2d(64) - - self.bn1_target_conv = torch.nn.BatchNorm2d(32) - self.bn2_target_conv = torch.nn.BatchNorm2d(64) - self.bn3_target_conv = torch.nn.BatchNorm2d(64) - - ################################## MLP HEAD ################################## - feature_size_mlp = 256 - - self.fc1_predict_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_predict_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) + feature_size_img = define_conv_head_weight(self, self.D_in_img) + define_conv_batch_norm(self) - self.fc1_target_mlp = torch.nn.Linear(self.D_in_vec, D_hidden) - self.fc2_target_mlp = torch.nn.Linear(D_hidden, feature_size_mlp) - - if batch_norm: - self.bn1_predict_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_predict_mlp = torch.nn.BatchNorm1d(feature_size_mlp) + feature_size_mlp = 256 + define_mlp_head_weight(self, self.D_in_vec, D_hidden, feature_size_mlp) - self.bn1_target_mlp = torch.nn.BatchNorm1d(D_hidden) - self.bn2_target_mlp = torch.nn.BatchNorm1d(feature_size_mlp) + define_mlp_batch_norm(self, D_hidden, feature_size_mlp) - ################################## FC Layers ################################## - self.fc1_predict = torch.nn.Linear( - feature_size_img + feature_size_mlp, D_hidden - ) - self.fc2_predict = torch.nn.Linear(D_hidden, D_hidden) - self.fc3_predict = torch.nn.Linear(D_hidden, D_hidden) + feature_size = feature_size_img + feature_size_mlp - self.fc1_target = torch.nn.Linear(feature_size_img + feature_size_mlp, D_hidden) + define_fc_layers_weight(self, feature_size, D_hidden) def update_rms_obs(self, v): self.rms_obs_img.update(v[0] / 255.0) @@ -313,57 +300,19 @@ def forward(self, s_next, update_ri=False): s_next_vec, self.rms_obs_vec.mean, self.rms_obs_vec.var ) - ################################## Predict ################################## - if self.batch_norm: - p_i = F.relu(self.bn1_predict_conv(self.conv1_predict(s_next_img))) - p_i = F.relu(self.bn2_predict_conv(self.conv2_predict(p_i))) - p_i = F.relu(self.bn3_predict_conv(self.conv3_predict(p_i))) - - p_v = F.relu(self.bn1_predict_mlp(self.fc1_predict_mlp(s_next_vec))) - p_v = F.relu(self.bn2_predict_mlp(self.fc2_predict_mlp(p_v))) - else: - p_i = F.relu(self.conv1_predict(s_next_img)) - p_i = F.relu(self.conv2_predict(p_i)) - p_i = F.relu(self.conv3_predict(p_i)) + p_conv, t_conv = conv_head(self, s_next_img) + p_mlp, t_mlp = mlp_head(self, s_next_vec) - p_v = F.relu(self.fc1_predict_mlp(s_next_vec)) - p_v = F.relu(self.fc2_predict_mlp(p_v)) + p = torch.cat((p_conv, p_mlp), -1) + t = torch.cat((t_conv, t_mlp), -1) - p_i = p_i.view(p_i.size(0), -1) - p = torch.cat((p_i, p_v), -1) - - p = F.relu(self.fc1_predict(p)) - p = F.relu(self.fc2_predict(p)) - p = self.fc3_predict(p) - - ################################## target ################################## - if self.batch_norm: - t_i = F.relu(self.bn1_target_conv(self.conv1_target(s_next_img))) - t_i = F.relu(self.bn2_target_conv(self.conv2_target(t_i))) - t_i = F.relu(self.bn3_target_conv(self.conv3_target(t_i))) - - t_v = F.relu(self.bn1_target_mlp(self.fc1_target_mlp(s_next_vec))) - t_v = F.relu(self.bn2_target_mlp(self.fc2_target_mlp(t_v))) - else: - t_i = F.relu(self.conv1_target(s_next_img)) - t_i = F.relu(self.conv2_target(t_i)) - t_i = F.relu(self.conv3_target(t_i)) - - t_v = F.relu(self.fc1_target_mlp(s_next_vec)) - t_v = F.relu(self.fc2_target_mlp(t_v)) - - t_i = t_i.view(t_i.size(0), -1) - t = torch.cat((t_i, t_v), -1) - t = self.fc1_target(t) + p, t = fc_layers(self, p, t) r_i = torch.mean(torch.square(p - t), axis=1) if update_ri: - ri_T = r_i.view(self.num_workers, -1).T # (n_batch, n_workers) - rewems = torch.stack( - [self.rff.update(rit.detach()) for rit in ri_T] - ).ravel() # (n_batch, n_workers) -> (n_batch * n_workers) - self.update_rms_ri(rewems) + ri_update(r_i, self.num_workers, self.rff, self.update_rms_ri) + if self.ri_normalize: r_i = r_i / (torch.sqrt(self.rms_ri.var) + 1e-7)