Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions hawk/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,6 @@ def rnn_param_init(
raise NotImplementedError()


_MAX_SQRT_GRADIENT = 1000.0


class SqrtBoundDerivative(torch.autograd.Function):
"""Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`."""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
"""The forward pass, which is a normal `sqrt`."""
ctx.save_for_backward(x)
return torch.sqrt(x)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore
"""The backward pass, which clips the `sqrt` gradient."""
(x,) = ctx.saved_tensors
clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2))
return grad_output / torch.sqrt(clipped_x_times_4)


class BlockDiagonalLinear(nn.Module):
"""Block-diagonal linear layer."""

Expand Down
24 changes: 13 additions & 11 deletions hawk/hawk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.nn.functional as F

from .cache import RNNCache
from .external import BlockDiagonalLinear, Conv1D, SqrtBoundDerivative, rnn_param_init
from .scan import linear_scan
from .external import BlockDiagonalLinear, Conv1D, rnn_param_init
from .scan_fused import fused_linear_scan

# ------
# Config
Expand All @@ -28,9 +28,9 @@ class HawkConfig:
post_norm: bool = False


# ------
# Helper
# ------
# ----
# Init
# ----


def lecun_init(w: torch.Tensor, d_in: int):
Expand Down Expand Up @@ -139,7 +139,8 @@ def forget_init(self, w: torch.Tensor) -> torch.Tensor:
def epilogue(self, gate, h):
return self.resid_proj(F.gelu(gate) * self.norm(h))

def prologue(self, x):
def inference_prologue(self, x):
# inference-only prologue function
gate_x = torch.sigmoid(self.rg_lru_input_gate(x))
gate_a = torch.sigmoid(self.rg_lru_a_gate(x))

Expand All @@ -148,9 +149,7 @@ def prologue(self, x):
a_square = torch.exp(2 * log_a.float())
gated_x = x * gate_x

multiplier = SqrtBoundDerivative.apply(1 - a_square)

assert multiplier is not None
multiplier = torch.sqrt(1 - a_square)

normalized_x = gated_x * multiplier.to(x.dtype)

Expand Down Expand Up @@ -184,18 +183,21 @@ def forward(
if has_layer_past:
x = x[:, -1:, ...]

a, normalized_x = self.prologue(x)
x_rg_lru = self.rg_lru_input_gate(x)
a_rg_lru = self.rg_lru_a_gate(x)

cache = None
if not has_layer_past:
h = linear_scan(a, normalized_x)
h = fused_linear_scan(x, x_rg_lru, a_rg_lru, self.rg_lru_a_param.float())

if self.use_cache:
assert h is not None
layer_past.update_cache(h[:, -1:, :])
cache = layer_past

else:
a, normalized_x = self.inference_prologue(x)

h = (a * layer_past.recc_state) + normalized_x

layer_past.update_cache(h[:, -1:, :])
Expand Down
234 changes: 0 additions & 234 deletions hawk/scan.py

This file was deleted.

Loading
Loading