From db7e1cebf0342551560287ffd613eed1a807512b Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 9 Apr 2024 10:58:51 -0700 Subject: [PATCH] Update TP examples to align with tutorials as titled --- distributed/tensor_parallelism/README.md | 16 +- .../tensor_parallelism/fsdp_tp_example.py | 142 ++--- .../tensor_parallelism/llama2_model.py | 490 ++++++++++++++++++ .../tensor_parallelism/requirements.txt | 2 +- 4 files changed, 579 insertions(+), 71 deletions(-) create mode 100644 distributed/tensor_parallelism/llama2_model.py diff --git a/distributed/tensor_parallelism/README.md b/distributed/tensor_parallelism/README.md index d72cbf5feb..b49d1672e8 100644 --- a/distributed/tensor_parallelism/README.md +++ b/distributed/tensor_parallelism/README.md @@ -1,14 +1,14 @@ -# PyTorch Tensor Parallelism for distributed training +# PyTorch native Tensor Parallel for distributed training -This example demonstrates SPMD Megatron-LM style tensor parallel by using -PyTorch native Tensor Parallelism APIs, which include: +This example demonstrates SPMD Megatron-LM style Tensor Parallel by using +PyTorch native Tensor Parallel APIs, which include: -1. High-level APIs for module-level parallelism with a dummy MLP model. -2. Model agnostic ops for `DistributedTensor`, such as `Linear` and `RELU`. -3. A E2E demo of tensor parallel for a given toy model (Forward/backward + optimization). +1. Simple module-level Tensor Parallelism on a dummy MLP model. +2. Simple module-level Tensor Parallelism with Sequence Parallel inputs/outputs on a dummy MLP model. +3. A E2E demo of Fully Sharded Data Parallel + Tensor Parallel (with Sequence Parallel) on a example Llama2 model. -More details about the design can be found: -https://github.com/pytorch/pytorch/issues/89884 +More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs: +https://pytorch.org/docs/stable/distributed.tensor.parallel.html ``` pip install -r requirements.txt diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index a85c798e6f..216c2ce32d 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,20 +1,12 @@ import sys +import os import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor.parallel import ( - parallelize_module, - ColwiseParallel, - RowwiseParallel, -) - -import os from log_utils import rank_log, get_logger, verify_min_gpu_count - # ---- GPU check ------------ _min_gpu_count = 4 @@ -23,13 +15,24 @@ sys.exit() # --------------------------- -from torch.distributed._tensor.device_mesh import init_device_mesh +from llama2_model import Transformer, ModelArgs + +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed._tensor import Shard, Replicate +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, + PrepareModuleInput, + SequenceParallel +) """ This is the script to test 2D Parallel which combines Tensor/Sequence -parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model -in the SPMD style. We show an E2E working flow from forward, backward +parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example +Llama2 model. We show an E2E working flow from forward, backward and optimization. We enabled Fully Sharded Data Parallel + Tensor Parallel in @@ -53,41 +56,10 @@ [0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] ====================================================================== -More details can be seen in the slide: -https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ +More details can be seen in the PyTorch tutorials: +https://pytorch.org/tutorials/intermediate/TP_tutorial.html """ - -def find_multiple(n: int, k: int) -> int: - """function to find resizing multiple for SwiGLU MLP""" - if n % k == 0: - return n - return n + k - (n % k) - - -class MLP_swiglu(nn.Module): - """SwiGLU to showcase a Llama style MLP model""" - - def __init__(self, mlp_dim: int = 1024) -> None: - super().__init__() - hidden_dim = 4 * mlp_dim - scaled_hidden = int(2 * hidden_dim / 3) - rounded_hidden = find_multiple(scaled_hidden, 256) - - self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) - self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) - self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.silu(self.in_proj(x)) * self.gate_proj(x) - x = self.out_proj(x) - return x - - -""" -Main body of the demo of a basic version of tensor parallel by using -PyTorch native APIs. -""" tp_size = 2 logger = get_logger() @@ -120,26 +92,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # to mimic the behavior of the dataloader. dp_rank = dp_mesh.get_local_rank() -# create model and move it to GPU with id rank -_mlp_dim = 1024 -base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda") - - -# Custom parallelization plan for the swiglu MLP model -custom_tp_model = parallelize_module( - module=base_model_swiglu, - device_mesh=tp_mesh, - parallelize_plan={ - "in_proj": ColwiseParallel(), - "gate_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, +# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000) + +model = Transformer.from_model_args(simple_llama2_config).to("cuda") + +# init model weights +model.init_weights() + +# parallelize the first embedding and the last linear out projection +model = parallelize_module( + model, + tp_mesh, + { + "embeddings.tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + ), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate() + ), + "norm": SequenceParallel(), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Shard(1), None), + use_local_output=True, + ), + } ) -rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n") +for layer_id, transformer_block in enumerate(model.layers): + layer_tp_plan = { + "attention": PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + # Custom parallelization plan for the model + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan + ) # Init FSDP using the dp device mesh -sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) +sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True) + +rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n") # Create an optimizer for the parallelized and sharded model. lr = 3e-3 @@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i in range(num_iterations): # seeding with dp_rank to ensure identical inputs for TP groups torch.manual_seed(i + dp_rank) - inp = torch.rand(batch_size, _mlp_dim, device="cuda") + inp = torch.randint(32000, (8, 256), device="cuda") output = sharded_model(inp) output.sum().backward() diff --git a/distributed/tensor_parallelism/llama2_model.py b/distributed/tensor_parallelism/llama2_model.py new file mode 100644 index 0000000000..c5e8aca42e --- /dev/null +++ b/distributed/tensor_parallelism/llama2_model.py @@ -0,0 +1,490 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 32768 + depth_init: bool = ( + True # initialization uses each unique layer_id or total model layer count + ) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + xk, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + xv, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class RotaryEmbedding(nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + def _precompute_freqs_cis(self): + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the embedding module. + + Args: + tokens (torch.Tensor): Input tensor. + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + freqs_cis = self.freqs_cis[0:seqlen] + return h, freqs_cis + + def init_weights(self): + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + nn.init.normal_(self.tok_embeddings.weight) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = RMSNorm( + dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = RMSNorm( + dim=model_args.dim, eps=model_args.norm_eps + ) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.model_dim = model_args.dim + + self.embeddings = RotaryEmbedding(model_args) + self.layers = torch.nn.ModuleList() + for layer_id in range(model_args.n_layers): + self.layers.append(TransformerBlock(layer_id, model_args)) + + self.norm = RMSNorm( + dim=model_args.dim, eps=model_args.norm_eps + ) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + self.embeddings.init_weights() + for layer in self.layers: + layer.init_weights() + self.norm.reset_parameters() + final_out_std = self.model_dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + h, freqs_cis = self.embeddings(tokens) + for layer in self.layers: + h = layer(h, freqs_cis) + h = self.norm(h) + output = self.output(h).float() + return output + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/distributed/tensor_parallelism/requirements.txt b/distributed/tensor_parallelism/requirements.txt index c6b283a441..80fad36bf2 100644 --- a/distributed/tensor_parallelism/requirements.txt +++ b/distributed/tensor_parallelism/requirements.txt @@ -3,4 +3,4 @@ --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -torch >= 2.2.0.dev0; sys_platform == "linux" +torch >= 2.3.0.dev0; sys_platform == "linux"