Skip to content

Commit

Permalink
inital tp commits
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Oct 31, 2024
1 parent 12aac35 commit 97c7915
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,16 @@ class ParallelRWKV_ChannelMix(nn.Module):
Channel Mix layer. The ffn in RWKV
"""

def __init__(self, neox_args, layer_number, init_method):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number

world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)


world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)

Expand Down Expand Up @@ -363,7 +368,7 @@ class RWKVResidualLayer(nn.Module):
"""
RWKV layer definition
"""

def __init__(self, neox_args, init_method, layer_number):
super().__init__()
self.neox_args = neox_args
Expand Down

0 comments on commit 97c7915

Please sign in to comment.