Skip to content

Commit

Permalink
Remove debug (apache#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Nov 4, 2023
2 parents 9677f42 + ff7890d commit a16a341
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 10 deletions.
1 change: 0 additions & 1 deletion frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _fn(*args: Any, **kwargs: Any) -> Any:
print("exception in _fn:", e, type(e))
raise e
finally:
print("restoring frame, prior =", prior)
set_eval_frame(prior)

return _fn
Expand Down
1 change: 1 addition & 0 deletions frontend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

CONFIG = {
"backend": "inductor", # Union[str, Callable[..., Any]]
"debug": True,
}


Expand Down
4 changes: 4 additions & 0 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.fx
import torch._inductor.compile_fx
import torch._dynamo.backends.torchxla
from .utils import NO_LD_PRELOAD_CTX
from . import config

Expand All @@ -27,6 +28,9 @@ def backend_compile(gm: torch.fx.GraphModule,
return gm
elif backend == 'inductor':
return torch._inductor.compile_fx.compile_fx(gm, example_inputs)
elif backend == 'xla':
return torch._dynamo.backends.torchxla.aot_torchxla_trace_once(
gm, example_inputs)
else:
raise RuntimeError(f"Unknown backend: {backend}")

Expand Down
14 changes: 9 additions & 5 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .pycode_writer import PyCodeWriter, new_name, is_valid_name
from .store_pos import StorePos
from .variables import Variable
from .config import get_config


def gen_imports(writer: PyCodeWriter, imports: set[str]) -> None:
Expand Down Expand Up @@ -83,8 +84,9 @@ def get_code(self) -> str:
gen_imports(writer, self.imports)
writer.wl(f"def fn(locals):")
writer.block_start()
writer.wl(
f"print('running graph_fn (key = {self.key})', locals.keys())")
if get_config('debug'):
writer.wl(
f"print('running graph_fn (key = {self.key})', locals.keys())")
# TODO: simplify
graph_inputs = []
for x, to_tensor in self.graph_inputs:
Expand Down Expand Up @@ -149,15 +151,17 @@ def get_code(self) -> str:
writer.block_start()
writer.write(f"try:")
writer.block_start()
writer.wl(
f"print('running guard_fn (key = {self.key})', locals.keys())")
if get_config('debug'):
writer.wl(
f"print('running guard_fn (key = {self.key})', locals.keys())")
writer.write(self.writer.get_code())
if len(self.checks) == 0:
writer.wl(f"ok = True")
else:
writer.wl(
f"ok = {' and '.join(map(lambda x: f'({x})', self.checks))}")
writer.wl(f"print('ok = ', ok)")
if get_config('debug'):
writer.wl(f"print('ok = ', ok)")
writer.block_end()
writer.wl(f"except Exception as e:")
writer.block_start()
Expand Down
13 changes: 9 additions & 4 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .bytecode_writter import rewrite_bytecode
from .code import ProcessedCode
from .instruction import format_insts
from .config import get_config


def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]:
Expand Down Expand Up @@ -73,12 +74,15 @@ def get_process_frame(
f: Callable[..., Any],
is_callee: bool) -> Tuple[Callable[..., Any], Callable[..., Any]]:

is_debug = get_config('debug')

def preprocess_frame(
frame: FrameType, frame_id: int
) -> Tuple[CodeType, Callable[..., Any], ProcessedCode]:
try:
print(f"preprocess frame {frame.f_code.co_filename}", frame_id,
hex(id(frame)), frame.f_code.co_name)
if is_debug:
print(f"preprocess frame {frame.f_code.co_filename}", frame_id,
hex(id(frame)), frame.f_code.co_name)
enable_cache(frame_id)
set_frame_root(frame_id, f)
if check_cache_updated(frame_id):
Expand All @@ -89,13 +93,14 @@ def preprocess_frame(
trace_func = get_trace_func(frame_id)

else:
print("old bytecode: \n")
old_frame = get_frame_cache(frame_id)
assert old_frame.code_map is not None, "Code map doesn't exist for frame id {}".format(
frame_id)
assert old_frame.new_code is not None, "New code doesn't exist for frame id {}".format(
frame_id)
print(format_insts(old_frame.code_map.guard_insts))
if is_debug:
print("old bytecode: \n")
print(format_insts(old_frame.code_map.guard_insts))
new_code = old_frame.new_code
code_map = old_frame.code_map
trace_func = get_trace_func(frame_id)
Expand Down

0 comments on commit a16a341

Please sign in to comment.