From 687a766d628eae4885ce170fc6a88f2c98923b51 Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Thu, 23 Jun 2022 12:44:18 -0700 Subject: [PATCH] [MetaSchedule] Added a cost model --- .../tvm/meta_schedule/cost_model/__init__.py | 1 + .../meta_schedule/cost_model/cost_model.py | 5 +- .../tvm/meta_schedule/cost_model/mlp_model.py | 743 ++++++++++++++++++ 3 files changed, 748 insertions(+), 1 deletion(-) create mode 100644 python/tvm/meta_schedule/cost_model/mlp_model.py diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 8fc6f04ac9558..2ef789cc8ede1 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -18,5 +18,6 @@ The tvm.meta_schedule.cost_model package. """ from .cost_model import CostModel, PyCostModel +from .mlp_model import MLPModel from .random_model import RandomModel from .xgb_model import XGBModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index 2fdb9b93494f9..439ee9c16a21a 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -175,6 +175,7 @@ def update( context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], + skip_model: bool = False, ) -> None: """Update the cost model given running results. @@ -186,11 +187,13 @@ def update( The measure candidates. results : List[RunnerResult] The running results of the measure candidates. + skip_model : bool + Skip updating the cost model, only load the data into runtime. """ raise NotImplementedError def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: - """Update the cost model given running results. + """Predict the normalized score using the cost model. Parameters ---------- diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py new file mode 100644 index 0000000000000..1f96a9ce8a688 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -0,0 +1,743 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +MLP-based cost model +""" + +import logging +import math +import os +import random +import tempfile +from collections import OrderedDict +from itertools import chain as itertools_chain +from typing import Dict, List, NamedTuple +from tqdm import tqdm + +import numpy as np +import torch +from torch import nn + +# pylint: disable=relative-beyond-top-level +from ...contrib.tar import tar, untar +from ...runtime import NDArray +from ..cost_model import PyCostModel +from ..feature_extractor import FeatureExtractor, PerStoreFeature +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext +from ..utils import derived_object, shash2hex + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +# pylint: disable=no-member + + +class SegmentSumMLPConfig(NamedTuple): + """SegmentSum MLP model configuration + + Parameters + ---------- + input_dim : int + The input dim for the model. + hidden_dim : int + The hidden dim for the model. + output_dim : int + The output dim for the model. + use_norm : bool + Whether to normalize the segment sum or not. + use_sigmoid : bool + Whether to use sigmoid on the final output or not. + """ + + input_dim: int = 172 + hidden_dim: int = 256 + output_dim: int = 1 + use_norm: bool = False + use_sigmoid: bool = False + + def to_dict(self): # pylint: disable=missing-function-docstring + return { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "use_norm": self.use_norm, + "use_sigmoid": self.use_sigmoid, + } + + +# pylint: disable=too-few-public-methods +class FeatureGroup: + """Feature group + + Parameters + ---------- + group_hash : str + The hash of the group + features : List[np.ndarray] + The features + costs : List[float] + The costs + min_cost : float + The minimum cost + """ + + group_hash: str + features: List[np.ndarray] + costs: np.ndarray + min_cost: float + + def __init__( + self, + group_hash: str, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.group_hash = group_hash + self.features = features + self.costs = costs + self.min_cost = np.min(costs) + + def append( # pylint: disable=missing-function-docstring + self, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.features.extend(features) + self.costs = np.append(self.costs, costs) + self.min_cost = np.min(self.costs) + + +# pylint: disable=too-many-instance-attributes +class SegmentDataLoader: + """Dataloader for SegmentSum MLP model. + + Parameters + ---------- + features : List[np.ndarray] + The features + results : np.ndarray + The measured results + batch_size : int + The batch size + shuffle : bool + Whether to shuffle the dataset or not + """ + + def __init__( + self, + features, + results, + batch_size=128, + shuffle=False, + ): + self.batch_size = batch_size + self.shuffle = shuffle + self.data_size = len(features) + + # flatten features and store the starting indices + self.segment_sizes = torch.tensor([len(feature) for feature in features]) + self.feature_offsets = ( + torch.cumsum(self.segment_sizes, 0, dtype=torch.int32) - self.segment_sizes + ) + features = torch.cat([torch.tensor(feature) for feature in features]) + norm = features.max(dim=0)[0] + norm[norm == 0] = 1 + self.features = features / norm + self.results = torch.tensor(results) + self.iter_order = self.pointer = None + + def __len__(self): + return self.data_size + + def __iter__(self): + if self.shuffle: + self.iter_order = torch.randperm(self.data_size) + else: + self.iter_order = torch.arange(self.data_size) + self.pointer = 0 + return self + + def __next__(self): + if self.pointer >= self.data_size: + raise StopIteration + batch_indices = self.iter_order[self.pointer : self.pointer + self.batch_size] + self.pointer += self.batch_size + return self._fetch_indices(batch_indices) + + def _fetch_indices(self, indices): + segment_sizes, feature_offsets = self.segment_sizes[indices], self.feature_offsets[indices] + feature_indices = torch.empty(segment_sizes.sum(), dtype=torch.int32) + idx = 0 + for offset, seg_size in zip(feature_offsets, segment_sizes): + feature_indices[idx : idx + seg_size] = torch.arange(offset, offset + seg_size) + idx += seg_size + features = self.features[feature_indices.long()] + results = self.results[indices.long()] + return segment_sizes, features, results + + +class SegmentSumMLPModule(nn.Module): + """SegmentSum MLP model. + + Parameters + ---------- + input_dim : int + The input dim for the model. + hidden_dim : int + The hidden dim for the model. + output_dim : int + The output dim for the model. + use_norm : bool + Whether to normalize the segment sum or not. + use_sigmoid : bool + Whether to use sigmoid on the final output or not. + """ + + input_dim: int + hidden_dim: int + output_dim: int + use_norm: bool + use_sigmoid: bool + + def __init__( # pylint: disable=too-many-arguments + self, + input_dim: int = 172, + hidden_dim: int = 256, + output_dim: int = 1, + use_norm: bool = False, + use_sigmoid: bool = False, + ): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.norm = nn.BatchNorm1d(hidden_dim) if use_norm else nn.Identity() + self.layer0 = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.layer1 = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.decoder = nn.Linear(hidden_dim, output_dim) + self.sigmoid = nn.Sigmoid() if use_sigmoid else nn.Identity() + + def forward( + self, + segment_sizes: torch.Tensor, + features: torch.Tensor, + ): + """Forward the inputs with the model. + + Parameters + ---------- + segment_sizes : Tensor + The sizes of the segments. + features : Tensor + The feature vectors of the candidates. + """ + n_seg = len(segment_sizes) + encoded_features = self.encoder(features) + segment_indices = torch.repeat_interleave( + torch.arange(n_seg, device=features.device), + segment_sizes.long(), + ) + n_dim = encoded_features.shape[1] + segment_sum = torch.scatter_add( + input=torch.zeros((n_seg, n_dim), dtype=encoded_features.dtype, device=features.device), + dim=0, + index=segment_indices.view(-1, 1).expand(-1, n_dim), + src=encoded_features, + ) + out = self.norm(segment_sum) + out = self.layer0(out) + out + out = self.layer1(out) + out + out = self.decoder(out).squeeze() + out = self.sigmoid(out) + return out + + # pylint: disable=no-self-use,too-many-arguments,too-many-locals,invalid-name + def lambda_rank_loss(self, preds, labels, k=None, eps=1e-10, sigma=1.0): + """LambdaLoss: Metric-Driven Loss for Learning-to-Rank + + Parameters + ---------- + preds : Tensor + The predicted runtime for each candidate. + labels : Tensor + The measured runtime for each candidate. + k : int + Loss for top k. + Default is None, which means computing all scores. + eps : float + The minimum value to the denominator and argument of log if they reach 0. + sigma : float + The scaling factor to the input of the sigmoid function. + """ + device = preds.device + y_pred, y_true = preds[None, :], labels[None, :] + y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) + y_true_sorted, _ = y_true.sort(descending=True, dim=-1) + true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) + true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] + padded_pairs_mask = torch.isfinite(true_diffs) & (true_diffs > 0) + ndcg_at_k_mask = torch.zeros( + (y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device + ) + ndcg_at_k_mask[:k, :k] = 1 + true_sorted_by_preds.clamp_(min=0.0) + y_true_sorted.clamp_(min=0.0) + pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) + D = torch.log2(1.0 + pos_idxs.float())[None, :] + maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps) + G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] + weights = torch.abs( + torch.pow(D[:, :, None], -1.0) - torch.pow(D[:, None, :], -1.0) + ) * torch.abs(G[:, :, None] - G[:, None, :]) + scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp( + min=-1e8, max=1e8 + ) + scores_diffs[torch.isnan(scores_diffs)] = 0.0 + weighted_probs = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp( + min=eps + ) + losses = torch.log2(weighted_probs) + masked_losses = losses[padded_pairs_mask & ndcg_at_k_mask] + loss = -torch.sum(masked_losses) + return loss + + def topk_score( + self, + pred_results: torch.Tensor, + gt_results: torch.Tensor, + k: int, + ) -> float: + """ + Evaluate the top-k score + + Parameters + ---------- + pred_results: Tensor + The raw prediction + gt_results: Tensor + The measured results + k : int + The k in top k score + + Returns + ------- + score : float + The top-k score + """ + topk_indices = torch.topk(pred_results, k, largest=False).indices + score = gt_results.min() / gt_results[topk_indices].min() + return score.item() + + +@derived_object +class MLPModel(PyCostModel): + """MLP model + + Parameters + ---------- + extractor : FeatureExtractor + The feature extractor for the model. + config : SegmentSumMLPConfig + The SegmentSum MLP model config. + num_epoch : int + Number of epoches. + learning_rate : float + Learning rate. + grad_clip : float + Gradient clipping max norm. + weight_decay : float + Adam weight decay. + batch_size : int + The batch size for dataloader. + test_split : float + The portion of the testing set. + test_interval : int + The testing interval (in number of epoches). + train_verbose : int + The verbose frequency for training (in number of batches). + """ + + extractor: FeatureExtractor + config: SegmentSumMLPConfig + num_epoch: int + learning_rate: float + grad_clip: float + weight_decay: float + batch_size: int + test_split: float + test_interval: int + train_verbose: int + data: Dict[str, FeatureGroup] + data_size: int + + def __init__( + self, + *, + extractor: FeatureExtractor = PerStoreFeature(extract_workload=True), + config: SegmentSumMLPConfig = SegmentSumMLPConfig(), + num_epoch: int = 50, + learning_rate: float = 7e-4, + grad_clip: float = 0.5, + weight_decay: float = 1e-6, + batch_size: int = 128, + test_split: float = 0.2, + test_interval: int = 5, + train_verbose: int = 25, + ): + super().__init__() + self.extractor = extractor + self.config = config + self.num_epoch = num_epoch + self.learning_rate = learning_rate + self.grad_clip = grad_clip + self.weight_decay = weight_decay + self.batch_size = batch_size + self.test_split = test_split + self.test_interval = test_interval + self.train_verbose = train_verbose + self.model = SegmentSumMLPModule(**self.config.to_dict()) + self.device = "cuda" if torch.cuda.device_count() else "cpu" + self.data = OrderedDict() + self.data_size = 0 + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + To expedite data loading and processing, each time this method loads the model + together with previously cached feature vectors if exist. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model.pth") + data_path = os.path.join(tmp_dir, "data.npy") + # Step 1. Untar + untar(path, tmp_dir) + # Step 2. Load data + if os.path.exists(data_path): + data = OrderedDict() + data_size = 0 + for group_hash, features, costs in np.load(data_path, allow_pickle=True): + data[group_hash] = FeatureGroup( + group_hash=group_hash, + features=list(features), + costs=costs, + ) + data_size += len(costs) + self.data = data + self.data_size = data_size + # Step 3. Load the model + if os.path.exists(model_path): + self.model.load_state_dict(torch.load(model_path)) + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + To expedite data loading and processing, each time this method saves the model + together with previously cached feature vectors. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model.pth") + data_path = os.path.join(tmp_dir, "data.npy") + # Step 1. Save the model + torch.save(self.model.state_dict(), model_path) + # Step 2. Save data + data = [ + ( + g.group_hash, + g.features, + g.costs, + ) + for g in self.data.values() + ] + np.save( + file=data_path, + arr=np.array(data, dtype=object), + ) + # Step 3. Tar it + tar(path, [x for x in [model_path, data_path] if x is not None]) + logger.info("Saved MLPModel to %s", path) + + def update( + self, + context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + skip_model: bool = False, + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + context : TuneContext + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + skip_model : bool + Skip updating the cost model, only load the data into runtime. + """ + assert len(candidates) == len(results) + if len(candidates) == 0: + return + + # Step 1. Get the feature group + new_group_hash = shash2hex(context.mod) + group = self.data.get(new_group_hash, None) + + # Step 2. Extract features + def _feature(feature: NDArray) -> np.ndarray: + return feature.numpy().astype("float32") + + def _mean_cost(res: RunnerResult) -> float: + if not res.run_secs: + return 1e10 + return float(np.median([float(s) for s in res.run_secs])) + + new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)] + new_mean_costs = np.array([_mean_cost(x) for x in results]).astype("float32") + + # Steps 3. Run validation + if not skip_model and group is not None: + logger.debug( + "MLP validation: %s", + "\t".join( + f"{key}: {score:.6f}" + for key, score in self._validate( + features=new_features, + gt_results=group.min_cost / new_mean_costs, + ) + ), + ) + + # Step 4. Add the features into the data points + if group is None: + group = FeatureGroup( + group_hash=new_group_hash, + features=new_features, + costs=new_mean_costs, + ) + else: + group.append(new_features, new_mean_costs) + self.data[new_group_hash] = group + self.data_size += len(new_features) + + # Step 5. Re-train the model + if not skip_model: + self._train() + + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Predict the normalized score using the cost model. + + Parameters + ---------- + context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + features = [ + torch.tensor(x.numpy().astype("float32")) + for x in self.extractor.extract_from(context, candidates) + ] + segment_sizes = torch.tensor([len(feature) for feature in features]).to(self.device) + features = torch.cat(features).to(self.device) + norm = features.max(dim=0)[0] + norm[norm == 0] = 1 + features /= norm + # begin predicting + self.model = self.model.to(self.device) + self.model.eval() + result = self.model(segment_sizes, features).detach().cpu().numpy() + return result + + def _train(self) -> None: # pylint: disable=too-many-locals,too-many-statements + """Train the MLP model using all the data in the runtime.""" + # split into training and testing set + keys = list(self.data.keys()) + test_keys = random.sample(keys, k=math.floor(len(keys) * self.test_split)) + test_data = OrderedDict() + for key in test_keys: + test_data[key] = self.data[key] + del self.data[key] + train_features = list( + itertools_chain.from_iterable([g.features for g in self.data.values()]) + ) + test_features = list( + itertools_chain.from_iterable([g.features for g in test_data.values()]) + ) + train_results = np.concatenate([g.min_cost / g.costs for g in self.data.values()]) + test_results = np.concatenate([g.min_cost / g.costs for g in test_data.values()]) + train_loader = SegmentDataLoader( + train_features, train_results, batch_size=self.batch_size, shuffle=True + ) + test_loader = SegmentDataLoader( + test_features, test_results, batch_size=self.batch_size, shuffle=False + ) + self.data, test_data = None, None # save memory + logger.info("Training size: %d, testing size: %d", len(train_loader), len(test_loader)) + + # begin training + self.model = self.model.to(self.device) + optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=self.num_epoch // 3, gamma=0.7 + ) + min_test_loss = 1e10 + with tempfile.TemporaryDirectory() as tmp_dir: + model_cache_path = os.path.join(tmp_dir, "best_model.pth") + for epoch in range(self.num_epoch): + logger.info("Epoch: %d", epoch) + # training + self.model.train() + train_loss = None + for batch, (segment_sizes, features, gt_results) in tqdm(enumerate(train_loader)): + optimizer.zero_grad() + segment_sizes, features, gt_results = ( + segment_sizes.to(self.device), + features.to(self.device), + gt_results.to(self.device), + ) + pred_results = self.model(segment_sizes, features) + loss = self.model.lambda_rank_loss(pred_results, gt_results) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) + optimizer.step() + loss = loss.detach().cpu() + train_loss = ( + train_loss * 0.95 + loss.item() * 0.05 + if train_loss is not None + else loss.item() + ) + segment_sizes, features, gt_results, pred_results = ( + segment_sizes.detach().cpu(), + features.detach().cpu(), + gt_results.detach().cpu(), + pred_results.detach().cpu(), + ) + if batch % self.train_verbose == 0: + logger.info("Batch: %d, train loss: %6f", batch, train_loss) + del pred_results + del loss + scheduler.step() + # testing + if epoch % self.test_interval == 0: + self.model.eval() + test_scores, test_losses = [], [] + for batch, (segment_sizes, features, gt_results) in tqdm(enumerate(test_loader)): + segment_sizes, features = ( + segment_sizes.to(self.device), + features.to(self.device), + ) + pred_results = self.model(segment_sizes, features) + segment_sizes, features, pred_results = ( + segment_sizes.detach().cpu(), + features.detach().cpu(), + pred_results.detach().cpu(), + ) + test_losses.append(self.model.lambda_rank_loss(pred_results, gt_results).item()) + scores = [] + for k in [1, 5, 10]: + scores.append(self.model.topk_score(pred_results, gt_results, k)) + test_scores.append(scores) + del pred_results + test_loss = ( + np.array(test_losses[:-1]).mean() if len(test_losses) > 1 else test_losses[0] + ) + logger.info( + "Average test loss: %6f, top1 score: %5f, top5 score: %5f, top10 score: %5f", + test_loss, + np.array(test_scores)[:, 0].mean(), + np.array(test_scores)[:, 1].mean(), + np.array(test_scores)[:, 2].mean(), + ) + if test_loss < min_test_loss: + min_test_loss = test_loss + torch.save(self.model.state_dict(), model_cache_path) + self.model.to("cpu").load_state_dict(torch.load(model_cache_path)) + + def _validate(self, features: List[np.ndarray], gt_results: np.ndarray) -> Dict[str, float]: + """Run validation without a test dataset. + + Parameters + ---------- + features : List[np.ndarray] + The features + gt_results : np.ndarray + The measured results + + Returns + ------- + result : Dict[str, float] + The validation result. + """ + segment_sizes = torch.tensor([len(feature) for feature in features]).to(self.device) + features = torch.cat([torch.tensor(feature) for feature in features]).to(self.device) + norm = features.max(dim=0)[0] + norm[norm == 0] = 1 + features /= norm + gt_results = torch.tensor(gt_results) + # begin validating + self.model = self.model.to(self.device) + self.model.eval() + pred_results = self.model(segment_sizes, features) + segment_sizes, features, pred_results = ( + segment_sizes.detach().cpu(), + features.detach().cpu(), + pred_results.detach().cpu(), + ) + loss = self.model.lambda_rank_loss(pred_results, gt_results).item() + scores = [] + for k in [1, 5, 10]: + scores.append(self.model.topk_score(pred_results, gt_results, k)) + return { + "loss": loss, + "top1_score": scores[0], + "top5_score": scores[1], + "top10_score": scores[2], + }