Skip to content

Commit

Permalink
WIP: Mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
georgepaw committed Dec 6, 2023
1 parent 4d1af54 commit d9d7e61
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 18 deletions.
78 changes: 64 additions & 14 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ def get_param_from_node(
return None


def create_mpsgraph_constant_tensor(tensor: torch.Tensor, mpsGraph):
def create_mpsgraph_constant_tensor(tensor: torch.Tensor, mpsGraph, convert_model_to_fp16: bool):
dtype = get_mps_data_type(tensor.dtype)
if convert_model_to_fp16 and dtype == get_mps_data_type(torch.float32):
tensor = tensor.half()
dtype = get_mps_data_type(torch.float16)
if tensor.dim() == 0:
return mpsGraph.constant(tensor.item(), get_mps_data_type(tensor.dtype))
return mpsGraph.constant(tensor.item(), dtype)
else:
return mpsGraph.constantTensor(tensor, get_mps_data_type(tensor.dtype))
return mpsGraph.constantTensor(tensor, dtype)


@final
Expand Down Expand Up @@ -153,15 +157,58 @@ def preprocess( # noqa: C901
exir_ops.edge.aten.pow.Tensor_Scalar: mpsGraph.pow,
}

# `graph_nodes` dictionary is made out of <key> : <MPSGraphTensor*>
graphNodes: Dict[str, Any] = {}
# GraphNodesDict is dictionary structure mapping node to MPSGraph outputs.
# Each value is either a MPSGraphTensor* or a tuple(MPSGraphTensor*).
# This class can automatically insert casts requiered to convert a model
# from float32 to float16.
class GraphNodesDict(dict):
def __init__(self, convert_model_to_fp16=False):
self._convert_model_to_fp16 = convert_model_to_fp16

def get_node(self, key, cast_to_fp16=False):
value = dict.__getitem__(self, key)
if cast_to_fp16:
def handle(value):
current_data_type = mpsGraph.get_data_type(value)
if current_data_type == get_mps_data_type(torch.float32):
value = mpsGraph.cast_tensor(value, get_mps_data_type(torch.float16))
return value
value = tuple([handle(x) for x in value]) if isinstance(value, tuple) else handle(value)
return value

def set_node(self, key, value, cast_to_fp32=False):
if cast_to_fp32:
def handle(value):
current_data_type = mpsGraph.get_data_type(value)
if current_data_type == get_mps_data_type(torch.float16):
value = mpsGraph.cast_tensor(value, get_mps_data_type(torch.float32))
return value
value = tuple([handle(x) for x in value]) if isinstance(value, tuple) else handle(value)
dict.__setitem__(self, key, value)

def __getitem__(self, key):
return self.get_node(key, self._convert_model_to_fp16)

def __setitem__(self, key, value):
self.set_node(key, value, self._convert_model_to_fp16)

def __repr__(self):
return dict.__repr__(self)

# Check whether the model should be converted to fp16.
convert_model_to_fp16 = True
for spec in compile_specs:
if spec.key == "use_fp16":
convert_model_to_fp16 = bool(list(bytes(spec.value))[0])
graphNodes = GraphNodesDict(convert_model_to_fp16=convert_model_to_fp16)

for node in edge_program.graph.nodes:
if node.op == "get_attr":
attr = MPSBackend.fetch_attr(node, edge_program)
graphNodes[node.name] = create_mpsgraph_constant_tensor(
tensor=attr, mpsGraph=mpsGraph
)
tensor=attr,
mpsGraph=mpsGraph,
convert_model_to_fp16=convert_model_to_fp16)

# Handle inputs to the graph.
elif node.op == "placeholder":
Expand All @@ -170,20 +217,22 @@ def preprocess( # noqa: C901
lifted_param_or_buffer = get_param_from_node(node, edge_program)
if lifted_param_or_buffer is not None:
graphNodes[node.name] = create_mpsgraph_constant_tensor(
tensor=lifted_param_or_buffer, mpsGraph=mpsGraph
)
tensor=lifted_param_or_buffer,
mpsGraph=mpsGraph,
convert_model_to_fp16=convert_model_to_fp16)
else:
if node.meta["val"] is None:
continue
shape = MPSBackend.eval_shape(node.meta["val"])
# Call set_node explicitly to preserve the input signature.
if shape is None:
graphNodes[node.name] = mpsGraph.mpsGraphUnrankedPlaceHolder(
graphNodes.set_node(node.name, mpsGraph.mpsGraphUnrankedPlaceHolder(
get_mps_data_type(node.meta["val"].dtype)
)
))
else:
graphNodes[node.name] = mpsGraph.mpsGraphRankedPlaceHolder(
graphNodes.set_node(node.name, mpsGraph.mpsGraphRankedPlaceHolder(
get_mps_data_type(node.meta["val"].dtype), shape
)
))

# Handle `call_function` calls.
elif node.op == "call_function":
Expand Down Expand Up @@ -768,7 +817,8 @@ def preprocess( # noqa: C901
output_nodes = []
for i in range(len(node.args)):
for j in range(len(node.args[i])):
output_nodes.append(graphNodes[node.args[i][j].name])
# Call get_node explicitly to preserve the output signature.
output_nodes.append(graphNodes.get_node(node.args[i][j].name))
mpsGraph.set_outputs(*output_nodes)
else:
torch._assert(
Expand Down
5 changes: 5 additions & 0 deletions backends/apple/mps/operations/ShapeOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@
name:@"permutation"];
}

PyMPSGraphTensor*
MPSGraphModule::cast_tensor(MPSGraphTensor* inputTensor, MPSDataType dtype) {
return castMPSTensor(mpsGraph, inputTensor, dtype);
}

PyMPSGraphTensor*
MPSGraphModule::squeeze(MPSGraphTensor* inputTensor) {
return [mpsGraph squeezeTensor:inputTensor
Expand Down
5 changes: 4 additions & 1 deletion backends/apple/mps/test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from executorch.exir import ExirExportedProgram
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.tests.models import (
BasicSinMax,
CompositeDelegateModule,
Expand Down Expand Up @@ -77,6 +78,7 @@ def run_model(
model_type: MODEL_TYPE = MODEL_TYPE.EXIR_DEFAULT_MODEL,
dump_non_lowered_module: bool = False,
dump_lowered_module: bool = False,
use_fp16: bool = False,
):
logging.info(f"Step 1: Retrieving model: {model}...")
if model_type == MODEL_TYPE.EXIR_DEFAULT_MODEL:
Expand All @@ -100,7 +102,8 @@ def run_model(

# Step 3: Lower to MPSGraph
logging.info("Step 3: Lowering to MPSGraph...")
lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, [])
compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]
lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, compile_specs)

logging.info("Step 4: Capturing executorch program with lowered module...")

Expand Down
7 changes: 6 additions & 1 deletion backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from executorch.exir import ExecutorchProgram, ExirExportedProgram
from executorch.exir.backend.backend_api import to_backend, validation_disabled

from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.print_program import print_program
from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.sdk.bundled_program.core import create_bundled_program
Expand Down Expand Up @@ -122,6 +123,7 @@ def lower_module_and_test_output(
use_partitioner: bool = False,
dump_non_lowered_module: bool = False,
dump_lowered_module: bool = False,
use_fp16: bool = False,
) -> ExirExportedProgram:
"""
Helper testing function that takes a torch.nn.Module and lowers it to XNNPACK with
Expand Down Expand Up @@ -151,8 +153,9 @@ def forward(self, *args):
with validation_disabled():
None
else:
compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]
delegated_program = to_backend(
"MPSBackend", edge_program.exported_program, []
"MPSBackend", edge_program.exported_program, compile_specs
)

logging.info("Step 3: Capturing executorch program with lowered module...")
Expand Down Expand Up @@ -218,6 +221,7 @@ def lower_and_test_with_partitioner(
graph_module,
example_inputs,
func_name: str,
use_fp16: bool = False,
):
logging.info(func_name)
# MPS TODO: partitioner support
Expand All @@ -226,4 +230,5 @@ def lower_and_test_with_partitioner(
example_inputs,
use_partitioner=False,
func_name=func_name,
use_fp16=use_fp16,
)
6 changes: 6 additions & 0 deletions backends/apple/mps/utils/Bindings.mm
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@
.def("permute", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef axes) {
return self.permute(static_cast<MPSGraphTensor*>(inputTensor), axes);
})
.def("cast_tensor", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, MPSDataType dtype) {
return self.cast_tensor(static_cast<MPSGraphTensor*>(inputTensor), dtype);
})
.def("cumsum", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim) {
return self.cumsum(static_cast<MPSGraphTensor*>(inputTensor), dim);
})
Expand Down Expand Up @@ -316,6 +319,9 @@
.def("floor_divide", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor) {
return self.div_mode_template(static_cast<MPSGraphTensor*>(primaryTensor), static_cast<MPSGraphTensor*>(secondaryTensor), "floor", "floor_divide");
})
.def("get_data_type", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor) {
return self.getDataType(static_cast<MPSGraphTensor*>(inputTensor));
})

//
// Graph debug methods.
Expand Down
5 changes: 5 additions & 0 deletions backends/apple/mps/utils/MPSGraphInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class MPSGraphModule {
PyMPSGraphTensor* select(MPSGraphTensor* inputTensor, int dim, int index);
PyMPSGraphTensor* view(MPSGraphTensor* inputTensor, IntArrayRef shape);
PyMPSGraphTensor* permute(MPSGraphTensor* inputTensor, IntArrayRef axes);
PyMPSGraphTensor* cast_tensor(MPSGraphTensor* inputTensor, MPSDataType dtype);
PyMPSGraphTensor* cumsum(MPSGraphTensor* inputTensor, int dim);
PyMPSGraphTensor* addmm(
MPSGraphTensor* biasTensor,
Expand Down Expand Up @@ -234,6 +235,10 @@ class MPSGraphModule {
int64_t dim,
MPSGraphTensor* indexTensor);

MPSDataType getDataType(MPSGraphTensor* input) {
return [input dataType];
}

MPSGraph* getMPSGraph() {
return mpsGraph;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/apple/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ add_executable(
target_include_directories(
mps_executor_runner INTERFACE ${CMAKE_BINARY_DIR}/schema/include/
${EXECUTORCH_ROOT}/third-party/flatbuffers/include)
target_link_libraries(mps_executor_runner bundled_program program_schema
target_link_libraries(mps_executor_runner program_schema
${_executor_runner_libs}
${mps_executor_runner_libs})
target_compile_options(mps_executor_runner PUBLIC ${_common_compile_options})
Expand Down
10 changes: 9 additions & 1 deletion examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.backends.apple.mps.mps_preprocess import MPSBackend

from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.sdk.bundled_program.core import create_bundled_program
from executorch.sdk.bundled_program.serialize import (
Expand All @@ -37,6 +38,12 @@
help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
)

parser.add_argument(
"--use_fp16",
default=True,
action=argparse.BooleanOptionalAction,
help="Whether to automatically convert float32 operations to float16 operations.")

parser.add_argument(
"-b",
"--bundled",
Expand Down Expand Up @@ -64,7 +71,8 @@
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
logging.info(f"Exported graph:\n{edge.exported_program.graph}")

lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, [])
compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))]
lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, compile_specs)

logging.info(f"Lowered graph:\n{edge.exported_program.graph}")

Expand Down

0 comments on commit d9d7e61

Please sign in to comment.