diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 0fabf223fb..96c40e6536 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -112,11 +112,11 @@ def get_propagated_const_tensor_dict( # Initialize dict with all constant placeholders. const_node_to_tensor = get_constant_placeholder_dict(exported_program) - all_skip_targets: set[EdgeOpOverload] = set() - # Default set of targets to skip. - all_skip_targets.update(_DEFAULT_SKIP_TARGETS) if custom_skip_targets is not None: - all_skip_targets.update(custom_skip_targets) + all_skip_targets = custom_skip_targets + else: + # Default set of targets to skip. + all_skip_targets = _DEFAULT_SKIP_TARGETS for node in exported_program.graph.nodes: if node.op != "call_function" or node.target in all_skip_targets: