Using GAZSL from A Generative Adversarial Approach for Zero-Shot Learning from Noisy Texts [1] in AttnGAN [2] model.
One wants to use noisy texts in AttnGAN to generate birds. This yields the following challenges:
- sentences are of arbitrary length
- not all sentences are relevant
The original model may be improved with GAN for Zero-shot learning [1] to generate better (in any sence) text and image embeddings.
In the original AttnGAN model, there are two types of text-modality embeddings: Word-level embeddings and Sentence-level embeddings (an average over all words). Hence, there are several ways how we can use ZSL within the model:
- Use ZSL encoder (text/image) to summarize text, image into a matrix of embeddings (pseudo-words), or word-level embeddings (fully replace original encoder). Matching distribution between words and matrix of image features
- Use ZSL encoder (text/image) on top of encoded words and matrix of image features. ZSL_ENC: W -> emb
- Use ZSL encoder (text/image) on top of sentence embedding and sentence of image features. ZSL_ENC: W.mean() -> emb
- Incorporate ZSL approach in the training phase, not pretraining one
To keep things simple I've built ZSL Generator and ZSL Discriminator on top of average embeddings obtained from the original encoders architecture (ZSL_ENC: W.mean() -> emb). The overall task is to introduce adversarial and classification loss via the discriminative model into an optimization objective.
- Because in the original ZSL Paper WGAN + GP is used it is usually recommended to update Discriminator more frequently than Generator (5:1 in the original ZSL code), for the sake of simplicity I ignore that fact and use the vanilla adv. loss function.
- Visual pivot regularization from the original paper is dropped
- We can omit KL-Loss because we have introduced stochastic part in Sentence Embedder (ZSL part uses z ~ N)
- There may be some bugs, need to perform further checks
- Choose better adv. loss function and hacks from other papers
- Try to incorporate ZSL for word-level embeddings (ZSL_ENC: W -> emb) and for training step (not pretraining one)
- Test eval
- Tune hyperparams
- TODOs in the code
- DASSM model was trained for 200 epochs. However it was only one run, so probably there may be bugs.
- The main model is training, however it's required to tune lambdas and smoothing factors to balance components of the objective function (see image below)