forked from juncongmoo/pyllama
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquant_infer.py
29 lines (20 loc) · 885 Bytes
/
quant_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import hiq, time
from hiq.memory import total_gpu_memory_mb, get_memory_mb
def main():
driver = hiq.HiQLatency(
hiq_table_or_path=[
["llama.llama_infer", "", "run", "run_quant"],
["llama.llama_infer", "AutoTokenizer", "from_pretrained", "from_pretrained"],
["transformers.models.llama.tokenization_llama", "LLaMATokenizer", "encode", "encode"],
["llama.llama_infer", "", "load_quant", "load_quant"],
["llama.hf.modeling_llama","LLaMAForCausalLM","generate","generate"]
],
metric_funcs=[time.time, total_gpu_memory_mb, get_memory_mb],
# extra_metrics={hiq.ExtraMetrics.ARGS},
)
args = hiq.mod("llama.llama_infer").get_args()
hiq.mod("llama.llama_infer").run(args)
print("*" * 30, "GPU/CPU/Latency Profiling", "*" * 30)
driver.show()
if __name__ == "__main__":
main()