diff --git a/hawk/external.py b/hawk/external.py index c22f4e1..21417e3 100644 --- a/hawk/external.py +++ b/hawk/external.py @@ -168,26 +168,6 @@ def rnn_param_init( raise NotImplementedError() -_MAX_SQRT_GRADIENT = 1000.0 - - -class SqrtBoundDerivative(torch.autograd.Function): - """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`.""" - - @staticmethod - def forward(ctx, x: torch.Tensor) -> torch.Tensor: - """The forward pass, which is a normal `sqrt`.""" - ctx.save_for_backward(x) - return torch.sqrt(x) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore - """The backward pass, which clips the `sqrt` gradient.""" - (x,) = ctx.saved_tensors - clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2)) - return grad_output / torch.sqrt(clipped_x_times_4) - - class BlockDiagonalLinear(nn.Module): """Block-diagonal linear layer.""" diff --git a/hawk/hawk.py b/hawk/hawk.py index 31d6189..833fa33 100644 --- a/hawk/hawk.py +++ b/hawk/hawk.py @@ -9,8 +9,8 @@ import torch.nn.functional as F from .cache import RNNCache -from .external import BlockDiagonalLinear, Conv1D, SqrtBoundDerivative, rnn_param_init -from .scan import linear_scan +from .external import BlockDiagonalLinear, Conv1D, rnn_param_init +from .scan_fused import fused_linear_scan # ------ # Config @@ -28,9 +28,9 @@ class HawkConfig: post_norm: bool = False -# ------ -# Helper -# ------ +# ---- +# Init +# ---- def lecun_init(w: torch.Tensor, d_in: int): @@ -139,7 +139,8 @@ def forget_init(self, w: torch.Tensor) -> torch.Tensor: def epilogue(self, gate, h): return self.resid_proj(F.gelu(gate) * self.norm(h)) - def prologue(self, x): + def inference_prologue(self, x): + # inference-only prologue function gate_x = torch.sigmoid(self.rg_lru_input_gate(x)) gate_a = torch.sigmoid(self.rg_lru_a_gate(x)) @@ -148,9 +149,7 @@ def prologue(self, x): a_square = torch.exp(2 * log_a.float()) gated_x = x * gate_x - multiplier = SqrtBoundDerivative.apply(1 - a_square) - - assert multiplier is not None + multiplier = torch.sqrt(1 - a_square) normalized_x = gated_x * multiplier.to(x.dtype) @@ -184,11 +183,12 @@ def forward( if has_layer_past: x = x[:, -1:, ...] - a, normalized_x = self.prologue(x) + x_rg_lru = self.rg_lru_input_gate(x) + a_rg_lru = self.rg_lru_a_gate(x) cache = None if not has_layer_past: - h = linear_scan(a, normalized_x) + h = fused_linear_scan(x, x_rg_lru, a_rg_lru, self.rg_lru_a_param.float()) if self.use_cache: assert h is not None @@ -196,6 +196,8 @@ def forward( cache = layer_past else: + a, normalized_x = self.inference_prologue(x) + h = (a * layer_past.recc_state) + normalized_x layer_past.update_cache(h[:, -1:, :]) diff --git a/hawk/scan.py b/hawk/scan.py deleted file mode 100644 index 50ec997..0000000 --- a/hawk/scan.py +++ /dev/null @@ -1,234 +0,0 @@ -import warnings -from typing import Tuple - -import torch -import triton -import triton.language as tl -from torch.autograd import Function - - -# fmt: off -@triton.jit -def sequential_scan_fwd_kernel( - alpha_ptr,beta_ptr,hidden_ptr, - bs_stride,sq_stride, - num_context: tl.constexpr, - numel: tl.constexpr, - BLOCKSIZE: tl.constexpr -): - #fmt: on - - bs_pid = tl.program_id(0) - - alpha_ptr += bs_pid * bs_stride - beta_ptr += bs_pid * bs_stride - hidden_ptr += bs_pid * bs_stride - - offs = tl.arange(0, BLOCKSIZE) - mask = offs < numel - # compute h_0 outside loop - - hidden_t = tl.load(beta_ptr + offs, mask = mask).to(tl.float32) - - tl.store(hidden_ptr + offs, hidden_t.to(tl.bfloat16), mask = mask) - - for i in range(1, num_context): - beta_ptr += sq_stride - alpha_ptr += sq_stride - hidden_ptr += sq_stride - - alpha_t = tl.load(alpha_ptr + offs, mask = mask).to(tl.float32) - beta_t = tl.load(beta_ptr + offs, mask = mask).to(tl.float32) - - hidden_t = (alpha_t * hidden_t) + beta_t - - tl.store(hidden_ptr + offs, hidden_t.to(tl.bfloat16), mask = mask) - -#fmt: off -@triton.jit -def sequential_scan_bwd_kernel( - alpha_saved_ptr,h_saved_ptr,d_out_ptr, - d_alpha_ptr,d_beta_ptr, - bs_stride, sq_stride, - num_context: tl.constexpr, - numel: tl.constexpr, - BLOCKSIZE: tl.constexpr -): - #fmt: on - bs_pid = tl.program_id(0) - - - # offset ptrs to correct batch start - alpha_saved_ptr += (bs_pid * bs_stride) + ((num_context)*sq_stride) - h_saved_ptr += (bs_pid * bs_stride) + ((num_context -2)*sq_stride) - d_out_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) - - d_alpha_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) - d_beta_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) - - offs = tl.arange(0, BLOCKSIZE) - - mask = offs < numel - - # compute (t = T) outside loop - h_grad = tl.load(d_out_ptr + offs, mask=mask).to(tl.float32) - h_rec = tl.load(h_saved_ptr + offs, mask=mask).to(tl.float32) - - d_alpha = h_grad*h_rec - d_beta = h_grad - - tl.store(d_alpha_ptr + offs, d_alpha.to(tl.bfloat16),mask=mask) - tl.store(d_beta_ptr + offs, d_beta.to(tl.bfloat16),mask=mask) - - for _ in range(2, num_context): - # reduce pointer offsets - d_alpha_ptr -= sq_stride - d_beta_ptr -= sq_stride - h_saved_ptr -= sq_stride - d_out_ptr -= sq_stride - alpha_saved_ptr -= sq_stride - - alpha = tl.load(alpha_saved_ptr + offs,mask=mask).to(tl.float32) - grad_out = tl.load(d_out_ptr + offs,mask=mask).to(tl.float32) - h_rec = tl.load(h_saved_ptr + offs,mask=mask).to(tl.float32) - - - h_grad = alpha * h_grad - h_grad += grad_out - - d_alpha = h_grad * h_rec - d_beta = h_grad - - tl.store(d_alpha_ptr + offs, d_alpha.to(tl.bfloat16),mask=mask) - tl.store(d_beta_ptr + offs, d_beta.to(tl.bfloat16),mask=mask) - - - # first grad (t = 0) - d_alpha_ptr -= sq_stride - d_beta_ptr -= sq_stride - d_out_ptr -= sq_stride - alpha_saved_ptr -= sq_stride - - alpha = tl.load(alpha_saved_ptr + offs,mask=mask).to(tl.float32) - grad_out = tl.load(d_out_ptr + offs,mask=mask).to(tl.float32) - - h_grad = alpha * h_grad - h_grad += grad_out - d_beta = h_grad - - d_alpha = tl.zeros_like(d_beta).to(tl.float32) - - tl.store(d_alpha_ptr + offs, d_alpha.to(tl.bfloat16),mask=mask) - tl.store(d_beta_ptr + offs, d_beta.to(tl.bfloat16),mask=mask) - -def sequential_scan_forward( - alpha: torch.Tensor, # [b,sq,d] - beta: torch.Tensor, # [b,sq,d] -) -> torch.Tensor: - """Computes forward pass of a linear scan.""" - - hidden = torch.empty_like(alpha) - - b, sq, d = alpha.shape - - BLOCKSIZE = triton.next_power_of_2(d) - - grid = (b,) - - match d: - case _ if d <= 256: - warps = 1 - - case _ if d <= 512: - warps = 2 - - case _ if d <= 1024: - warps = 4 - - case _ : - warps = 8 - - - #fmt: off - sequential_scan_fwd_kernel[grid]( - alpha,beta,hidden, - alpha.stride(0),alpha.stride(1), - sq,d,BLOCKSIZE, - num_warps = warps, # type: ignore - num_stages = 2 # type: ignore - ) - #fmt: on - return hidden - - -def sequential_scan_backward( - alpha_saved: torch.Tensor, # [b,sq,d] - h_saved: torch.Tensor, # [b,sq,d] - grad_out: torch.Tensor, # [b,sq,d] -) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes backward pass of a linear scan.""" - - alpha_grad = torch.empty_like(alpha_saved) - beta_grad = torch.empty_like(alpha_saved) - - b, sq, d = alpha_saved.shape - - BLOCKSIZE = triton.next_power_of_2(d) - - grid = (b,) - - match d: - case _ if d <= 256: - warps = 1 - - case _ if d <= 512: - warps = 2 - - case _ if d <= 1024: - warps = 4 - - case _ : - warps = 8 - - # semi-cryptic errors if tensors not contiguous - assert alpha_saved.is_contiguous() - assert h_saved.is_contiguous() - - if not grad_out.is_contiguous(): - grad_out = grad_out.contiguous() - warnings.warn("`grad_out` tensor is not contiguous. Setting to contiguous and attempting to continue. This may impact runtime.") - assert grad_out.is_contiguous() - - #fmt: off - sequential_scan_bwd_kernel[grid]( - alpha_saved,h_saved, grad_out, - alpha_grad, beta_grad, - alpha_saved.stride(0), alpha_saved.stride(1), - sq, d, BLOCKSIZE, - num_warps = warps, # type: ignore - num_stages = 2 # type: ignore - ) - #fmt: on - - return alpha_grad, beta_grad - -class DiagonalRecurrence(Function): - @staticmethod - @torch.amp.custom_fwd(device_type = 'cuda', cast_inputs=torch.bfloat16) # type: ignore - def forward(ctx, input_alpha, input_beta) -> torch.Tensor: - h = sequential_scan_forward(input_alpha, input_beta) - ctx.save_for_backward(input_alpha, h) - - return h - - @staticmethod - @torch.amp.custom_bwd(device_type = 'cuda') # type: ignore - def backward(ctx, grad_output) -> Tuple[torch.Tensor, torch.Tensor, None]: # type: ignore - (input_alpha, h) = ctx.saved_tensors - - alpha_grad, beta_grad = sequential_scan_backward(input_alpha, h, grad_output) - - return alpha_grad, beta_grad, None - - -linear_scan = DiagonalRecurrence.apply diff --git a/hawk/scan_fused.py b/hawk/scan_fused.py new file mode 100644 index 0000000..bb0d99f --- /dev/null +++ b/hawk/scan_fused.py @@ -0,0 +1,429 @@ +import warnings +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.autograd import Function + + +@triton.jit +def clipped_sqrt_grad(x, MAX_SQRT_GRADIENT=1000): + min_value = 1.0 / (MAX_SQRT_GRADIENT * MAX_SQRT_GRADIENT) + scaled_x = 4.0 * x + clipped_value = tl.where(scaled_x > min_value, scaled_x, min_value) + return 1.0 / tl.sqrt(clipped_value) + + +@triton.jit +def softplus_fwd(x): + return tl.log(1 + tl.exp(x)) + + +@triton.jit +def softplus_bwd(x): + return tl.exp(x) / (1 + tl.exp(x)) + + +@triton.jit +def sigmoid_bwd(x): + return tl.sigmoid(x) * (1 - tl.sigmoid(x)) + + +# fmt: off +@triton.jit +def sequential_scan_fwd_kernel( + x_ptr, x_rg_lru_ptr, a_rg_lru_ptr, a_param_ptr, + hidden_ptr, + bs_stride,sq_stride, + num_context: tl.constexpr, + BLOCKSIZE: tl.constexpr +): + #fmt: on + + C = -8.0 + + bs_pid = tl.program_id(0) + + channel_pid = tl.program_id(1) + + offs = tl.arange(0, BLOCKSIZE) + + x_ptr += bs_pid * bs_stride + (channel_pid*BLOCKSIZE) + x_rg_lru_ptr += bs_pid * bs_stride + (channel_pid*BLOCKSIZE) + a_rg_lru_ptr += bs_pid * bs_stride + (channel_pid*BLOCKSIZE) + hidden_ptr += bs_pid * bs_stride + (channel_pid*BLOCKSIZE) + + a_param = tl.load(a_param_ptr + (channel_pid*BLOCKSIZE) + offs, ).to(tl.float32) + + # compute first hidden state + + x = tl.load(x_ptr + offs, ).to(tl.float32) + x_rg_lru = tl.load(x_rg_lru_ptr + offs, ).to(tl.float32) + a_rg_lru = tl.load(a_rg_lru_ptr + offs, ).to(tl.float32) + + x_rg_lru = tl.sigmoid(x_rg_lru) + a_rg_lru = tl.sigmoid(a_rg_lru) + + log_a = C * a_rg_lru * softplus_fwd(a_param) + a_square = tl.exp(2 * log_a) + + multiplier = tl.sqrt(1 - a_square) + + gated_x = x * x_rg_lru + + beta_t = gated_x * multiplier + + # compute h_0 outside loop + + hidden_t = beta_t + + tl.store(hidden_ptr + offs, hidden_t.to(tl.bfloat16), ) + + for i in range(1, num_context): + hidden_ptr += sq_stride + x_ptr += sq_stride + x_rg_lru_ptr += sq_stride + a_rg_lru_ptr += sq_stride + + x = tl.load(x_ptr + offs, ).to(tl.float32) + x_rg_lru = tl.load(x_rg_lru_ptr + offs, ).to(tl.float32) + a_rg_lru = tl.load(a_rg_lru_ptr + offs, ).to(tl.float32) + + x_rg_lru = tl.sigmoid(x_rg_lru) + a_rg_lru = tl.sigmoid(a_rg_lru) + + log_a = C * a_rg_lru * softplus_fwd(a_param) + alpha_t = tl.exp(log_a) + + a_square = tl.exp(2 * log_a) + + multiplier = tl.sqrt(1 - a_square) + + gated_x = x * x_rg_lru + + beta_t = gated_x * multiplier + + hidden_t = (alpha_t * hidden_t) + beta_t + + tl.store(hidden_ptr + offs, hidden_t.to(tl.bfloat16), ) + +#fmt: off +@triton.jit +def sequential_scan_bwd_kernel( + x_ptr, x_rg_lru_ptr, a_rg_lru_ptr, a_param_ptr, h_saved_ptr, + d_out_ptr, + dx_ptr, dx_rg_lru_ptr, da_rg_lru_ptr, da_param_ptr, + bs_stride, sq_stride, + aparam_bs_stride, + num_context: tl.constexpr, + BLOCKSIZE: tl.constexpr +): + + C = -8.0 + + #fmt: on + bs_pid = tl.program_id(0) + + channel_pid = tl.program_id(1) + + a_param_batched_grad = tl.zeros([BLOCKSIZE], dtype = tl.float32) + + + h_saved_ptr += (bs_pid * bs_stride) + ((num_context -2)*sq_stride) + (channel_pid*BLOCKSIZE) + + # offset ptrs to correct batch start + d_out_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + + dx_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + dx_rg_lru_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + da_rg_lru_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + da_param_ptr += (bs_pid * aparam_bs_stride) + (channel_pid*BLOCKSIZE) + + + x_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + x_rg_lru_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + a_rg_lru_ptr += (bs_pid * bs_stride) + ((num_context-1)*sq_stride) + (channel_pid*BLOCKSIZE) + + offs = tl.arange(0, BLOCKSIZE) + + a_param = tl.load(a_param_ptr + (channel_pid*BLOCKSIZE) + offs, ).to(tl.float32) + + # compute (t = T) outside loop + h_grad = tl.load(d_out_ptr + offs).to(tl.float32) + h_rec = tl.load(h_saved_ptr + offs).to(tl.float32) + + x = tl.load(x_ptr + offs).to(tl.float32) + x_rg_lru = tl.load(x_rg_lru_ptr + offs).to(tl.float32) + a_rg_lru = tl.load(a_rg_lru_ptr + offs).to(tl.float32) + + d_alpha = h_grad*h_rec + d_beta = h_grad + + log_a = ( + C + * tl.sigmoid(a_rg_lru) + * softplus_fwd(a_param) + ) + + a_square_T = tl.exp(2 * log_a) + + i_T = tl.sigmoid(x_rg_lru) + + multiplier_T = tl.sqrt(1 - a_square_T) + + dlog_a = d_alpha * tl.exp(log_a) + + sqrt_grad = clipped_sqrt_grad(1 - a_square_T) + extra_term = -2.0 * a_square_T + + dlog_a += ( + d_beta + * (i_T * x) + *sqrt_grad * extra_term + ) + + a_rg_lru_grad = sigmoid_bwd(a_rg_lru) * dlog_a * C * softplus_fwd(a_param) + + x_grad = d_beta * multiplier_T * i_T + + x_rg_lru_grad = d_beta * multiplier_T * x * sigmoid_bwd(x_rg_lru) + + a_param_batched_grad += dlog_a * C * tl.sigmoid(a_rg_lru) * softplus_bwd(a_param) + + tl.store(dx_ptr + offs, x_grad.to(tl.bfloat16), ) + tl.store(dx_rg_lru_ptr+ offs, x_rg_lru_grad.to(tl.bfloat16), ) + tl.store(da_rg_lru_ptr+ offs, a_rg_lru_grad.to(tl.bfloat16), ) + + + for _ in range(2, num_context): + # reduce pointer offsets + dx_ptr -= sq_stride + dx_rg_lru_ptr -= sq_stride + da_rg_lru_ptr -= sq_stride + + x_ptr -= sq_stride + x_rg_lru_ptr -= sq_stride + a_rg_lru_ptr -= sq_stride + + + h_saved_ptr -= sq_stride + d_out_ptr -= sq_stride + + a = tl.exp(log_a) + + h_grad = a * h_grad + + grad_out = tl.load(d_out_ptr + offs,).to(tl.float32) + h_rec = tl.load(h_saved_ptr + offs, ).to(tl.float32) + h_grad += grad_out + + d_alpha = h_grad * h_rec + d_beta = h_grad + + x = tl.load(x_ptr + offs,).to(tl.float32) + x_rg_lru = tl.load(x_rg_lru_ptr + offs,).to(tl.float32) + a_rg_lru = tl.load(a_rg_lru_ptr + offs,).to(tl.float32) + + log_a = ( + C + * tl.sigmoid(a_rg_lru) + * softplus_fwd(a_param) + ) + + a_square_t = tl.exp(2 * log_a) + + i_t = tl.sigmoid(x_rg_lru) + + multiplier_t = tl.sqrt(1 - a_square_t) + + x_grad = d_beta * multiplier_t * i_t + + x_rg_lru_grad = d_beta * multiplier_t * x * sigmoid_bwd(x_rg_lru) + + dlog_a = d_alpha * tl.exp(log_a) + sqrt_grad = clipped_sqrt_grad(1 - a_square_t) + extra_term = -2.0 * a_square_t + + dlog_a += ( + d_beta + * (i_t * x) + *sqrt_grad * extra_term + ) + + a_rg_lru_grad = sigmoid_bwd(a_rg_lru) * dlog_a * C * softplus_fwd(a_param) + + a_param_batched_grad += dlog_a * C * tl.sigmoid(a_rg_lru) * softplus_bwd(a_param) + + tl.store(dx_ptr + offs, x_grad.to(tl.bfloat16), ) + tl.store(dx_rg_lru_ptr + offs, x_rg_lru_grad.to(tl.bfloat16), ) + tl.store(da_rg_lru_ptr + offs, a_rg_lru_grad.to(tl.bfloat16), ) + + dx_ptr -= sq_stride + dx_rg_lru_ptr -= sq_stride + da_rg_lru_ptr -= sq_stride + + x_ptr -= sq_stride + x_rg_lru_ptr -= sq_stride + a_rg_lru_ptr -= sq_stride + + + d_out_ptr -= sq_stride + + a = tl.exp(log_a) + h_grad = a * h_grad + grad_out = tl.load(d_out_ptr + offs).to(tl.float32) + h_grad += grad_out + + d_beta = h_grad + + x = tl.load(x_ptr + offs,).to(tl.float32) + x_rg_lru = tl.load(x_rg_lru_ptr + offs,).to(tl.float32) + a_rg_lru = tl.load(a_rg_lru_ptr + offs,).to(tl.float32) + + log_a = ( + C + * tl.sigmoid(a_rg_lru) + * softplus_fwd(a_param) + ) + + a_square_t = tl.exp(2 * log_a) + + i_t = tl.sigmoid(x_rg_lru) + + multiplier_t = tl.sqrt(1 - a_square_t) + + x_grad = d_beta * multiplier_t * i_t + + x_rg_lru_grad = d_beta * multiplier_t * x * sigmoid_bwd(x_rg_lru) + + sqrt_grad = clipped_sqrt_grad(1 - a_square_t) + extra_term = -2.0 * a_square_t + + dlog_a = ( + d_beta + * (i_t * x) + *sqrt_grad * extra_term + ) + + a_rg_lru_grad = sigmoid_bwd(a_rg_lru) * dlog_a * C * softplus_fwd(a_param) + + a_param_batched_grad += dlog_a * C * tl.sigmoid(a_rg_lru) * softplus_bwd(a_param) + + tl.store(dx_ptr + offs, x_grad.to(tl.bfloat16)) + tl.store(dx_rg_lru_ptr + offs, x_rg_lru_grad.to(tl.bfloat16)) + tl.store(da_rg_lru_ptr + offs, a_rg_lru_grad.to(tl.bfloat16)) + + tl.store(da_param_ptr + offs,a_param_batched_grad.to(tl.float32), ) + +def sequential_scan_forward( + x: torch.Tensor, + x_rg_lru: torch.Tensor, + a_rg_lru: torch.Tensor, # [b,sq,d] + a_param: torch.Tensor, # [d] +) -> torch.Tensor: + """Computes forward pass of a linear scan.""" + + hidden = torch.empty_like(x) + + b, sq, d = x.shape + + BLOCKSIZE = 64 + + assert d % BLOCKSIZE == 0, f"Error: expected model dimension to be a multiple of {BLOCKSIZE}" + + num_blocks = d // BLOCKSIZE + + grid = (b,num_blocks) + + warps = 1 + + #fmt: off + sequential_scan_fwd_kernel[grid]( + x,x_rg_lru,a_rg_lru, a_param, + hidden, + x.stride(0),x.stride(1), + sq,BLOCKSIZE, + num_warps = warps, # type: ignore + num_stages = 2 # type: ignore + ) + #fmt: on + return hidden + + +def sequential_scan_backward( + x_saved, # [b,sq,d] + x_rg_lru_saved, # [b,sq,d] + a_rg_lru_saved, # [b,sq,d] + a_param_saved, # [d] + h_saved, # [b,sq,d] + grad_out: torch.Tensor, # [b,sq,d] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes backward pass of a linear scan.""" + + b,l,d = grad_out.shape + + x_grad = torch.empty_like(x_saved) + x_rg_lru_grad = torch.empty_like(x_saved) + a_rg_lru_grad = torch.empty_like(x_saved) + a_param_batched = torch.empty((b, d), device=x_saved.device, dtype = a_param_saved.dtype) + + b, sq, d = h_saved.shape + + BLOCKSIZE = 64 + + assert d % BLOCKSIZE == 0, f"Error: expected model dimension to be a multiple of {BLOCKSIZE}" + + num_blocks = d // BLOCKSIZE + + grid = (b,num_blocks) + + warps = 1 + + # semi-cryptic errors if tensors not contiguous + assert x_saved.is_contiguous() + assert x_rg_lru_saved.is_contiguous() + assert a_rg_lru_saved.is_contiguous() + assert a_param_saved.is_contiguous() + assert h_saved.is_contiguous() + + if not grad_out.is_contiguous(): + grad_out = grad_out.contiguous() + warnings.warn("`grad_out` tensor is not contiguous. Setting to contiguous and attempting to continue. This may impact runtime.") + assert grad_out.is_contiguous() + + #fmt: off + sequential_scan_bwd_kernel[grid]( + x_saved, x_rg_lru_saved, a_rg_lru_saved, a_param_saved, h_saved, + grad_out, + x_grad, x_rg_lru_grad, a_rg_lru_grad, a_param_batched, + x_saved.stride(0), x_saved.stride(1), + a_param_batched.stride(0), + sq, BLOCKSIZE, + num_warps = warps, # type: ignore + num_stages = 2 # type: ignore + ) + #fmt: on + + return x_grad, x_rg_lru_grad, a_rg_lru_grad, a_param_batched.sum(dim = 0) + +class DiagonalRecurrence(Function): + @staticmethod + @torch.amp.custom_fwd(device_type = 'cuda', cast_inputs=torch.bfloat16) # type: ignore + def forward(ctx, x,x_rg_lru,a_rg_lru,a_param) -> torch.Tensor: + h = sequential_scan_forward(x,x_rg_lru,a_rg_lru,a_param) + ctx.save_for_backward(x,x_rg_lru,a_rg_lru,a_param,h) + + return h + + @staticmethod + @torch.amp.custom_bwd(device_type = 'cuda') # type: ignore + def backward(ctx, grad_output) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore + (x,x_rg_lru,a_rg_lru,a_param,h_saved) = ctx.saved_tensors + + x_grad, x_rg_lru_grad, a_rg_lru_grad, a_param_grad = sequential_scan_backward(x,x_rg_lru,a_rg_lru,a_param, h_saved, grad_output) + + return x_grad, x_rg_lru_grad, a_rg_lru_grad, a_param_grad + + +fused_linear_scan = DiagonalRecurrence.apply diff --git a/pyproject.toml b/pyproject.toml index 3a175a3..3150960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "hawk-pytorch" authors = [{ name = "Benjamin Fattori", email = "fattoribenjamin@gmail.com" }] -version = "1.0.0" +version = "1.1.0" description = "PyTorch implementation of Hawk" license = { file = "LICENSE" } dependencies = [