diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 82b5a62b..bbed5e82 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -450,12 +450,66 @@ def forward(self, x): x = RUN_CUDA(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa) return self.jit_func_2(x, g) + +######################################################################################################## +# RWKV: RWKV Wavenet-mem + rotary memory +######################################################################################################## +class Short_Mem(nn.Module): + def __init__(self, args, shiftAmount=1): + super().__init__() + # self.time_shift1 = TimeShift(args.n_embd, shiftAmount=shiftAmount, batch=args.micro_bsz) + self.time_shift1 = nn.ZeroPad2d((0, 0, shiftAmount, -shiftAmount)) + self.activation = nn.Sequential( + nn.Linear(args.n_embd*2, args.n_embd, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + xv = self.activation(torch.cat([self.time_shift1(x),x], dim=-1)) + return xv + +class WaveNet_Mem(Short_Mem): + def __init__(self, args, layer_id, modulo=12, undialated=False): + if undialated: + super().__init__(args, shiftAmount=1) + else: + super().__init__(args, shiftAmount=2**(layer_id%modulo)) + +class Rotary_Memory(nn.Module): + def __init__(self, args, layer_id): + nn.Module.__init__(self) + self.args = args + self.layer_id = layer_id + + self.complexsize = args.n_embd + self.short = WaveNet_Mem(args, layer_id, undialated=True) + self.key = nn.Linear(args.n_embd,self.complexsize*2, bias=False, dtype=torch.bfloat16) + # self.cumprod = CumProd(torch.complex(torch.ones(args.micro_bsz, 1, self.complexsize), torch.zeros(args.micro_bsz, 1, self.complexsize))) + # self.cummax = CumMax() + self.activation = nn.Linear(self.complexsize*2, args.n_embd, bias=False, dtype=torch.bfloat16) + + def forward(self, x): + B, T, C = x.size() + k = self.key(x).float() + + + complexval = torch.view_as_complex(k.reshape(B, T, self.complexsize,2)) + # scale = self.cummax(torch.abs(complexval)) + scale = torch.cummax(torch.abs(complexval), dim=-2)[0] + complexval2 = complexval / scale + # kv = self.cumprod(complexval2) + kv = torch.cumprod(complexval2, dim=-2) + out = self.activation(torch.view_as_real(kv).reshape(B, T, self.complexsize*2)) * self.short(x* scale) + + return out ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## + + class RWKV_TimeMix(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -632,6 +686,11 @@ def __init__(self, args, layer_id): self.ln1 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd) + self.use_mem = 'c' in os.environ["RWKV_MY_TESTING"] + if self.use_mem: + self.ln3 = nn.LayerNorm(args.n_embd) + self.mem = Rotary_Memory(args, layer_id) + if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: @@ -663,6 +722,8 @@ def __init__(self, args, layer_id): if args.dropout > 0: self.drop0 = nn.Dropout(p = args.dropout) self.drop1 = nn.Dropout(p = args.dropout) + if self.use_mem: + self.drop2 = nn.Dropout(p = args.dropout) def forward(self, x, x_emb=None): args = self.args @@ -678,12 +739,16 @@ def forward(self, x, x_emb=None): x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) + if self.use_mem: + x = x + self.mem(self.ln3(x)) x = x + self.ffn(self.ln2(x)) else: if self.layer_id == 0 and args.pre_ffn > 0: x = self.drop0(x + self.ffnPre(self.ln1(x))) else: x = self.drop0(x + self.att(self.ln1(x))) + if self.use_mem: + x = self.drop2(x + self.mem(self.ln3(x))) x = self.drop1(x + self.ffn(self.ln2(x))) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: