Skip to content

Commit

Permalink
[MetaSchedule] Handle cases when no features found by FeatureExtractor (
Browse files Browse the repository at this point in the history
#14591)

This PR fixes the problem that ops like `topi.full` cannot be correctly tuned.

Currently, metaSchedule system adopts a XGBoost model to predict the running performance of an IRModule based on features extracted from it. 

However, Some operators like `topi.full` has so simple IRModule that no feature can be extracted. `topi.full` contains only one BufferStore node with a FloatImm value, so `PerStoreFeature` extractor cannot find any feature from it. And the XGBModel also cannot handle cases when there is no features.

This PR specifically handles this case in `PerStoreFeatureNode` and `XGBModel`.
  • Loading branch information
Ubospica authored Apr 12, 2023
1 parent b1ab4dc commit 742c5ee
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 8 deletions.
15 changes: 11 additions & 4 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,14 @@ def _mean_cost(x: RunnerResult) -> float:
return float(np.median([float(s) for s in x.run_secs]))

new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)]
new_mean_costs = np.array([_mean_cost(x) for x in results]).astype("float32")
new_mean_costs = [_mean_cost(x) for x in results]

# Filter instances with no features
new_mean_costs = [c for i, c in enumerate(new_mean_costs) if len(new_features[i]) != 0]
new_mean_costs_np = np.array(new_mean_costs).astype("float32")
new_features = [f for f in new_features if len(f) != 0]
if not new_features:
return

# Steps 3. Run validation
if group is not None and self.booster is not None:
Expand All @@ -489,7 +496,7 @@ def _mean_cost(x: RunnerResult) -> float:
f"{key}: {score:.6f}"
for key, score in self._validate(
xs=new_features,
ys=group.min_cost / new_mean_costs,
ys=group.min_cost / new_mean_costs_np,
)
),
)
Expand All @@ -499,10 +506,10 @@ def _mean_cost(x: RunnerResult) -> float:
group = FeatureGroup(
group_hash=new_group_hash,
features=new_features,
costs=new_mean_costs,
costs=new_mean_costs_np,
)
else:
group.append(new_features, new_mean_costs)
group.append(new_features, new_mean_costs_np)
self.data[new_group_hash] = group
self.data_size += len(new_features)

Expand Down
11 changes: 7 additions & 4 deletions src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,15 @@ int64_t GetVarStride(const std::vector<MultiIndex>& multi_indices, const IntVec&
/*!
* \brief Converts a 2-dimensional STL vector to a TVM NDArray
* \param src The source 2-dimensional STL vector
* \param second_dim_size The length of the second dimension. When the first dim of src is 0,
* second_dim_size must be specified, and in such case the shape of the result NDArray is
* (0, second_dim_size).
* \return The converted TVM NDArray
*/
runtime::NDArray AsNDArray(const std::vector<std::vector<double>>& src) {
ICHECK(!src.empty());
runtime::NDArray AsNDArray(const std::vector<std::vector<double>>& src, int second_dim_size = -1) {
int n = src.size();
int m = src[0].size();
ICHECK(!src.empty() || second_dim_size != -1);
int m = src.empty() ? second_dim_size : src[0].size();
runtime::NDArray tgt = runtime::NDArray::Empty(
/*shape=*/{n, m},
/*dtype=*/DLDataType{kDLFloat, 64, 1},
Expand Down Expand Up @@ -1404,7 +1407,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
feature_group6->Export(&feature);
}
}
results[task_id] = tir::utils::AsNDArray(features);
results[task_id] = tir::utils::AsNDArray(features, this->feature_vector_length);
};
support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
return results;
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_meta_schedule_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class FullModule:
@T.prim_func
def main(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = T.float32(1)


# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument


Expand Down Expand Up @@ -165,6 +178,19 @@ def test_meta_schedule_xgb_model():
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])


def test_meta_schedule_xgb_model_no_feature():
model = XGBModel(num_warmup_samples=0)
tune_ctx = TuneContext(
FullModule,
target="llvm --num-cores 16",
space_generator="post-order-apply",
search_strategy="evolutionary",
)
candidate = MeasureCandidate(Schedule(FullModule), [])
model.update(tune_ctx, [candidate], [_dummy_result()])
model.predict(tune_ctx, [candidate])


def test_meta_schedule_xgb_model_reload():
extractor = RandomFeatureExtractor()
model = XGBModel(extractor=extractor, num_warmup_samples=10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,28 @@ def _create_schedule():
)


def test_empty_feature():
@T.prim_func
def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = T.float32(1)

def _create_schedule():
return tir.Schedule(full, debug_mask="all")

extractor = ms.feature_extractor.PerStoreFeature()
(feature,) = extractor.extract_from(
_make_context(tvm.target.Target("llvm")),
candidates=[_make_candidate(_create_schedule)],
)
feature = feature.numpy()
assert feature.shape == (0, N_FEATURES)


def test_gpu():
def _create_schedule():
func = matmul
Expand Down

0 comments on commit 742c5ee

Please sign in to comment.