Skip to content

Commit

Permalink
Decouple custom ops in llama_transformer.py Part 2/N (#3007)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3007

Keep llama_transformer.py to look like stock implementation, so that it can be reused everywhere.

Do module swap

Differential Revision: D56048640
  • Loading branch information
mergennachin authored and facebook-github-bot committed Apr 12, 2024
1 parent 19e2a3b commit 3b0c813
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 57 deletions.
2 changes: 1 addition & 1 deletion examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
],
)

Expand Down Expand Up @@ -86,6 +85,7 @@ runtime.python_library(
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
"//executorch/examples/portable:utils",
"//executorch/exir:lib",
"//executorch/sdk/etrecord:etrecord",
Expand Down
61 changes: 58 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
XnnpackDynamicallyQuantizedPartitioner,
)

from executorch.examples.models.llama2.llama_transformer import Transformer
from executorch.examples.models.llama2.llama_transformer import (
KVCache,
SDPA,
Transformer,
)
from executorch.exir.backend.backend_details import CompileSpec

from executorch.sdk.etrecord import generate_etrecord
Expand Down Expand Up @@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis(
return module


class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: KVCache,
mask,
dim: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.dim = dim

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
):
from .custom_ops import sdpa_with_kv_cache # noqa

output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
return output.view(bsz, seqlen, self.dim)


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACustom(child.kv_cache, child.mask, child.dim),
)
else:
_replace_sdpa_with_custom_op(child)


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
_replace_sdpa_with_custom_op(module)
return module


def quantize(
model: torch.nn.Module,
qmode: str,
Expand Down Expand Up @@ -493,8 +549,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
pass
# TODO: Next diff transforms.append()
transforms.append(replace_sdpa_with_custom_op)

return (
load_llama_model(
Expand Down
53 changes: 0 additions & 53 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,12 @@ def __init__(
self,
kv_cache: KVCache,
mask,
use_sdpa_with_kv_cache_op: bool,
dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
self.dim = dim
self.n_rep = n_rep

Expand All @@ -233,56 +231,6 @@ def forward(
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
if not self.use_sdpa_with_kv_cache_op:
return self._forward_default(
input_pos,
q,
k,
v,
bsz,
seqlen,
)
else:
return self._forward_custom(
input_pos,
q,
k,
v,
bsz,
seqlen,
)

def _forward_custom(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
):
from .custom_ops import sdpa_with_kv_cache # noqa

output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
return output.view(bsz, seqlen, self.dim)

def _forward_default(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
Expand Down Expand Up @@ -341,7 +289,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.SDPA = SDPA(
self.kv_cache,
self.mask,
args.use_sdpa_with_kv_cache_op,
self.dim,
self.n_rep,
)
Expand Down

0 comments on commit 3b0c813

Please sign in to comment.