Skip to content

Latest commit

 

History

History
168 lines (130 loc) · 7.56 KB

README.md

File metadata and controls

168 lines (130 loc) · 7.56 KB

Differential Translation

Transformer model (Vaswani, et al 2017) adapted with a differential attention mechanism (Ye et al, 2024) for translation tasks.

The Differential Transformer paper proposed the differential attention mechanism and studied the impact of its use in the masked multi-head self-attention layer of a decoder-only transformer model.

The purpose of this project is to continue this research to a full encoder-decoder transformer model, applying the differential attention mechanism in both the cross-attention and self-attention layers.

If you use this work in your research, please consider citing it:

@misc{differential_translation,
  author = {Daniel Vega-Myhre},
  title = {Translate AI: Differential Attention in an Encoder-Decoder Transformer Model},
  month = November,
  year = 2024,
  url = {https://github.com/danielvegamyhre/translate-ai}
}

Feature overview

Training

Training script templates to estimate MFU for:

  • Single GPU training
  • Multi-GPU training
  • Multi-node training

Example for multi-node training on 2 nodes with 1 GPU each:

torchrun --nproc_per_node=1 --nnodes 2 --master_port=12345 \
    translate_ai/train.py \
    --dataset-file data/english-spanish.csv \
    --device cuda \                               # 🔥 Configure device type and mixed precision data type
    --mixed-precision bf16 \
    --epochs 10 \                                 # 🔥 Training hyperparams
    --learning-rate .004 \
    --batch-size 128 \
    --num-layers 2 \                              # 🔥 Architecture hyperparams
    --embed-dim 128 \
    --d-model 128 \
    --ffwd-dim 512 \
    --seq-len 128 \
    --max-output-tokens 128 \
    --eval-interval 200 \                         # 🔥 Evaluation interval and iterations for computing validation loss
    --eval-iters 10 \
    --checkpoint-interval 200 \
    --save-checkpoint checkpoints/chkpt.pt \
    --wandb-project ${WANDB_PROJECT} \            # 🔥 Weights & Biases configuration for training observability
    --wandb-api-key ${WANDB_API_KEY} \
    --distributed                                 # 🔥 Use DDP for distributed training

Performance Analysis

Performance analysis script templates to estimate MFU for:

  • Single GPU training
  • Multi-GPU training
  • Multi-node training

Example MFU estimation for multi-node training:

torchrun --nproc_per_node=1 --nnodes 2 --master_port=12345 translate_ai/train.py \
    --dataset-file data/english-spanish.csv \
    --device cuda \
    --mixed-precision bf16 \
    --learning-rate .001 \
    --batch-size 32 \
    --num-layers 2 \
    --embed-dim 128 \
    --d-model 128 \
    --ffwd-dim 512 \
    --seq-len 128 \
    --max-output-tokens 128 \
    --hardware-peak-tflops ${HARDWARE_PEAK_TFLOPS} \    # 🔥 Set peak accelerator TFLOPs by referencing manufacturer docs
    --distributed \                                     # 🔥 Use DDP
    --estimate-mfu                                      # 🔥 Run MFU estimation instead of training

Inference

Example:

python3 translate_ai/translate.py --english-query "The cat is blue." --checkpoint-file checkpoints/chkpt.pt
...
"El gato es azul."

Training workload orchestration with Kubernetes

Using Kubernetes and the JobSet API simplifies the process of orchestrating distributed training workloads, especially for very large scale workloads.

To deploy a training workload using Kubernetes, you can follow these steps. This guide will use Google Cloud to provision the infrastructure, but you can use any k8s cluster (on-prem or cloud based).

Prerequisites

  1. Google Cloud account set up.
  2. gcloud installed in your local development environment.
  3. Docker installed in your local development environment.
  4. The following Google Cloud APIs are enabled: GKE, Artifact Repository

Steps

  1. Build the container image and push it to Artifact Repository with the following command:
PROJECT_ID=your-gcp-project REPO_NAME=your-ar-repo IMAGE_NAME=translate TAG=latest ./build_and_push.sh
  1. Create a GKE cluster with a single GPU nodepool. The example script below will provision a GKE cluster called demo in zone us-central1-c, with a GPU node pool called gpu-pool. The GPU pool will have 2 nodes of type n1-standard-4, each with 1 NVIDIA Tesla T4 GPUs attached. Note these GPUs have peak hardware capacity of 8.1 TFLOPS (fp32) or 65 TFLOPS (fp16, bf16) - this will be needed if you want to run performance analysis to estimate MFU. If you use a different GPU, refer to the vendor spec to get the hardware peak FLOPS.
./create_cluster.sh
  1. Install the JobSet API, a k8s native distributed ML training orchestrator.
VERSION=v0.6.0
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/$VERSION/manifests.yaml

You can find more detailed installation information here.

  1. Modify the reference JobSet manifest. Note the items marked optional are not needed in this example as they are already configured to correctly match the infrastructure provisioned in these steps, but if you train on different infrastructure (more/fewer nodes, more/fewer GPUs per node, etc.) you'll need to configure these parameters as well.
  • Set the container image parameter to be your own container image which you built and pushed in step 1.
  • (Optional) Set the Job template parallelism and completions fields to match the number of nodes in your GPU pool.
  • (Optional) Set the environment variable NPROC_PER_NODE as the number of GPUs per node (in this case, 1).
  • (Optional) Set the environment variable NNODES as the number of nodes in your GPU pool.
  • (Optional) Set the environment variable WANDB_PROJECT to be your Weights and Biases project name.
  • (Optional) Set the environment variable WANDB_API_KEY to be your Weights and Biases API key. The best practices for doing this securely can be found here.
  1. Deploy the workload!
kubectl apply -f jobset.yaml
  1. Verify the training is working, either by viewing container logs, Tensorboard, or Weights and Biases (depending on what observability instrumentation you have set up).