Skip to content

Commit

Permalink
Better device_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ncassereau committed Dec 8, 2021
1 parent c507d3b commit d02f71a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def bitsize(self, type_as):
return torch.finfo(type_as.dtype).bits

def device_type(self, type_as):
return "CPU" if "cpu" in str(type_as.device) else "GPU"
return type_as.device.type.replace("cuda", "gpu").upper()

def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
Expand Down Expand Up @@ -2337,7 +2337,7 @@ def bitsize(self, type_as):
return type_as.dtype.size * 8

def device_type(self, type_as):
return "CPU" if "CPU" in type_as.device else "GPU"
return self.dtype_device(type_as)[1].split(":")[0]

def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
Expand Down

0 comments on commit d02f71a

Please sign in to comment.