diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index 2ccc85dc5a..f60da8310d 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -14,13 +14,15 @@ def __init__(self): self.lstm1 = nn.LSTMCell(1, 51) self.lstm2 = nn.LSTMCell(51, 51) self.linear = nn.Linear(51, 1) + self.dummy_param = nn.Parameter(torch.empty(0)) def forward(self, input, future = 0): outputs = [] - h_t = torch.zeros(input.size(0), 51, dtype=torch.double) - c_t = torch.zeros(input.size(0), 51, dtype=torch.double) - h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double) - c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double) + device = self.dummy_param.device + h_t = torch.zeros(input.size(0), 51, dtype=torch.double).to(device) + c_t = torch.zeros(input.size(0), 51, dtype=torch.double).to(device) + h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double).to(device) + c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double).to(device) for input_t in input.split(1, dim=1): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) @@ -39,20 +41,29 @@ def forward(self, input, future = 0): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--steps', type=int, default=15, help='steps to run') + parser.add_argument('--device', type=str, default='cuda', help='training device. cuda, mps, or cpu.') opt = parser.parse_args() + # training device + device_name = opt.device + if device_name == 'cuda' and not torch.cuda.is_available(): + print('cuda is not available') + exit(-1) + elif device_name == 'mps' and not torch.backends.mps.is_available(): + print('mps is not available') + exit(-1) + device = torch.device(device_name) # set random seed to 0 np.random.seed(0) torch.manual_seed(0) # load data and make training set - data = torch.load('traindata.pt') - input = torch.from_numpy(data[3:, :-1]) - target = torch.from_numpy(data[3:, 1:]) - test_input = torch.from_numpy(data[:3, :-1]) - test_target = torch.from_numpy(data[:3, 1:]) + data = torch.from_numpy(torch.load('traindata.pt')).to(device) + input = data[3:, :-1] + target = data[3:, 1:] + test_input = data[:3, :-1] + test_target = data[:3, 1:] # build the model - seq = Sequence() - seq.double() - criterion = nn.MSELoss() + seq = Sequence().to(device).double() + criterion = nn.MSELoss().to(device) # use LBFGS as optimizer since we can load the whole data to train optimizer = optim.LBFGS(seq.parameters(), lr=0.8) #begin to train @@ -72,7 +83,7 @@ def closure(): pred = seq(test_input, future=future) loss = criterion(pred[:, :-future], test_target) print('test loss:', loss.item()) - y = pred.detach().numpy() + y = pred.cpu().detach().numpy() # draw the result plt.figure(figsize=(30,10)) plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)