Skip to content

Commit

Permalink
tiny bug fix; blocksparse documentation fix (#344)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Yuan <christopheryuan@learnfair1481.h2.fair>
  • Loading branch information
yuanandonly and Chris Yuan authored Jun 30, 2022
1 parent 12e8abc commit 7fdb90d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,6 @@ _Note_: The estimated flops currently miss accounting for many operators, and ar
FP16 | FP32
:-------------------------:|:-------------------------:
![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW_fp16_Blocksize128.png) | ![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW_fp32_Blocksize128.png)
![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW+BW_fp16_Blocksize128.png) | ![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW+BW_fp32_Blocksize128.png)
![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp16_Blocksize128.png) | ![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp32_Blocksize128.png)
![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp16_Blocksize128.png) | ![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp32_Blocksize128.png)
![fw+bw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW+BW_fp16_Blocksize128.png) | ![fw+bw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Runtime_FW+BW_fp32_Blocksize128.png)
![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW_fp16_Blocksize128.png) | ![fw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW_fp32_Blocksize128.png)
![fw+bw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp16_Blocksize128.png) | ![fw+bw](docs/plots/causal_attention_blocksparse/Causal_Blocksparse_Memory_FW+BW_fp32_Blocksize128.png)
6 changes: 3 additions & 3 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

# Lazy init (helps with pickling)
if self.activation is None:
if self.activation is None or self.activation_pytorch is None:
self.activation = get_triton_activation_kernel(self.activation_type)
self.pytorch_activation = build_activation(self.activation_type)
self.activation_pytorch = build_activation(self.activation_type)
self.activation_grad = get_triton_activation_bwd_kernel(
self.activation_type
)
Expand All @@ -255,7 +255,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda or not perf_check or p == 0.0:
x = x + self.bias if self.bias is not None else x
x = self.pytorch_activation(x)
x = self.activation_pytorch(x)
return torch.nn.functional.dropout(x, p) if p > 0.0 else x

# The normal, Triton-backed path
Expand Down

0 comments on commit 7fdb90d

Please sign in to comment.