Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.
This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ
$ pip install RQ-transformer
import torch
from rq_transformer import RQTransformer
model = RQTransformer(
num_tokens = 16000, # number of tokens, in the paper they had a codebook size of 16k
dim = 512, # transformer model dimension
max_spatial_seq_len = 1024, # maximum positions along space
depth_seq_len = 4, # number of positions along depth (residual quantizations in paper)
spatial_layers = 8, # number of layers for space
depth_layers = 4, # number of layers for depth
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
)
x = torch.randint(0, 16000, (1, 1024, 4))
loss = model(x, return_loss = True)
loss.backward()
# then after much training
logits = model(x)
# and sample from the logits accordingly
# or you can use the generate function
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the HierarchicalCausalTransformer
import torch
from rq_transformer import HierarchicalCausalTransformer
model = HierarchicalCausalTransformer(
num_tokens = 16000, # number of tokens
dim = 512, # feature dimension
dim_head = 64, # dimension of attention heads
heads = 8, # number of attention heads
depth = (4, 4, 2), # 3 stages (but can be any number) - transformer of depths 4, 4, 2
max_seq_len = (16, 4, 5) # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()
x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()
loss = model(x, return_loss = True)
loss.backward()
# after a lot training
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)
- move hierarchical causal transformer to separate repository, seems to be working
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
@misc{press2021ALiBi,
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
author = {Ofir Press and Noah A. Smith and Mike Lewis},
year = {2021},
url = {https://ofir.io/train_short_test_long.pdf}
}