-
Notifications
You must be signed in to change notification settings - Fork 51
/
Quantize.py
77 lines (66 loc) · 2.11 KB
/
Quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import copy
import numpy as np
import torch
import memtorch
import memtorch_bindings
quant_methods = ["linear", "log"]
def quantize(
tensor,
quant,
overflow_rate=0.0,
quant_method=None,
min=float("nan"),
max=float("nan"),
override_original=False,
):
"""Method to quantize a tensor.
Parameters
----------
tensor : torch.Tensor
Input tensor.
quant : int
Bit width (if quant_method is not None) or the number of discrete quantization levels (if quant_method is None).
overflow_rate : float, optional
Overflow rate threshold for linear quantization.
quant_method : str, optional
Quantization method. Must be in quant_methods.
min : float or tensor, optional
Minimum value(s) to clip numbers to.
max : float or tensor, optional
Maximum value(s) to clip numbers to.
override_original : bool, optional
Whether to override the original tensor (True) or not (False).
Returns
-------
torch.Tensor
Quantized tensor.
"""
device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
assert (
overflow_rate >= 0 and overflow_rate <= 1
), "overflow_rate must be >= 0 and <= 1."
assert (
type(quant) == int and quant > 0
), "The bit width or number of discrete quantization levels must be a positive integer."
if type(min) == int:
min = float(min)
if type(max) == int:
max = float(max)
if not override_original:
tensor = copy.deepcopy(tensor)
tensor = tensor.cpu()
if tensor.dtype != torch.float32 and tensor.dtype != torch.float64:
tensor = tensor.float()
if quant_method is not None:
assert quant_method in quant_methods, "quant_method is invalid."
memtorch_bindings.quantize(
tensor,
bits=quant,
overflow_rate=overflow_rate,
quant_method=quant_methods.index(quant_method),
min=min,
max=max,
)
else:
memtorch_bindings.quantize(tensor, n_quant_levels=quant, min=min, max=max)
return tensor.to(device)