Skip to content

Commit

Permalink
fixed cuda graph demo
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Jul 16, 2024
1 parent f13ca20 commit 559a366
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions notebooks/aqlm_cuda_graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"%%capture\n",
"!pip install aqlm[gpu]>=1.1.0\n",
"!pip install accelerate>=0.27.0\n",
"!pip install git+https://github.com/huggingface/transformers.git@main"
"!pip install transformers>=4.41.0"
]
},
{
Expand Down Expand Up @@ -210,12 +210,16 @@
"source": [
"import torch\n",
"\n",
"def decode_one_tokens(model, cur_token, input_pos, cache_position):\n",
"def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):\n",
" logits = model(\n",
" cur_token, position_ids=None, cache_position=cache_position, return_dict=False, use_cache=True\n",
" cur_token,\n",
" position_ids=input_pos,\n",
" cache_position=cache_position,\n",
" past_key_values=past_key_values,\n",
" return_dict=False,\n",
" use_cache=True\n",
" )[0]\n",
" new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)\n",
"\n",
" new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]\n",
" return new_token\n",
"\n",
"MAX_NEW_TOKENS = 128"
Expand All @@ -242,7 +246,14 @@
"\n",
"input_ids = tokenizer(\"I'm AQLM, \", return_tensors=\"pt\").to(\"cuda\")[\"input_ids\"]\n",
"seq_length = input_ids.shape[1]\n",
"quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + MAX_NEW_TOKENS * 2 + 1)"
"\n",
"past_key_values = StaticCache(\n",
" quantized_model.config,\n",
" 1,\n",
" seq_length + MAX_NEW_TOKENS * 2 + 1,\n",
" quantized_model.device,\n",
" quantized_model.dtype\n",
")"
]
},
{
Expand Down Expand Up @@ -284,7 +295,9 @@
},
"outputs": [],
"source": [
"logits = quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]\n",
"logits = quantized_model(\n",
" input_ids, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True\n",
")[0]\n",
"next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)\n",
"generated_ids[:, [seq_length]] = next_token"
]
Expand Down Expand Up @@ -314,8 +327,8 @@
" cache_position = torch.tensor([seq_length + 1], device=\"cuda\")\n",
" for _ in range(1, MAX_NEW_TOKENS):\n",
" with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n",
" next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)\n",
" generated_ids.index_copy_(1, cache_position, next_token)\n",
" next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n",
" generated_ids[:, cache_position] = next_token.int()\n",
" cache_position += 1"
]
},
Expand Down Expand Up @@ -363,8 +376,8 @@
"with torch.no_grad():\n",
" for _ in range(MAX_NEW_TOKENS):\n",
" with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n",
" next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)\n",
" generated_ids.index_copy_(1, cache_position, next_token)\n",
" next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n",
" generated_ids[:, cache_position] = next_token.int()\n",
" cache_position += 1\n",
"end = time.perf_counter()"
]
Expand Down

0 comments on commit 559a366

Please sign in to comment.