Skip to content

Commit

Permalink
FIX SAM for bfloat16 (#1764)
Browse files Browse the repository at this point in the history
Summary:
Ok this was kinda annoying

Basically the SAM codebase had a few places where it hardcodes `torch.float32` such that even if you convert the model to `torch.bfloat16` a few parts of the model won't be and will have type mismatch errors - this fixes the problem cpuhrsch desertfire - idk enough about floats and why there isn't some type promotion rule for bfloat16

I wonder whether we should add tests for multiple dtypes in torchbench to make checking for this kind of issue more robust especially now that bfloat16 seems to be the default for dynamo xuzhao9

## Logs

```
FAILED (errors=1)
(sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda"
E
======================================================================
ERROR: test_sam_eval_cuda (__main__.TestBenchmark)
----------------------------------------------------------------------
components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last):
  File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block
    exec(  # noqa: P204
  File "<subprocess-worker>", line 2, in <module>
  File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 280, in invoke
    out = self.eval()
  File "/home/ubuntu/benchmark/torchbenchmark/models/sam/__init__.py", line 65, in eval
    masks, scores, logits = predictor.predict(
  File "/home/ubuntu/benchmark/torchbenchmark/models/sam/predictor.py", line 164, in predict
    low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16

    working_dir: /tmp/tmpg5de41du
    stdout:
        [2023-07-13] 01:57:38.499061: TIMER_SUBPROCESS_BEGIN_EXEC
        [2023-07-13] 01:57:39.002078: TIMER_SUBPROCESS_FAILED
        [2023-07-13] 01:57:39.002141: TIMER_SUBPROCESS_FINISHED
        [2023-07-13] 01:57:39.002153: TIMER_SUBPROCESS_BEGIN_READ

    stderr:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/benchmark/test.py", line 104, in eval_fn
    task.invoke()
  File "/home/ubuntu/benchmark/torchbenchmark/__init__.py", line 402, in invoke
    self.worker.run("""
  File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 155, in run
    self._run(snippet)
  File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 320, in _run
    subprocess_rpc.SerializedException.raise_from(
  File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 458, in raise_from
    raise e from ChildTraceException(traceback_str)
TypeError: Got unsupported ScalarType BFloat16

----------------------------------------------------------------------
Ran 1 test in 7.814s

FAILED (errors=1)
(sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda"
.
----------------------------------------------------------------------
Ran 1 test in 8.315s

OK
```

Pull Request resolved: #1764

Reviewed By: drisspg, cpuhrsch

Differential Revision: D47441873

Pulled By: msaroufim

fbshipit-source-id: a60880fd7c0826cfd469ace39d76894469ca0e5e
  • Loading branch information
msaroufim authored and facebook-github-bot committed Jul 14, 2023
1 parent 2ea018e commit 745644f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion torchbenchmark/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_module(self):
]

multimask_output = False

return self.model, (example_input, multimask_output)

def train(self):
Expand All @@ -57,6 +56,9 @@ def train(self):
return NotImplementedError(error_msg)

def eval(self):
# To test for bfloat16 uncomment the below line
# predictor = SamPredictor(self.model.to(dtype=torch.bfloat16))

predictor = SamPredictor(self.model)

predictor.set_image(self.image)
Expand Down
1 change: 1 addition & 0 deletions torchbenchmark/models/sam/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def predict_masks(
b, c, h, w = src.shape

# Run the transformer
tokens = tokens.to(src.dtype)
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/models/sam/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def predict(
)

masks_np = masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].to(torch.float32).detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].to(torch.float32).detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np

@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions torchbenchmark/models/sam/prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)

coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
Expand Down

0 comments on commit 745644f

Please sign in to comment.