From 6dccc8d18a68538d0fd0a2794ff390b9a81a5f62 Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Mon, 19 Feb 2024 10:22:49 -0800 Subject: [PATCH 1/5] added opt-builder script --- opt-builder.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 opt-builder.py diff --git a/opt-builder.py b/opt-builder.py new file mode 100644 index 000000000..136e8c86a --- /dev/null +++ b/opt-builder.py @@ -0,0 +1,101 @@ +from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM +import safetensors +from iree.compiler.ir import Context +import torch +import shark_turbine.aot as aot +from shark_turbine.aot import * + +class HFTransformerBuilder: + """ + A model builder that uses Hugging Face's transformers library to build a PyTorch model. + + Args: + example_input (torch.Tensor): An example input tensor to the model. + hf_id (str): The Hugging Face model ID. + auto_model (AutoModel): The AutoModel class to use for loading the model. + auto_tokenizer (AutoTokenizer): The AutoTokenizer class to use for loading the tokenizer. + auto_config (AutoConfig): The AutoConfig class to use for loading the model configuration. + """ + + def __init__( + self, + example_input: torch.Tensor, + hf_id: str, + auto_model: AutoModel = AutoModelForCausalLM, + auto_tokenizer: AutoTokenizer = AutoTokenizer, + auto_config: AutoConfig = None, + hf_auth_token="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", + ) -> None: + self.example_input = example_input + self.hf_id = hf_id + self.auto_model = auto_model + self.auto_tokenizer = auto_tokenizer + self.auto_config = auto_config + self.hf_auth_token = hf_auth_token + self.model = None + self.tokenizer = None + self.build_model() + + def build_model(self) -> None: + """ + Builds a PyTorch model using Hugging Face's transformers library. + """ + # TODO: check cloud storage for existing ir + self.model = self.auto_model.from_pretrained( + self.hf_id, token=self.hf_auth_token, torch_dtype=torch.float, trust_remote_code=True + ) + #if self.auto_tokenizer is not None: + # self.tokenizer = self.auto_tokenizer.from_pretrained( + # self.hf_id, token=self.hf_auth_token, use_fast=False + # ) + #else: + self.tokenizer = None + + def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: + """ + Compiles the PyTorch model into a compiled module using SHARK-Turbine's AOT compiler. + + Args: + save_to (str): one of: input (Torch IR) or import (linalg). + + Returns: + aot.CompiledModule: The compiled module binary. + """ + module = aot.export(self.model, self.example_input) + compiled_binary = module.compile(save_to=save_to) + return compiled_binary + + +if __name__ == "__main__": + import sys + hf_id = sys.argv[-1] + safe_name = hf_id.replace("/", "_").replace("-", "_") + inp = torch.zeros(1, 1, dtype=torch.int64) + model = HFTransformerBuilder(inp, hf_id) + mapper=dict() + mod_params = dict(model.model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name +# safetensors.torch.save_file(mod_params, safe_name+".safetensors") + class GlobalModule(CompiledModule): + params = export_parameters(model.model, external=True, external_scope="",) + compute = jittable(model.model.forward) + + def run(self, x=aot.AbstractTensor(1, None, dtype=torch.int64)): + return self.compute(x, constraints=[ + x.dynamic_dim(1),] + ) + + def run_not(self, x=abstractify(inp)): + return self.compute(x) + + print("module defined") + inst = GlobalModule(context=Context()) + print("module inst") + module = CompiledModule.get_mlir_module(inst) +# compiled = module.compile() + print("got mlir module") + with open(safe_name+".mlir", "w+") as f: + f.write(str(module)) + + print("done") \ No newline at end of file From 1f40422b6de3bcb79c93e8c07dd144b9724a3fc7 Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Mon, 19 Feb 2024 13:59:00 -0800 Subject: [PATCH 2/5] fixed empty faketensor issue --- core/shark_turbine/aot/builtins/jittable.py | 6 ++++++ opt-builder.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index d2c85b73f..de5f9db46 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -206,6 +206,12 @@ def flat_wrapped_f(*args): if "functorch_functionalize" in self._passes: transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) + for node in transformed_f.graph.nodes: + if node.op == "call_function": + if node.target == torch._ops.ops.aten.lift_fresh_copy.default: + node.target = torch._ops.ops.aten.clone.default + transformed_f.recompile() + # Ask dynamo to give us an aten graph. # TODO: Cache this for repeated calls. logger.debug("Performing dynamo.export(constraints=%r)", constraints) diff --git a/opt-builder.py b/opt-builder.py index 136e8c86a..f0b65520a 100644 --- a/opt-builder.py +++ b/opt-builder.py @@ -86,8 +86,8 @@ def run(self, x=aot.AbstractTensor(1, None, dtype=torch.int64)): x.dynamic_dim(1),] ) - def run_not(self, x=abstractify(inp)): - return self.compute(x) + #def run(self, x=abstractify(inp)): + # return self.compute(x) print("module defined") inst = GlobalModule(context=Context()) From c1932d109135e5ca3d6245d65dfb007a13c94b74 Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Mon, 19 Feb 2024 23:18:16 -0800 Subject: [PATCH 3/5] removed demo script to prepare for PR --- opt-builder.py | 101 ------------------------------------------------- 1 file changed, 101 deletions(-) delete mode 100644 opt-builder.py diff --git a/opt-builder.py b/opt-builder.py deleted file mode 100644 index f0b65520a..000000000 --- a/opt-builder.py +++ /dev/null @@ -1,101 +0,0 @@ -from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM -import safetensors -from iree.compiler.ir import Context -import torch -import shark_turbine.aot as aot -from shark_turbine.aot import * - -class HFTransformerBuilder: - """ - A model builder that uses Hugging Face's transformers library to build a PyTorch model. - - Args: - example_input (torch.Tensor): An example input tensor to the model. - hf_id (str): The Hugging Face model ID. - auto_model (AutoModel): The AutoModel class to use for loading the model. - auto_tokenizer (AutoTokenizer): The AutoTokenizer class to use for loading the tokenizer. - auto_config (AutoConfig): The AutoConfig class to use for loading the model configuration. - """ - - def __init__( - self, - example_input: torch.Tensor, - hf_id: str, - auto_model: AutoModel = AutoModelForCausalLM, - auto_tokenizer: AutoTokenizer = AutoTokenizer, - auto_config: AutoConfig = None, - hf_auth_token="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", - ) -> None: - self.example_input = example_input - self.hf_id = hf_id - self.auto_model = auto_model - self.auto_tokenizer = auto_tokenizer - self.auto_config = auto_config - self.hf_auth_token = hf_auth_token - self.model = None - self.tokenizer = None - self.build_model() - - def build_model(self) -> None: - """ - Builds a PyTorch model using Hugging Face's transformers library. - """ - # TODO: check cloud storage for existing ir - self.model = self.auto_model.from_pretrained( - self.hf_id, token=self.hf_auth_token, torch_dtype=torch.float, trust_remote_code=True - ) - #if self.auto_tokenizer is not None: - # self.tokenizer = self.auto_tokenizer.from_pretrained( - # self.hf_id, token=self.hf_auth_token, use_fast=False - # ) - #else: - self.tokenizer = None - - def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: - """ - Compiles the PyTorch model into a compiled module using SHARK-Turbine's AOT compiler. - - Args: - save_to (str): one of: input (Torch IR) or import (linalg). - - Returns: - aot.CompiledModule: The compiled module binary. - """ - module = aot.export(self.model, self.example_input) - compiled_binary = module.compile(save_to=save_to) - return compiled_binary - - -if __name__ == "__main__": - import sys - hf_id = sys.argv[-1] - safe_name = hf_id.replace("/", "_").replace("-", "_") - inp = torch.zeros(1, 1, dtype=torch.int64) - model = HFTransformerBuilder(inp, hf_id) - mapper=dict() - mod_params = dict(model.model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name -# safetensors.torch.save_file(mod_params, safe_name+".safetensors") - class GlobalModule(CompiledModule): - params = export_parameters(model.model, external=True, external_scope="",) - compute = jittable(model.model.forward) - - def run(self, x=aot.AbstractTensor(1, None, dtype=torch.int64)): - return self.compute(x, constraints=[ - x.dynamic_dim(1),] - ) - - #def run(self, x=abstractify(inp)): - # return self.compute(x) - - print("module defined") - inst = GlobalModule(context=Context()) - print("module inst") - module = CompiledModule.get_mlir_module(inst) -# compiled = module.compile() - print("got mlir module") - with open(safe_name+".mlir", "w+") as f: - f.write(str(module)) - - print("done") \ No newline at end of file From d21fbcbe1aa2eb89c39dbf438a57469b26428bb8 Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Tue, 20 Feb 2024 10:52:47 -0800 Subject: [PATCH 4/5] formatting --- core/shark_turbine/aot/builtins/jittable.py | 74 ++++++++++++++++----- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index de5f9db46..f81bc57ba 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -7,7 +7,17 @@ """Tracing builtins.""" -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import torch from torch._decomp import get_decompositions @@ -97,7 +107,9 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]: # legal). Note that the merger will ignore these since they already # exist in the target module. if materialized_global.symbol_name not in cloned_global_symbols: - materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip) + materialized_global.global_op.operation.clone( + ip=gni.fx_importer._m_ip + ) cloned_global_symbols.add(materialized_global.symbol_name) # Emit a global load and conversion. @@ -156,7 +168,9 @@ def __init__( self.constraints = constraints self.decomposition_table = decomposition_table self.wrapped_f = wrapped_f - self.function_name = function_name if function_name else wrapped_f.__name__ + self.function_name = ( + function_name if function_name else wrapped_f.__name__ + ) self._passes = set(passes) for p in passes: if p not in ALL_PASSES: @@ -186,7 +200,9 @@ def resolve_call( flat_pytorch_args = [] flat_ir_args = [] for py_arg in flat_py_args: - ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints) + ir_arg, pytorch_arg = self._split_py_arg( + py_arg, constraints=constraints + ) flat_ir_args.append(ir_arg) flat_pytorch_args.append(pytorch_arg) @@ -204,14 +220,16 @@ def flat_wrapped_f(*args): # Run pre-processing passes. transformed_f = flat_wrapped_f if "functorch_functionalize" in self._passes: - transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) + transformed_f = functorch_functionalize( + transformed_f, *flat_pytorch_args + ) for node in transformed_f.graph.nodes: if node.op == "call_function": if node.target == torch._ops.ops.aten.lift_fresh_copy.default: node.target = torch._ops.ops.aten.clone.default transformed_f.recompile() - + # Ask dynamo to give us an aten graph. # TODO: Cache this for repeated calls. logger.debug("Performing dynamo.export(constraints=%r)", constraints) @@ -240,10 +258,14 @@ def flat_wrapped_f(*args): fx_importer = FxImporter( context=proc_trace.context, config_check=False, - literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder), + literal_resolver_callback=_make_literal_resolver( + proc_trace.module_builder + ), py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker, ) - fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name) + fx_importer.import_stateless_graph( + gm.graph, func_name=self.function_name + ) # TODO: Real debugging options # print(fx_importer.module, file=sys.stderr) @@ -300,11 +322,17 @@ def flat_wrapped_f(*args): assert len(flat_ir_results) == len(result_tensor_infos) flat_py_results = [] - for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos): + for ir_result, result_tensor_info in zip( + flat_ir_results, result_tensor_infos + ): (dtype,) = result_tensor_info - native_ir_result = type_converter.materialize_torch_to_native(ir_result) + native_ir_result = type_converter.materialize_torch_to_native( + ir_result + ) if dtype is not None: - flat_py_results.append(IrImmediateTensor(native_ir_result, dtype)) + flat_py_results.append( + IrImmediateTensor(native_ir_result, dtype) + ) else: raise TypeError( f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}" @@ -313,7 +341,9 @@ def flat_wrapped_f(*args): tree_py_results = tree_unflatten(flat_py_results, out_spec) return tree_py_results - def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]: + def _split_py_arg( + self, arg, constraints: List[Constraint] + ) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): meta_tensor, meta_constraints = arg._to_meta_tensor() constraints.extend(meta_constraints) @@ -358,7 +388,9 @@ def merge(self) -> Optional[Operation]: imported_func_op: Optional[Operation] = None # Import functions. - func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME) + func_ops = _get_top_level_ops( + self.from_module_op, func_d.FuncOp.OPERATION_NAME + ) for func_op in func_ops: # Pre-rename, check if it is the one we are looking for. func_name = _get_symbol_name(func_op) @@ -374,7 +406,9 @@ def merge(self) -> Optional[Operation]: for from_symbol, to_symbol in self.rename_map.items(): from_name = StringAttr(from_symbol).value to_name = StringAttr(to_symbol).value - SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation) + SymbolTable.replace_all_symbol_uses( + from_name, to_name, sym_operation + ) return imported_func_op @@ -384,7 +418,9 @@ def import_symbol_op(self, symbol_op): orig_symbol = SymbolTable.get_symbol_name(symbol_op) orig_symbol_name = StringAttr(orig_symbol).value # Make sure it is unique. - new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table) + new_symbol_name = _uniqueify_name( + orig_symbol_name, target_symbol_table + ) if new_symbol_name != orig_symbol_name: SymbolTable.set_symbol_name(symbol_op, new_symbol_name) self._rename(orig_symbol, new_symbol_name) @@ -393,7 +429,9 @@ def import_symbol_op(self, symbol_op): self.nested_symbol_ops.append(symbol_op) target_symbol_table.insert(symbol_op) - def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr): + def _rename( + self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr + ): from_symbol = self._make_string_attr(from_symbol) to_symbol = self._make_string_attr(to_symbol) if from_symbol != to_symbol: @@ -407,7 +445,9 @@ def _make_string_attr(self, string_attr_or_str: StringAttrOrStr): return StringAttr(string_attr_or_str) -def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]: +def _get_top_level_ops( + module_op: Operation, *op_names: str +) -> Sequence[Operation]: results = [] for op_view in module_op.regions[0].blocks[0]: op = op_view.operation From 685fd802612b2703c730fae2bb05ff64db35dbbf Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Tue, 20 Feb 2024 19:31:27 -0800 Subject: [PATCH 5/5] reverted format changes --- core/shark_turbine/aot/builtins/jittable.py | 60 ++++++--------------- 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index f81bc57ba..58c9fa790 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -107,9 +107,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]: # legal). Note that the merger will ignore these since they already # exist in the target module. if materialized_global.symbol_name not in cloned_global_symbols: - materialized_global.global_op.operation.clone( - ip=gni.fx_importer._m_ip - ) + materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip) cloned_global_symbols.add(materialized_global.symbol_name) # Emit a global load and conversion. @@ -168,9 +166,7 @@ def __init__( self.constraints = constraints self.decomposition_table = decomposition_table self.wrapped_f = wrapped_f - self.function_name = ( - function_name if function_name else wrapped_f.__name__ - ) + self.function_name = function_name if function_name else wrapped_f.__name__ self._passes = set(passes) for p in passes: if p not in ALL_PASSES: @@ -200,9 +196,7 @@ def resolve_call( flat_pytorch_args = [] flat_ir_args = [] for py_arg in flat_py_args: - ir_arg, pytorch_arg = self._split_py_arg( - py_arg, constraints=constraints - ) + ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints) flat_ir_args.append(ir_arg) flat_pytorch_args.append(pytorch_arg) @@ -220,9 +214,7 @@ def flat_wrapped_f(*args): # Run pre-processing passes. transformed_f = flat_wrapped_f if "functorch_functionalize" in self._passes: - transformed_f = functorch_functionalize( - transformed_f, *flat_pytorch_args - ) + transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) for node in transformed_f.graph.nodes: if node.op == "call_function": @@ -258,14 +250,10 @@ def flat_wrapped_f(*args): fx_importer = FxImporter( context=proc_trace.context, config_check=False, - literal_resolver_callback=_make_literal_resolver( - proc_trace.module_builder - ), + literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder), py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker, ) - fx_importer.import_stateless_graph( - gm.graph, func_name=self.function_name - ) + fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name) # TODO: Real debugging options # print(fx_importer.module, file=sys.stderr) @@ -322,17 +310,11 @@ def flat_wrapped_f(*args): assert len(flat_ir_results) == len(result_tensor_infos) flat_py_results = [] - for ir_result, result_tensor_info in zip( - flat_ir_results, result_tensor_infos - ): + for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos): (dtype,) = result_tensor_info - native_ir_result = type_converter.materialize_torch_to_native( - ir_result - ) + native_ir_result = type_converter.materialize_torch_to_native(ir_result) if dtype is not None: - flat_py_results.append( - IrImmediateTensor(native_ir_result, dtype) - ) + flat_py_results.append(IrImmediateTensor(native_ir_result, dtype)) else: raise TypeError( f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}" @@ -341,9 +323,7 @@ def flat_wrapped_f(*args): tree_py_results = tree_unflatten(flat_py_results, out_spec) return tree_py_results - def _split_py_arg( - self, arg, constraints: List[Constraint] - ) -> Tuple[Value, Any]: + def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): meta_tensor, meta_constraints = arg._to_meta_tensor() constraints.extend(meta_constraints) @@ -388,9 +368,7 @@ def merge(self) -> Optional[Operation]: imported_func_op: Optional[Operation] = None # Import functions. - func_ops = _get_top_level_ops( - self.from_module_op, func_d.FuncOp.OPERATION_NAME - ) + func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME) for func_op in func_ops: # Pre-rename, check if it is the one we are looking for. func_name = _get_symbol_name(func_op) @@ -406,9 +384,7 @@ def merge(self) -> Optional[Operation]: for from_symbol, to_symbol in self.rename_map.items(): from_name = StringAttr(from_symbol).value to_name = StringAttr(to_symbol).value - SymbolTable.replace_all_symbol_uses( - from_name, to_name, sym_operation - ) + SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation) return imported_func_op @@ -418,9 +394,7 @@ def import_symbol_op(self, symbol_op): orig_symbol = SymbolTable.get_symbol_name(symbol_op) orig_symbol_name = StringAttr(orig_symbol).value # Make sure it is unique. - new_symbol_name = _uniqueify_name( - orig_symbol_name, target_symbol_table - ) + new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table) if new_symbol_name != orig_symbol_name: SymbolTable.set_symbol_name(symbol_op, new_symbol_name) self._rename(orig_symbol, new_symbol_name) @@ -429,9 +403,7 @@ def import_symbol_op(self, symbol_op): self.nested_symbol_ops.append(symbol_op) target_symbol_table.insert(symbol_op) - def _rename( - self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr - ): + def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr): from_symbol = self._make_string_attr(from_symbol) to_symbol = self._make_string_attr(to_symbol) if from_symbol != to_symbol: @@ -445,9 +417,7 @@ def _make_string_attr(self, string_attr_or_str: StringAttrOrStr): return StringAttr(string_attr_or_str) -def _get_top_level_ops( - module_op: Operation, *op_names: str -) -> Sequence[Operation]: +def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]: results = [] for op_view in module_op.regions[0].blocks[0]: op = op_view.operation