Skip to content

Commit

Permalink
[MetaSchedule] Enable BertTuning with MetaScheduler (apache#11)
Browse files Browse the repository at this point in the history
* Test bert using gluon model.

Change gluon to torch.

Revert evil work around.

Skip test.

* Minor fix.
  • Loading branch information
zxybazh authored Jan 16, 2022
1 parent fd7712c commit 2feb3fb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
52 changes: 52 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
VIDEO_CLASSIFICATION = (2,)
SEGMENTATION = (3,)
OBJECT_DETECTION = (4,)
TEXT_CLASSIFICATION = (5,)


# Specify the type of each model
Expand Down Expand Up @@ -95,6 +96,11 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
"r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
# Text classification
"bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_large": MODEL_TYPE.TEXT_CLASSIFICATION,
}


Expand All @@ -121,6 +127,8 @@ def get_torch_model(

import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model_trace = torch.jit.trace(model, inp)
Expand All @@ -136,6 +144,50 @@ def do_trace(model, inp):
model = getattr(models.detection, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
model = getattr(models.video, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config_dict = {
"bert_tiny": transformers.BertConfig(
num_hidden_layers=6,
hidden_size=512,
intermediate_size=2048,
num_attention_heads=8,
return_dict=False,
),
"bert_base": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
),
"bert_medium": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
"bert_large": transformers.BertConfig(
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
}
configuration = config_dict[model_name]
model = transformers.BertModel(configuration)
input_name = "input_ids"
A = torch.randint(10000, input_shape)

model.eval()
scripted_model = torch.jit.trace(model, [A], strict=False)

input_name = "input_ids"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params
else:
raise ValueError("Unsupported model in Torch model zoo.")

Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() {

int running_tasks = tasks.size();
for (int task_id; (task_id = NextTaskId()) != -1;) {
LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name;
TuneContext task = tasks[task_id];
ICHECK(!task->is_stopped);
ICHECK(!task->runner_futures.defined());
Expand Down
11 changes: 9 additions & 2 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.skip("Integration test")
@pytest.mark.parametrize("model_name", ["resnet18"])
@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"])
def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str):
Expand All @@ -47,6 +47,9 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
input_shape = (1, 3, 300, 300)
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
input_shape = (batch_size, 3, 3, 299, 299)
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
seq_length = 128
input_shape = (batch_size, seq_length)
else:
raise ValueError("Unsupported model: " + model_name)
output_shape: Tuple[int, int] = (batch_size, 1000)
Expand All @@ -71,7 +74,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
work_dir=work_dir,
)
for i, sch in enumerate(schs):
print("-" * 10 + f" Part {i}/{len(schs)} " + "-" * 10)
print("-" * 10 + f" Part {i+1}/{len(schs)} " + "-" * 10)
if sch is None:
print("No valid schedule found!")
else:
Expand All @@ -82,3 +85,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
if __name__ == """__main__""":
test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16")
test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")
test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")

0 comments on commit 2feb3fb

Please sign in to comment.