Skip to content

Commit

Permalink
added class mapping to automl class
Browse files Browse the repository at this point in the history
  • Loading branch information
screengreen committed Aug 20, 2024
1 parent 38fdfa9 commit 6ab6e04
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
16 changes: 15 additions & 1 deletion lightautoml/automl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,20 @@ def fit_predict(
self.timer.start()
train_dataset = self.reader.fit_read(train_data, train_features, roles)

# Saving class mapping
if self.reader.task.name == "binary":
self.classes_ = [1]
elif self.reader.task.name == "multi:reg":
self.classes_ = roles["target"]
elif self.reader.task.name == "reg":
self.classes_ = [roles["target"]]
else:
self.classes_ = (
sorted(self.reader.class_mapping, key=self.reader.class_mapping.get, reverse=False)
if self.reader.class_mapping
else None
)

assert (
len(self._levels) <= 1 or train_dataset.folds is not None
), "Not possible to fit more than 1 level without cv folds"
Expand Down Expand Up @@ -259,7 +273,7 @@ def fit_predict(
else:
break

blended_prediction, last_pipes = self.blender.fit_predict(level_predictions, pipes)
blended_prediction, last_pipes = self.blender.fit_predict(level_predictions, pipes, self.classes_)
self.levels.append(last_pipes)

self.reader.upd_used_features(remove=list(set(self.reader.used_features) - set(self.collect_used_feats())))
Expand Down
14 changes: 8 additions & 6 deletions lightautoml/automl/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def outp_dim(self) -> int: # noqa: D102
return self._outp_dim

def fit_predict(
self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPipeline]
self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPipeline], class_mapping: dict
) -> Tuple[LAMLDataset, Sequence[MLPipeline]]:
"""Wraps custom ``._fit_predict`` methods of blenders.
Expand All @@ -63,7 +63,7 @@ def fit_predict(
self._bypass = True
return predictions[0], pipes

return self._fit_predict(predictions, pipes)
return self._fit_predict(predictions, pipes, class_mapping)

def _fit_predict(
self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPipeline]
Expand Down Expand Up @@ -134,7 +134,7 @@ def split_models(self, predictions: Sequence[LAMLDataset]) -> Tuple[Sequence[LAM

return splitted_preds, model_idx, pipe_idx

def _set_metadata(self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPipeline]):
def _set_metadata(self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPipeline], class_mapping: dict):

pred0 = predictions[0]
pipe0 = pipes[0]
Expand All @@ -143,6 +143,8 @@ def _set_metadata(self, predictions: Sequence[LAMLDataset], pipes: Sequence[MLPi
self._outp_prob = pred0.task.name in ["binary", "multiclass"]
self._score = predictions[0].task.get_dataset_metric()

self._class_mapping = class_mapping

def score(self, dataset: LAMLDataset) -> float:
"""Score metric for blender.
Expand Down Expand Up @@ -321,7 +323,7 @@ def _get_weighted_pred(self, splitted_preds: Sequence[NumpyDataset], wts: Option
outp = splitted_preds[0].empty()
outp.set_data(
weighted_pred,
["WeightedBlend_{0}".format(x) for x in range(weighted_pred.shape[1])],
self._class_mapping if self._class_mapping else list(range(weighted_pred.shape[1])),
NumericRole(np.float32, prob=self._outp_prob),
)

Expand Down Expand Up @@ -436,7 +438,7 @@ def _prune_pipe(
return new_pipes, wts

def _fit_predict(
self, predictions: Sequence[NumpyDataset], pipes: Sequence[MLPipeline]
self, predictions: Sequence[NumpyDataset], pipes: Sequence[MLPipeline], class_mapping: dict
) -> Tuple[NumpyDataset, Sequence[MLPipeline]]:
"""Perform coordinate descent.
Expand All @@ -451,7 +453,7 @@ def _fit_predict(
Dataset and MLPipeline.
"""
self._set_metadata(predictions, pipes)
self._set_metadata(predictions, pipes, class_mapping)
splitted_preds, _, pipe_idx = cast(List[NumpyDataset], self.split_models(predictions))

wts = self._optimize(splitted_preds)
Expand Down

0 comments on commit 6ab6e04

Please sign in to comment.