Skip to content

Commit

Permalink
use shape check in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Nov 24, 2023
1 parent 082e824 commit 758b670
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ exclude = (?x)(
)
strict = True
[mypy-torch.*]
follow_imports = skip
[mypy-sympy.*]
follow_imports = skip
158 changes: 151 additions & 7 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
from functools import partial
import copy
import collections
import torch
import torch.fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
Expand All @@ -8,9 +10,16 @@
import torch._dynamo.backends.torchxla
import torch.fx.immutable_collections as fx_immutable
from torch._dispatch.python import enable_python_dispatcher
from torch import SymInt, SymFloat, SymBool
from torch.fx.experimental.symbolic_shapes import Symbol
from sympy.printing.str import StrPrinter
import sympy
from .no_preload import NO_LD_PRELOAD_CTX
from . import config
from .utils import ScalarType
from .pycode_generator import GuardFnCodegen
from .store_pos import StorePos, StoreNegate, StoreInAttr, StoreInIndex
from . import variables as vs

BaseArgumentTypes = Union[
str,
Expand Down Expand Up @@ -41,6 +50,48 @@ def backend_compile(gm: torch.fx.GraphModule,
raise RuntimeError(f"Unknown backend: {backend}")


def guard_check_shapeenv(inputs: list[torch.Tensor], fake_inputs: list[Any],
shape_env: ShapeEnv) -> bool:
symbol2value: dict[Symbol, Any] = {}
for fake_input, input in zip(fake_inputs, inputs):
if isinstance(fake_input, torch._subclasses.FakeTensor):
assert isinstance(input, torch.Tensor)
if len(input.shape) != len(fake_input.shape):
return False
for symbol, value in zip(fake_input.shape, input.shape):
expr = symbol.node.expr
if expr in symbol2value:
if symbol2value[expr] != value:
print("false due to shape", fake_input.shape,
input.shape)
print("symbol2value", symbol2value[expr])
return False
else:
symbol2value[expr] = value
else:
raise NotImplementedError
for guard in shape_env.guards:
val = guard.expr.subs(symbol2value)
if not (val is sympy.true):
print("guard fail", guard, symbol2value)
return False
return True


class ShapeGuardPrinter(StrPrinter): # type: ignore[misc]

def __init__(self, symbol_to_source: Dict[Symbol, list[StorePos]]):
super().__init__()
self.symbol_to_source = symbol_to_source

def _print_Symbol(self, expr: Symbol) -> str:
assert isinstance(expr, Symbol), str(type(expr))
assert expr in self.symbol_to_source, (
f"{expr} (could be from {[s.name() for s in expr.sources]}) "
f"not in {self.symbol_to_source}")
return str(self.symbol_to_source[expr][0])


class FxGraph:
root: torch.nn.Module
result_graph: torch.fx.Graph
Expand Down Expand Up @@ -69,11 +120,6 @@ def wrap_fake_exception(fn: Callable[[], Any]) -> Any:
msg = f"Unsupported: {e.reason} with fake tensor propagation."
raise NotImplementedError(msg) from e

def deepcopy_to_fake_tensor(
obj: Any, fake_mode: torch._subclasses.FakeTensorMode) -> Any:
with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode):
return wrap_fake_exception(lambda: copy.deepcopy(obj))

def as_fake_args_kwargs(
args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Tuple[Any, Dict[str, Any]]:
Expand All @@ -82,6 +128,9 @@ def as_fake(arg: Any) -> Any:
if isinstance(arg, (tuple, list)):
return fx_immutable.immutable_list(
[as_fake(x) for x in arg])
if isinstance(arg, slice):
return slice(as_fake(arg.start), as_fake(arg.stop),
as_fake(arg.step))
if isinstance(arg, torch.fx.Node):
return arg.meta["fake"]
else:
Expand All @@ -108,7 +157,8 @@ def fetch_attr(target: str) -> Any:
assert op not in ("placeholder", "output")
if op == "get_attr":
with self.fake_mode, enable_python_dispatcher():
fake = fetch_attr(node.target)
param = fetch_attr(node.target)
fake = self.fake_mode.from_tensor(param, static_shapes=True)
elif op == "call_function":
with self.fake_mode, enable_python_dispatcher():
fake = node.target(*fake_args, **fake_kwargs)
Expand All @@ -118,7 +168,8 @@ def fetch_attr(target: str) -> Any:
**fake_kwargs)
elif op == "call_module":
module = fetch_attr(node.target)
fake_module = deepcopy_to_fake_tensor(module, self.fake_mode)
with torch._subclasses.fake_tensor.FakeCopyMode(self.fake_mode):
fake_module = wrap_fake_exception(lambda: copy.deepcopy(module))
with self.fake_mode, enable_python_dispatcher():
fake = fake_module(*fake_args, **fake_kwargs)
else:
Expand Down Expand Up @@ -194,12 +245,105 @@ def compile(
for x in self.example_inputs
])
assert callable(compiled_fn)
if self.fake_mode.shape_env is not None:
print("shape_env guards", self.fake_mode.shape_env.format_guards())
# TODO: add backend compiler
return compiled_fn

def get_inputs(self) -> list[torch.fx.Node]:
return [x for x in self.result_graph.nodes if x.op == "placeholder"]

def make_shape_env_guard(self, codegen: GuardFnCodegen) -> None:
fake_inputs: list[torch.FakeTensor] = []
poses: list[StorePos] = []
for node in self.result_graph.nodes:
if node.op == "placeholder":
fake = node.meta["fake"]
fake_inputs.append(fake)
var = node.meta["var"]
assert isinstance(var, (vs.TensorVar, vs.ScalarVar))
pos = var.extract_code_at_start[0]
poses.append(pos)
self.produce_guards(fake_inputs, poses, codegen)

# modified from torch produce_guards
def produce_guards(self, placeholders: list[Any], sources: list[StorePos],
codegen: GuardFnCodegen) -> None:
import math
import operator
SYMPY_INTERP = {
'Eq': operator.eq,
'Ne': operator.ne,
'Gt': operator.gt,
'Lt': operator.lt,
'Le': operator.le,
'Ge': operator.ge,
'Min': min,
'Max': max,
'Mod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'floor': math.floor,
'ceiling': math.ceil,
}
for k, v in SYMPY_INTERP.items():
codegen.add_obj(v, k, force=True)
input_guards = []
symbol_to_source = collections.defaultdict(list)

def track_symint(source: StorePos, val: Any) -> None:
if isinstance(val, SymInt):
s = val.node.expr

if isinstance(s, sympy.Symbol):
symbol_to_source[s].append(source)
elif isinstance(-s, sympy.Symbol):
symbol_to_source[-s].append(StoreNegate(source))

input_guards.append((source, s))
else:
input_guards.append((source, sympy.Integer(val)))

for t, source in zip(placeholders, sources):
assert isinstance(source, StorePos)
if t is None:
continue
if isinstance(t, SymInt):
track_symint(source, t)
continue
assert isinstance(t, torch.Tensor)
for i, s in enumerate(t.size()):
track_symint(
StoreInIndex(StoreInAttr(source, 0, 'size()'), 0, i), s)

for source, expr in input_guards:
# Small optimization
if (isinstance(expr, Symbol) and expr in symbol_to_source and
source == symbol_to_source[expr][0]):
continue
sexpr = ShapeGuardPrinter(symbol_to_source).doprint(expr)
codegen.add_check(f"{source} == {sexpr}")

for g, tb in self.fake_mode.shape_env.guards:
print("guard", g)
if self.fake_mode.shape_env._maybe_evaluate_static(g) is not None:
print("maybe static")
continue
print("before simplify", g)
g = self.fake_mode.shape_env.simplify(g)
print("after simplify", g)
try:
codegen.add_check(
ShapeGuardPrinter(symbol_to_source).doprint(g))
except Exception:
print(f"Failing guard allocated at: \n{tb}")
raise

for sources in symbol_to_source.values():
assert sources
codegen.add_check(f"{sources[0]} != 0")
codegen.add_check(f"{sources[0]} != 1")


frame_root: dict[int, torch.nn.Module] = {}

Expand Down
24 changes: 20 additions & 4 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class DeferRestartState:

deberta_model = None

the_first_input = None


class State:
objects: ObjectTable
Expand Down Expand Up @@ -164,6 +166,9 @@ def get_common_device(arg: Any) -> None:
def as_fx_node(arg: Any) -> NodeArgs:
if isinstance(arg, (tuple, list)):
return fx_immutable.immutable_list([as_fx_node(x) for x in arg])
if isinstance(arg, slice):
return slice(as_fx_node(arg.start), as_fx_node(arg.stop),
as_fx_node(arg.step))
var = self.objects.get(arg,
allow_unexist_const=True,
fx_graph=self.fx_graph)
Expand Down Expand Up @@ -229,7 +234,8 @@ def record_function(self,
position = i
else:
node = obj
if scalar is not None and node is not None:
if scalar is not None and node is not None and not config.get_config(
'dynshape'):
fx_node = self.fx_graph.create_node(
"call_function",
torch.full_like,
Expand Down Expand Up @@ -991,6 +997,8 @@ def commit(self) -> None:
while var.prev is not None:
var = var.prev
var.make_guard(guard_codegen)
if config.get_config('dynshape'):
self.state.fx_graph.make_shape_env_guard(guard_codegen)
guard_code = guard_codegen.get_code()
graph_codegen = GraphFnCodegen(key=key)
for node in self.state.fx_graph.result_graph.nodes:
Expand Down Expand Up @@ -1036,9 +1044,9 @@ def commit(self) -> None:

self.state.fx_graph.set_output_nodes(
graph_codegen.get_graph_outputs())
print("graph input",
[(name, x)
for x, name in self.state.fx_graph.example_inputs])
print("graph input", [
(name, x) for x, name in self.state.fx_graph.example_inputs
])
print("graph", self.state.fx_graph.result_graph)
graph_code = graph_codegen.get_code()
compiled_graph = self.state.fx_graph.compile()
Expand Down Expand Up @@ -1130,6 +1138,14 @@ def process_last_inst(self) -> None:
partial_make_var_fn: Optional[
MAKE_VAR_FN_TYPE] = partial.make_var_fn
make_var_fn: MAKE_VAR_FN_TYPE = partial_make_var_fn if partial_make_var_fn is not None else default_make_var_fn
if isinstance(value, bool) and config.get_config(
"dynshape") and node is not None:
fake = node.meta["fake"]
if isinstance(fake, torch.SymBool):
fake_bool = fake.node.expr
import sympy
if fake_bool is sympy.true or fake_bool is sympy.false: # not a dynamic value
node = None
if isinstance(value, torch.Tensor):
if isinstance(value, torch.nn.Parameter):
var = make_var_fn(value, partial.need_guard_check,
Expand Down
11 changes: 5 additions & 6 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.fx
from .pycode_writer import PyCodeWriter, new_name, is_valid_name
from .store_pos import StorePos
from .variables import Variable
from .config import get_config


Expand All @@ -25,22 +24,22 @@ def __init__(self, key: int) -> None:
self.imports = set()
self.objs = {}

def add_obj(self, var: Any, name: str = "", force: bool = False) -> str:
def add_obj(self, obj: Any, name: str = "", force: bool = False) -> str:
if force:
assert name != ""
assert is_valid_name(name)
if name in self.objs:
assert self.objs[name] == var
assert self.objs[name] == obj
else:
self.objs[name] = var
self.objs[name] = obj
return name
else:
if name == "" or not is_valid_name(name):
name = new_name("var")
name = new_name("obj")
elif name in self.objs:
name = new_name(name)

self.objs[name] = var
self.objs[name] = obj
return name

def add_import(self, module_name: str) -> None:
Expand Down
16 changes: 15 additions & 1 deletion frontend/store_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def get_value_from_frame(self, frame: FrameType) -> Any:
self.self_pos.get_value_from_frame(frame))[self.self_index]


class StoreNegate(StorePos):
pos: StorePos
neg_id: int

def __init__(self, pos: StorePos) -> None:
self.pos = pos

def __repr__(self) -> str:
return f"-({self.pos})"

def get_value_from_frame(self, frame: FrameType) -> Any:
return -self.pos.get_value_from_frame(frame)


class ExtractFromMethod(StorePos):
self_pos: StorePos
self_id: int
Expand Down Expand Up @@ -241,4 +255,4 @@ def __init__(self) -> None:
super().__init__()

def __repr__(self) -> str:
return "@__unknown_pos_in_caller__"
return "@__unknown_pos_in_caller__"
Loading

0 comments on commit 758b670

Please sign in to comment.