Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (9/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447729

fbshipit-source-id: fc781322b254f7027c24888cdadd5f1e90325ba8
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 2966e38 commit 75352ad
Show file tree
Hide file tree
Showing 44 changed files with 152 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ def call( # noqa: suprress function is too complex (13)
# if user is in to_dim_op_set, it means the user's arg is already set to_dim op
if user not in to_dim_op_set:
user_new_arg = [
input_node_map[user_arg]
if user_arg in input_node_map
else user_arg
(
input_node_map[user_arg]
if user_arg in input_node_map
else user_arg
)
for user_arg in user.args
]
# Update input node's users arg
Expand Down
2 changes: 1 addition & 1 deletion backends/transforms/duplicate_dynamic_quant_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _replicate_chose_qparam_nodes_for_q_dq(
)
q_dq_pair.append((user, dq_node))

for (q_node, dq_node) in q_dq_pair:
for q_node, dq_node in q_dq_pair:
with gm.graph.inserting_after(get_item_node_1):
new_get_item_node_1 = gm.graph.node_copy(get_item_node_1)
new_get_item_node_2 = gm.graph.node_copy(get_item_node_2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult


# TODO(T151254305) use subgraph_rewriter
class ChannelsLastTaggedReshapePass(XNNPACKPass):
"""
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from executorch.exir.pass_manager import PassType


### XNNPACK Configs ###
def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig:
return exir.EdgeCompileConfig(
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_param,
)


### XNNPACK Capture ###
def capture_graph_for_xnnpack(
module: torch.nn.Module,
Expand Down
6 changes: 3 additions & 3 deletions codegen/tools/test/test_gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None:
self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])

def test_get_kernel_metadata_from_ops_yaml(self) -> None:
metadata: Dict[
str, List[str]
] = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
metadata: Dict[str, List[str]] = (
gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
)

self.assertEqual(len(metadata), 2)

Expand Down
6 changes: 3 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# to_backend
partitioners = {}
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
partitioners[
XnnpackDynamicallyQuantizedPartitioner.__name__
] = XnnpackDynamicallyQuantizedPartitioner()
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
XnnpackDynamicallyQuantizedPartitioner()
)
modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack:
Expand Down
1 change: 1 addition & 0 deletions examples/xtensa/aot/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def quantize_tensor_multiplier(
result = RoundingRightShift(FixedPointMultiplication(int32_value,
out_multiplier[i]), right_shift[i])
"""

# This is identical to C++11 std::round(). The general python round rounds
# down, and C++ rounds away from zero.
def round_away_zero(f) -> int:
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/backend_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def enforcedmethod(func):
@dataclass
class PreprocessResult:
processed_bytes: bytes = bytes()
debug_handle_map: Optional[
Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
] = None
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
None
)


"""
Expand Down
1 change: 1 addition & 0 deletions exir/backend/test/backend_with_delegate_mapping_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch import nn
from torch.export.exported_program import ExportedProgram


# A simple way to represent an op along with its delegate debug identifier.
class DummyOp:
def __init__(
Expand Down
6 changes: 5 additions & 1 deletion exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,11 @@ def forward(self, x_raw, h, c):
composite_m = CompositeModel(3)
orig_res = composite_m(*inputs)

traced = exir.capture(composite_m, inputs, exir.CaptureConfig(),).to_edge(
traced = exir.capture(
composite_m,
inputs,
exir.CaptureConfig(),
).to_edge(
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
)
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def __init__(self, generated_identifiers: bool = False):

# Note that the internal struct has a Set value, while the getter
# function returns the values as a tuple
self._debug_handle_map: Union[
Dict[int, Set[int]], Dict[str, Set[int]]
] = defaultdict(set)
self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
defaultdict(set)
)
self._next_index: int = 0

def get_delegate_mapping(
Expand Down
16 changes: 10 additions & 6 deletions exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def capture_multiple(
"forward",
m.forward,
args,
dynamic_shapes["forward"]
if dynamic_shapes and "forward" in dynamic_shapes
else None,
(
dynamic_shapes["forward"]
if dynamic_shapes and "forward" in dynamic_shapes
else None
),
)
)
else:
Expand All @@ -447,9 +449,11 @@ def capture_multiple(
method_name,
getattr(m, method_name),
method_args,
dynamic_shapes[method_name]
if dynamic_shapes and method_name in dynamic_shapes
else None,
(
dynamic_shapes[method_name]
if dynamic_shapes and method_name in dynamic_shapes
else None
),
)
)
if prim_getters is not None:
Expand Down
4 changes: 3 additions & 1 deletion exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class CaptureConfig:
pt2_mode: bool = True
enable_functionalization: bool = True
enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True
enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config
enable_aot: bool = (
False # When it's true it implies automatic dynamic shapes via default dynamo config
)
_dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig)
_unlift: bool = False # This flag does nothing if enable_aot is False.
_use_old_decomp_table: bool = False
Expand Down
12 changes: 6 additions & 6 deletions exir/capture/_unlift.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ def unlift_exported_program_lifted_states(
if node.name in ep.graph_signature.inputs_to_buffers:
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
if buffer_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[buffer_name]
inp_pos_to_param_buffer_name[count] = (
param_buffer_name_to_corrected_name[buffer_name]
)
else:
inp_pos_to_param_buffer_name[count] = buffer_name
if node.name in ep.graph_signature.inputs_to_parameters:
param_name = ep.graph_signature.inputs_to_parameters[node.name]
if param_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[param_name]
inp_pos_to_param_buffer_name[count] = (
param_buffer_name_to_corrected_name[param_name]
)
else:
inp_pos_to_param_buffer_name[count] = param_name
count += 1
Expand Down
1 change: 1 addition & 0 deletions exir/delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"


# pyre-ignore
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
# pyre-ignore
Expand Down
1 change: 0 additions & 1 deletion exir/dialects/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def wrapper(f: Callable):


class _OpNamespace(types.ModuleType):

"""
EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher.
"""
Expand Down
3 changes: 2 additions & 1 deletion exir/dialects/edge/spec/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def gen_op_yaml(op_name: str) -> Optional[EdgeOpYamlInfo]:
Arguments:
op_name: The name of operator. Needs to conform the convention of "<name>.<overload_name>".
If no overload name for the operator, needs to use "default" as overload name.
Return the yaml info for given operator if generation succeed. Otherwise return None."""
Return the yaml info for given operator if generation succeed. Otherwise return None.
"""

try:
func_schema: torch._C.FunctionSchema = get_callable(op_name)._schema
Expand Down
3 changes: 2 additions & 1 deletion exir/dialects/edge/spec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def gen_index_pairs_to_types_mapping(
type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]]
) -> Dict[Tuple[int], List[str]]:
"""Generate mapping from index pairs to types. For example, given type_constraint [0, 0], [1, 1]
type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}."""
type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}.
"""

def gen(x: List[int]):
"""Generate all possible pairs of elements in the list."""
Expand Down
6 changes: 3 additions & 3 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def emit_program(
plans.append(emitter.plan())

debug_handle_map[name] = emitter.debug_handle_map
method_to_delegate_debug_id_map[
name
] = emitter.instr_id_to_delegate_debug_id_map
method_to_delegate_debug_id_map[name] = (
emitter.instr_id_to_delegate_debug_id_map
)

# emit any primitive getters
if prim_getters is not None:
Expand Down
1 change: 1 addition & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class _AbstractValue:
Dict[int, Tuple[int]], Dict[str, Tuple[int]]
]


# pyre-ignore[13]: Attribute `node` is never initialized.
class _Emitter(torch.fx.Interpreter):
"""An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
Expand Down
1 change: 1 addition & 0 deletions exir/experimental/export_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def trace(root: Callable[..., Value], concrete_args: Tuple[Value, ...]) -> Trace
concrete_args,
CaptureConfig(enable_functionalization=False),
).graph_module

# TODO convert torchdynamo guards to our own guards
def _convert_dynamo_guard_to_exir_guard(
dynamo_guard: DynamoGuard,
Expand Down
1 change: 1 addition & 0 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ class SharedObject:
last_used_index attribute. The shared object will be available for nodes
with index greater than last_used_index.
"""

# index of the shared object in the list of shared objects, used as a unique id
idx: int
# offset in the memory buffer
Expand Down
6 changes: 3 additions & 3 deletions exir/operator/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def set_mapping_for_op(op: OpOverload) -> None:
mismatched_out_schema: Optional[FunctionSchema] = next(
(s for s in all_schemas if s.kind() == SchemaKind.out), None
)
_schema_mismatch_map[
schema_to_opoverload(func_op_schema)
] = mismatched_out_schema
_schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
mismatched_out_schema
)

# update hte map even if scratch_schema is None to cache the negative
# case
Expand Down
16 changes: 8 additions & 8 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,14 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
]
).passes

base_post_op_replace_passes: List[
Callable[[torch.nn.Module], PassResult]
] = PassManager(
passes=[
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
]
).passes
base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
PassManager(
passes=[
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
]
).passes
)


def propagate_dynamic_shape(
Expand Down
6 changes: 3 additions & 3 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
)
prop_constant_data.append(prop_constant_node_input_spec)
buffers.append(prop_constant_tensor_fqn)
exported_program.state_dict[
prop_constant_tensor_fqn
] = prop_constant_tensor
exported_program.state_dict[prop_constant_tensor_fqn] = (
prop_constant_tensor
)
exported_program.graph_signature.input_specs.append(
prop_constant_node_input_spec
)
Expand Down
18 changes: 10 additions & 8 deletions exir/passes/insert_write_back_for_buffers_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,17 @@ def _insert_copy(
def insert_write_back_for_buffers_pass(ep: ExportedProgram):
gm: torch.fx.GraphModule = ep.graph_module
lifted_inputs: List[Optional[str]] = [
in_spec.target
if in_spec.kind
in (
InputKind.BUFFER,
InputKind.CONSTANT_TENSOR,
InputKind.PARAMETER,
InputKind.CUSTOM_OBJ,
(
in_spec.target
if in_spec.kind
in (
InputKind.BUFFER,
InputKind.CONSTANT_TENSOR,
InputKind.PARAMETER,
InputKind.CUSTOM_OBJ,
)
else None
)
else None
for in_spec in ep.graph_signature.input_specs
]

Expand Down
16 changes: 10 additions & 6 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,11 @@ def serialize_module_call_graph(
return [
ModuleCallEntry(
fqn=entry.fqn,
signature=self.serialize_module_call_signature(entry.signature)
if entry.signature
else None,
signature=(
self.serialize_module_call_signature(entry.signature)
if entry.signature
else None
),
)
for entry in module_call_graph
]
Expand Down Expand Up @@ -1668,9 +1670,11 @@ def deserialize_module_call_graph(
return [
ep.ModuleCallEntry(
fqn=entry.fqn,
signature=self.deserialize_module_call_signature(entry.signature)
if entry.signature
else None,
signature=(
self.deserialize_module_call_signature(entry.signature)
if entry.signature
else None
),
)
for entry in module_call_graph
]
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_get_schema_for_operators(self) -> None:

schemas = get_schema_for_operators(op_list)
pat = re.compile(r"[^\(]+\([^\)]+\) -> ")
for (_op_name, schema) in schemas.items():
for _op_name, schema in schemas.items():
self.assertIsNotNone(re.match(pat, schema))

def test_get_out_args(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions exir/tests/test_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_pass_registry_func(self) -> None:
"""
Test if we register a callable correctly
"""

# Registering w/o specifying pass_name
@PassRegistry.register()
def test_pass1(graph_module: torch.fx.GraphModule) -> None:
Expand Down
Loading

0 comments on commit 75352ad

Please sign in to comment.