A transformer neural network for a gesture keyboard that transduces curves swiped across a keyboard into word candidates
Contribution:
- A new method for constructing swipe point embeddings (SPE) that outperforms existing ones. It leverages a weighted sum of all keyboard key embeddings, resulting in a notable perfomance boost: 0.67% increase in Swipe MRR and 0.73% in accuracy compared to SPE construction methods described in literature
Other highlights:
- Enhanced Inference with Custom Beam Search: a modified beam search is implemented that masks out logits corresponding to impossible (according to dictionary) token continuations given an already generated prefix. It is faster and more accurate than a standard beam search
This repository used to contain my Yandex Cup 2023 solution (7th place), but after many improvements, it has become a standalone project
Try out a live demo with a trained model from the competition through this web app
Note
If the website is not available, you can run the demo yourself by following the instructions in the web app's GitHub repository.
Note
The website may take a minute to load, as it is not yet fully optimized. If you encounter a "Something went wrong" page, try refreshing the page. This usually resolves the issue.
Note
The model is an old and underfit legacy transformer variation (m1_bigger in models.py) that was used in the competition. A significant update is planned for both this project and the web app, but it will happen in winter 2024
Access a brief research report here, which includes:
- Overview of existing research
- Description of the developed method for constructing swipe point embeddings
- Comparative analysis and results
For in-depth insights, you can refer to my master's thesis (in Russian)
Install the dependencies:
pip install -r requirements.txt
- The inference was tested with python 3.10
- The training was conducted in kaggle using Tesla P100
A trained model is defined not only by its class and weights but also by the dataset transformation used during training.
All current models are instances of model.EncoderDecoderTransformerLike
and consist of the following components:
- Swipe point embedder
- Word component token embedder (currently char-level)
- Encoder
- Decoder
Transforms extract features from the raw dataset, converting each dataset item from the format (x, y, t, grid_name, tgt_word)
to (encoder_input, decoder_input), decoder_output
.
After collating the dataset, the format becomes (packed_model_in, dec_out)
, where packed_model_in
is (encoder_input, decoder_input, swipe_pad_mask, word_pad_mask)
. packed_model_in
is passed to the model via unpacking (model(*packed_model_in)
).
encoder_input
is passed as the only argument to swipe_point_embedder’s forward. The type depends on which swipe point embedding layer you use. It can be a single object, a tuple of objectsdecoder_input
anddecoder_output
aretokenized_target_word[1:]
andtokenized_target_word[:-1]
correspondingly.
A trained swipe decoding method is defined by
- model class
- model weights
- dataset transformation
- decoding algorithm
Your custom dataset must have items of format: tuple(x, y, t, grid_name, tgt_word)
. These raw features won't be used but there are transforms defined in feature_extractors.py
corresponding to every type of swipe point embedding layer
that extract the needed features. You can apply these transforms in your dataset's __init__
method or in __get_item__
/ __iter__
. The data formats after transform and after collation are described above
You also need to add your keyboard layout to grid_name_to_grid.json
The training is done in train.ipynb
Warning
train.ipynb
drains RAM when using n_workers
> 0 in Dataloader. This can result in up to dataset_size * n_workers
extra gigabytes of RAM usage. This is a known issue (see here) that happens when a dataset uses a list to store data. Although torch.cuda.empty_cache()
can be used as a workaround, it doesn't seem to work with pytorch lightning. It appears I didn't commit this workaround, but you can adapt train.ipynb from before-lightning branch by adding torch.cuda.empty_cache()
after each epoch to to avoid the issue. When training in a kaggle notebook, the issue is not a problem since a kaggle session comes with 30 Gb of RAM.
word_generation_demo.ipynb serves as an example on how to predict via a trained model.
predict_v2.py is used to obtain word candidates for a whole dataset and pickle them
Warning
If the decoding algorithm in predict_v2.py
script utilizes a vocabulary for masking (if use_vocab_for_generation: true
in the config), it is necessary to disable multiprocessing by passing the command-line argument --num-workers 0
to the script. Otherwise, the prediction will take a long time. It's a bug that will be fixed
A WIP documentation can be found here. It doesn't contain much information yet, will be updated. Please refer to docstrings in the code for now
See refactoring plan