diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 970613f27..fa0eaa53f 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -297,11 +297,16 @@ class ParallelRWKV_ChannelMix(nn.Module): Channel Mix layer. The ffn in RWKV """ + def __init__(self, neox_args, layer_number, init_method): def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + + world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) @@ -363,7 +368,7 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ - + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args