-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathmodel.py
74 lines (65 loc) · 2.38 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter
class VIN(nn.Module):
def __init__(self, config):
super(VIN, self).__init__()
self.config = config
self.h = nn.Conv2d(
in_channels=config.l_i,
out_channels=config.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)
self.r = nn.Conv2d(
in_channels=config.l_h,
out_channels=1,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=False)
self.q = nn.Conv2d(
in_channels=1,
out_channels=config.l_q,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=False)
self.fc = nn.Linear(in_features=config.l_q, out_features=8, bias=False)
self.w = Parameter(
torch.zeros(config.l_q, 1, 3, 3), requires_grad=True)
self.sm = nn.Softmax(dim=1)
def forward(self, input_view, state_x, state_y, k):
"""
:param input_view: (batch_sz, imsize, imsize)
:param state_x: (batch_sz,), 0 <= state_x < imsize
:param state_y: (batch_sz,), 0 <= state_y < imsize
:param k: number of iterations
:return: logits and softmaxed logits
"""
h = self.h(input_view) # Intermediate output
r = self.r(h) # Reward
q = self.q(r) # Initial Q value from reward
v, _ = torch.max(q, dim=1, keepdim=True)
def eval_q(r, v):
return F.conv2d(
# Stack reward with most recent value
torch.cat([r, v], 1),
# Convolve r->q weights to r, and v->q weights for v. These represent transition probabilities
torch.cat([self.q.weight, self.w], 1),
stride=1,
padding=1)
# Update q and v values
for i in range(k - 1):
q = eval_q(r, v)
v, _ = torch.max(q, dim=1, keepdim=True)
q = eval_q(r, v)
# q: (batch_sz, l_q, map_size, map_size)
batch_sz, l_q, _, _ = q.size()
q_out = q[torch.arange(batch_sz), :, state_x.long(), state_y.long()].view(batch_sz, l_q)
logits = self.fc(q_out) # q_out to actions
return logits, self.sm(logits)