Skip to content

Commit

Permalink
[Docs] More clarifications on BT + FA (huggingface#25823)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored and blbadger committed Nov 8, 2023
1 parent 6b348d9 commit 759c31d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to("cuda")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()

Expand All @@ -99,6 +99,8 @@ try using the PyTorch nightly version, which may have a broader coverage for Fla
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```

Or make sure your model is correctly casted in float16 or bfloat16


Have a look at [this detailed blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to read more about what is possible to do with `BetterTransformer` + SDPA API.

Expand Down Expand Up @@ -270,4 +272,4 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
```

0 comments on commit 759c31d

Please sign in to comment.