Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two transformer models via upload #508

Merged
merged 18 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions qlib/contrib/model/pytorch_localformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


from __future__ import division
from __future__ import print_function

import os
import numpy as np
import pandas as pd
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList

import pdb

# qrun benchmarks/Transformer/workflow_config_localformer_Alpha158.yaml
# 0.992366, @13,
'''
{'IC': 0.037426503365732174,
'ICIR': 0.28977883455541603,
'Rank IC': 0.04659889541774283,
'Rank ICIR': 0.373569340092482}

'The following are analysis results of the excess return without cost.'
risk
mean 0.000381
std 0.004109
annualized_return 0.096066
information_ratio 1.472729
max_drawdown -0.094917
'The following are analysis results of the excess return with cost.'
risk
mean 0.000213
std 0.004111
annualized_return 0.053630
information_ratio 0.821711
max_drawdown -0.113694
'''


class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=2,
seed=None,
**kwargs
):

# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
print('do we have gpu?{}'.format(torch.cuda.is_available()))
self.logger.info(
"Improved Transformer:"
"\nbatch_size : {}"
"\ndevice : {}".format(self.batch_size, self.device)
)

if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)

self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))

self.fitted = False
self.model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)

def loss_fn(self, pred, label):
mask = ~torch.isnan(label)

if self.loss == "mse":
return self.mse(pred[mask], label[mask])

raise ValueError("unknown loss `%s`" % self.loss)

def metric_fn(self, pred, label):

mask = torch.isfinite(label)

if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])

raise ValueError("unknown metric `%s`" % self.metric)

def train_epoch(self, data_loader):

self.model.train()

for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)

pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)

self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()

def test_epoch(self, data_loader):

self.model.eval()

scores = []
losses = []

for data in data_loader:

feature = data[:, :, 0:-1].to(self.device)
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)

with torch.no_grad():
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
losses.append(loss.item())

score = self.metric_fn(pred, label)
scores.append(score.item())

return np.mean(losses), np.mean(scores)

def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):

dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)

dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []

# train
self.logger.info("training...")
self.fitted = True

for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)

if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break

self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)

if self.use_gpu:
torch.cuda.empty_cache()

def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")

dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.model.eval()
preds = []

for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)

with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()

preds.append(pred)

return pd.Series(np.concatenate(preds), index=dl_test.get_index())


class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)

def forward(self, x):
# [T, N, F]
return x + self.pe[:x.size(0), :]


def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])


class LocalformerEncoder(nn.Module):
__constants__ = ['norm']

def __init__(self, encoder_layer, num_layers, d_model):
super(LocalformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
self.num_layers = num_layers

def forward(self, src, mask):
output = src
out = src

for i, mod in enumerate(self.layers):
# [T, N, F] --> [N, T, F] --> [N, F, T]
out = output.transpose(1, 0).transpose(2, 1)
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)

output = mod(output+out, src_mask=mask)

return output + out


class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=num_layers,
batch_first=False,
dropout=dropout,
)
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat

def forward(self, src):
# pdb.set_trace()
# src [N, T, F], [512, 60, 6]

src = self.feature_layer(src) # [512, 60, 8]

# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first

mask = None

src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]

output, _ = self.rnn(output)

# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]

return output.squeeze()

Loading