Skip to content

Commit

Permalink
Fix the rotation segments for rope and shortrope
Browse files Browse the repository at this point in the history
  • Loading branch information
gkielian committed Nov 9, 2023
1 parent 75441bd commit b1d0e22
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(self, config):
self.register_buffer('freq_right', 1.0 / (10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2)))

def forward(self, x):

seq_len = x.shape[-2]
device = x.device

Expand All @@ -172,15 +171,10 @@ def forward(self, x):

# Apply frequencies
x_left, x_right = x[..., :self.dim//2], x[..., self.dim//2:]
x_left = x_left * freqs_left.cos()
x_right = x_right * freqs_right.sin()

# Apply frequencies
x_left = x_left * freqs_left.cos()
x_right = x_right * freqs_right.sin()
x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin()
x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos()

# Remaining logic
x_right = torch.flip(x_right, [-1]) * -1
# Combine the left and right parts back
x = torch.cat([x_left, x_right], dim=-1)

return x
Expand Down Expand Up @@ -225,8 +219,8 @@ def forward(self, x):
x_right = x_rotated[..., self.n//2:]

# Apply the cosine and sine rotations
x_left = x_left * freqs_left.cos()
x_right = x_right * freqs_right.sin()
x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin()
x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos()

# Invert the order of the right tensor's last dimension and negate it
x_right = torch.flip(x_right, dims=[-1]) * -1
Expand Down

0 comments on commit b1d0e22

Please sign in to comment.