Skip to content

Commit

Permalink
[AutoScheduler] Fix task extraction with TE compiler (apache#8560)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Fix task extraction with TE compiler

* fix

* test

* Update python/tvm/auto_scheduler/relay_integration.py
  • Loading branch information
comaniac authored and ylc committed Sep 29, 2021
1 parent 3c7e561 commit 280fb7b
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 73 deletions.
47 changes: 38 additions & 9 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract_tasks(
# create search tasks
tasks = []
weights = []
for (func_name, wkl_key), weight in env.wkl_key_to_weight.items():
for wkl_key, (weight, func_names) in env.wkl_key_to_weight.items():
tasks.append(
SearchTask(
workload_key=wkl_key,
Expand All @@ -165,7 +165,7 @@ def extract_tasks(
else None
),
task_inputs_save_to_file=True,
desc=func_name,
desc=",".join(func_names),
)
)
weights.append(weight)
Expand All @@ -189,6 +189,7 @@ class TracingEnvironment:
def __init__(self, tracing_mode):
self.tracing_mode = tracing_mode
self.relay_disable_build_cache = "false"
self.func_name_to_wkl_key = {}
self.wkl_key_to_weight = {}
self.wkl_key_to_input_names = {}

Expand All @@ -210,10 +211,12 @@ def add_workload_key(self, func_name, workload_key):
workload_key: str
The workload key of a task.
"""
key = (func_name, workload_key)
if key not in self.wkl_key_to_weight:
self.wkl_key_to_weight[key] = 0
self.wkl_key_to_weight[key] += 1
self.func_name_to_wkl_key[func_name] = workload_key
if workload_key not in self.wkl_key_to_weight:
self.wkl_key_to_weight[workload_key] = (0, set())
weight, func_names = self.wkl_key_to_weight[workload_key]
func_names.add(func_name)
self.wkl_key_to_weight[workload_key] = (weight + 1, func_names)

def add_workload_input_names(self, workload_key, input_names):
"""Add special task inputs to this workload.
Expand Down Expand Up @@ -379,11 +382,37 @@ def auto_schedule_topi(func_name, outs):

@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights")
def te_compiler_update_weights(function_weights):
"""A callback for updating the weights of extracted tasks."""
"""A callback for updating the weights of extracted tasks. When using the TE compiler
that avoids compiling the same function multiple times by caching, all extracted tasks
have weight 1, so the TE compiler invokes this callback at the end. In this case,
we override existing weights with the use_count in TE compiler cache.
Parameters
----------
function_weights: Dict[str, int]
Mapping from function names to their weights.
"""
env = TracingEnvironment.current
if env is not None:
for key in env.wkl_key_to_weight:
env.wkl_key_to_weight[key] = function_weights[key[0]]
# Override this map with the weights in the TE compiler.
env.wkl_key_to_weight = {}

for func_name, weight in function_weights.items():
# If the function name is not in the map, then it means we are not interested in
# this function during task extraction (e.g., a function without reduction).
if func_name not in env.func_name_to_wkl_key:
continue

workload_key = env.func_name_to_wkl_key[func_name]
if workload_key not in env.wkl_key_to_weight:
env.wkl_key_to_weight[workload_key] = (0, set())

# Note that the function appears multiple times in a model will be renamed
# to make sure function names are unique, so we use the workload key generated
# from the function's TE compute to determine their weights.
old_weight, func_names = env.wkl_key_to_weight[workload_key]
func_names.add(func_name)
env.wkl_key_to_weight[workload_key] = (old_weight + weight, func_names)


def tensor_no_check_call(self, *indices):
Expand Down
136 changes: 72 additions & 64 deletions tests/python/relay/test_auto_scheduler_task_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,52 +96,61 @@ def get_network(name, batch_size=1, layout="NHWC"):


@tvm.testing.requires_cuda
def test_task_extraction_cuda():
@pytest.mark.parametrize(
"params",
[
("mlp", "NHWC", 1, 2),
("resnet-18", "NHWC", 24, 25),
("resnet-18", "NCHW", 24, 25),
("mobilenet", "NHWC", 22, 30),
("mobilenet", "NCHW", 22, 30),
("resnet3d-18", "NCDHW", 23, 24),
("resnet3d-18", "NDHWC", 23, 24),
],
)
def test_task_extraction_cuda(params):
target = tvm.target.Target("cuda")
network, layout, expected_task, expected_weights = params

mod, params = get_network("mlp")
mod, params = get_network(network, layout=layout)
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

assert len(tasks) == 1
assert sum(task_weights) == 2

for layout in ["NHWC", "NCHW"]:
mod, params = get_network("resnet-18", layout=layout)
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

assert len(tasks) == 24
assert sum(task_weights) == 25

mod, params = get_network("mobilenet", layout=layout)
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

assert len(tasks) == 22
assert sum(task_weights) == 30

for layout in ["NCDHW", "NDHWC"]:
mod, params = get_network("resnet3d-18", layout=layout)
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

assert len(tasks) == 23
assert sum(task_weights) == 24, sum(task_weights)


def test_task_extraction():
for task, weight in zip(tasks, task_weights):
print(task.desc, task.workload_key, weight)

assert len(tasks) == expected_task
assert sum(task_weights) == expected_weights


@pytest.mark.parametrize(
"params",
[
# Relay FuseOps puts two conv2ds to separate functions and results in two tasks.
("basic_func", 2, False),
# Relay FuseOps will not break the primitive function and result in one task.
("fused_func", 1, False),
# The Relay function without complex ops will not form a task by default.
("simple_func", 0, False),
# Every Relay function becomes a task regardless what ops in its body.
("simple_func", 1, True),
# The Relay function without any reduce op is considered as a simple task.
("shape_of_func", 0, False),
("shape_of_func", 1, True),
# The Relay function with dynamic shape inputs/outputs will not be extracted.
("dyn_shape_func", 0, False),
# The Conv2D in the Relay function with control flow could still be a task.
# Also, two identical Conv2D should only be one task with weight=2.
("control_flow_func", 1, False),
# The first function with unsupported op (NMS) will not be extracted.
("func_w_unsupported_op", 1, True),
],
)
def test_task_extraction_cpu(params):
ishape = (1, 3, 224, 224)
w1shape = (32, 3, 3, 3)
w2shape = (32, 32, 3, 3)
dtype = "float32"
target = tvm.target.Target("llvm")

def verify_task_extraction(func, expected_task, include_simple_tasks=False):
mod = tvm.IRModule.from_expr(func)
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], None, target, include_simple_tasks=include_simple_tasks
)

assert len(tasks) == expected_task
assert len(task_weights) == expected_task

def get_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
Expand Down Expand Up @@ -183,13 +192,16 @@ def get_func_with_dynamic_shape():

def get_func_with_control_flow():
data = relay.var("data", shape=(1, 3, 224, 224))
weight = relay.var("weight", shape=(32, 3, 3, 3))
weight = relay.var("weight", shape=(3, 3, 3, 3))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)

true_branch = relay.zeros(shape=(1, 32, 222, 222), dtype="float32")
false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=32)
true_branch = relay.zeros(shape=(1, 3, 224, 224), dtype="float32")
false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=3, padding=(1, 1))
false_branch = relay.nn.conv2d(
false_branch, weight, kernel_size=(3, 3), channels=3, padding=(1, 1)
)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)
return relay.Function([data, weight, eq1, eq2], out)
Expand All @@ -213,32 +225,28 @@ def get_postproc_func():
out = relay.Call(get_postproc_func(), [nms])
return relay.Function([cls_prob, loc_pred, anchors], out)

# Relay FuseOps puts two conv2ds to separate functions and results in two tasks.
verify_task_extraction(get_func(), 2)

# By setting the function to primitive, Relay FuseOps will not break it and result in one task.
verify_task_extraction(get_fused_func(), 1)

# The Relay function without complex ops will not form a task by default.
verify_task_extraction(get_simple_func(), 0)

# Every Relay function becomes a task regardless what ops in its body.
verify_task_extraction(get_simple_func(), 1, True)

# The Relay function without any reduce op is considered as a simple task.
verify_task_extraction(get_shape_of_func(), 0)
verify_task_extraction(get_shape_of_func(), 1, True)

# The Relay function with dynamic shape inputs/outputs will not be extracted.
verify_task_extraction(get_func_with_dynamic_shape(), 0)
func_map = {
"basic_func": get_func,
"fused_func": get_fused_func,
"simple_func": get_simple_func,
"shape_of_func": get_shape_of_func,
"dyn_shape_func": get_func_with_dynamic_shape,
"control_flow_func": get_func_with_control_flow,
"func_w_unsupported_op": get_func_with_unsupported_op,
}

def verify_task_extraction(func_name, expected_task, include_simple_tasks=False):
func = func_map[func_name]()
mod = tvm.IRModule.from_expr(func)
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], None, target, include_simple_tasks=include_simple_tasks
)

# The Conv2D in the Relay function with control flow could still be a task.
verify_task_extraction(get_func_with_control_flow(), 1)
assert len(tasks) == expected_task
assert len(task_weights) == expected_task

# Func1 (with NMS) -> Func2 (injective).
verify_task_extraction(get_func_with_unsupported_op(), 1, True)
verify_task_extraction(*params)


if __name__ == "__main__":
test_task_extraction_cuda()
test_task_extraction()
pytest.main([__file__])

0 comments on commit 280fb7b

Please sign in to comment.