-
-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
223 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from amlb.benchmark import TaskConfig | ||
from amlb.data import Dataset | ||
from amlb.utils import call_script_in_same_dir | ||
|
||
|
||
def setup(*args, **kwargs): | ||
call_script_in_same_dir(__file__, "setup.sh", *args, **kwargs) | ||
|
||
|
||
def run(dataset: Dataset, config: TaskConfig): | ||
from frameworks.shared.caller import run_in_venv | ||
|
||
data = dict( | ||
target=dataset.target.name, | ||
train=dict( | ||
X=dataset.train.X, | ||
y=dataset.train.y_enc, | ||
), | ||
test=dict( | ||
X=dataset.test.X, | ||
y=dataset.test.y_enc, | ||
), | ||
) | ||
if config.measure_inference_time: | ||
data["inference_subsample_files"] = dataset.inference_subsample_files(fmt="parquet") | ||
options = dict( | ||
serialization=dict(sparse_dataframe_deserialized_format='dense') | ||
) | ||
|
||
return run_in_venv(__file__, "exec.py", | ||
input_data=data, dataset=dataset, config=config, options=options) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import logging | ||
import os | ||
import pickle | ||
import re | ||
import subprocess | ||
import sys | ||
import tempfile as tmp | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import pandas as pd | ||
|
||
if sys.platform == 'darwin': | ||
os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' | ||
os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() | ||
os.environ['OMP_NUM_THREADS'] = '1' | ||
os.environ['OPENBLAS_NUM_THREADS'] = '1' | ||
os.environ['MKL_NUM_THREADS'] = '1' | ||
|
||
from frameworks.shared.callee import call_run, result, output_subdir, \ | ||
measure_inference_times | ||
from frameworks.shared.utils import Timer, touch | ||
|
||
from naiveautoml import NaiveAutoML | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def run(dataset, config): | ||
pip_list = subprocess.run("python -m pip list".split(), capture_output=True) | ||
match = re.search(r"naiveautoml\s+([^\n]+)", pip_list.stdout.decode(), flags=re.IGNORECASE) | ||
version, = match.groups() | ||
log.info("\n**** NaiveAutoML [v%s] ****", version) | ||
|
||
metrics_mapping = dict( | ||
acc='accuracy', | ||
balacc='balanced_accuracy', | ||
auc='roc_auc', | ||
logloss='neg_log_loss', | ||
mae='neg_mean_absolute_error', | ||
r2='r2', | ||
rmse='neg_mean_squared_error', | ||
) | ||
scoring_metric = metrics_mapping.get(config.metric) | ||
if scoring_metric is None: | ||
raise ValueError(f"Performance metric {config.metric} not supported.") | ||
|
||
kwargs = dict( | ||
scoring=scoring_metric, | ||
num_cpus=config.cores, | ||
) | ||
# NAML wasn't really designed to run for long time constraints, so we | ||
# make it easy to run NAML with its default configuration for time/iterations. | ||
if not config.framework_params.get("_use_default_time_and_iterations", False): | ||
kwargs["timeout"] = config.max_runtime_seconds | ||
# NAML stops at its first met criterion: iterations or time. | ||
# To ensure time is the first criterion, set max_hpo_iterations very high | ||
kwargs["max_hpo_iterations"] = 1e10 | ||
# NAML has a static per-pipeline evaluation time of 10 seconds, | ||
# which is not accommodation for larger datasets. | ||
kwargs["execution_timeout"] = max(config.max_runtime_seconds // 20, 10) | ||
else: | ||
log.info("`_use_default_time_and_iterations` is set, ignoring time constraint.") | ||
|
||
kwargs |= {k: v for k, v in config.framework_params.items() if not k.startswith("_")} | ||
automl = NaiveAutoML(**kwargs) | ||
|
||
with Timer() as training: | ||
automl.fit(dataset.train.X, dataset.train.y) | ||
log.info(f"Finished fit in {training.duration}s.") | ||
|
||
is_classification = (config.type == 'classification') | ||
|
||
def infer(data: Union[str, pd.DataFrame]): | ||
test_data = pd.read_parquet(data) if isinstance(data, str) else data | ||
predict_fn = automl.predict_proba if is_classification else automl.predict | ||
return predict_fn(test_data) | ||
|
||
inference_times = {} | ||
if config.measure_inference_time: | ||
inference_times["file"] = measure_inference_times(infer, dataset.inference_subsample_files) | ||
inference_times["df"] = measure_inference_times( | ||
infer, | ||
[(1, dataset.test.X.sample(1, random_state=i)) for i in range(100)], | ||
) | ||
log.info(f"Finished inference time measurements.") | ||
|
||
with Timer() as predict: | ||
predictions = automl.predict(dataset.test.X) | ||
probabilities = automl.predict_proba(dataset.test.X) if is_classification else None | ||
log.info(f"Finished predict in {predict.duration}s.") | ||
|
||
save_artifacts(automl, config) | ||
|
||
return result( | ||
output_file=config.output_predictions_file, | ||
predictions=predictions, | ||
probabilities=probabilities, | ||
truth=dataset.test.y, | ||
# models_count=len(gama_automl._final_pop), | ||
training_duration=training.duration, | ||
predict_duration=predict.duration, | ||
inference_times=inference_times, | ||
target_is_encoded=is_classification, | ||
) | ||
|
||
|
||
def save_artifacts(naive_automl, config): | ||
artifacts = config.framework_params.get('_save_artifacts', ['history']) | ||
try: | ||
artifacts_dir = Path(output_subdir("artifacts", config)) | ||
if 'history' in artifacts: | ||
naive_automl.history.to_csv(artifacts_dir / "history.csv", index=False) | ||
|
||
if 'model' in artifacts: | ||
(artifacts_dir / "model_str.txt").write_text(str(naive_automl.chosen_model)) | ||
with open(artifacts_dir / "model.pkl", 'wb') as fh: | ||
pickle.dump(naive_automl.chosen_model, fh) | ||
except Exception: | ||
log.warning("Error when saving artifacts.", exc_info=True) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
call_run(run) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pandas # https://github.com/fmohr/naiveautoml/issues/19 | ||
ConfigSpace<0.7.1 # https://github.com/fmohr/naiveautoml/issues/20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env bash | ||
HERE=$(dirname "$0") | ||
VERSION=${1:-"stable"} | ||
REPO=${2:-"https://github.com/fmohr/naiveautoml"} | ||
PKG=${3:-"naiveautoml"} | ||
|
||
echo "NaiveAutoML/setup.sh" "$@" | ||
|
||
if [[ "$VERSION" == "latest" ]]; then | ||
VERSION="master" | ||
fi | ||
|
||
. ${HERE}/../shared/setup.sh ${HERE} true | ||
|
||
PIP install -r ${HERE}/requirements.txt | ||
|
||
# no __version__ available: https://github.com/fmohr/naiveautoml/issues/22 | ||
GET_VERSION_STABLE="import subprocess | ||
import re | ||
pip_list = subprocess.run('$pip_exec list'.split(), capture_output=True) | ||
match = re.search(r'naiveautoml\s+([^\n]+)', pip_list.stdout.decode(), flags=re.IGNORECASE) | ||
version, = match.groups() | ||
print(version)" | ||
|
||
|
||
if [[ "$VERSION" == "stable" ]]; then | ||
PIP install --no-cache-dir -U ${PKG} | ||
echo GET_VERSION_STABLE | ||
VERSION=$(PY -c "${GET_VERSION_STABLE}") | ||
elif [[ "$VERSION" =~ ^[0-9] ]]; then | ||
PIP install --no-cache-dir -U ${PKG}==${VERSION} | ||
else | ||
if [[ "$VERSION" =~ ^# ]]; then | ||
# Versions starting with a `#` are to be interpreted as commit hashes | ||
# The actual git clone command expects the hash without the `#` prefix. | ||
VERSION="${VERSION:1}" | ||
fi | ||
echo "Attempting to install from git+${REPO}.git@${VERSION}#egg=naiveautoml&subdirectory=python" | ||
PIP install -U "git+${REPO}.git@${VERSION}#egg=naiveautoml&subdirectory=python" | ||
fi | ||
|
||
echo $VERSION >> "${HERE}/.setup/installed" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,9 @@ MLPlanWEKA: | |
mlr3automl: | ||
version: 'latest' | ||
|
||
NaiveAutoML: | ||
version: 'latest' | ||
|
||
oboe: | ||
version: 'latest' | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,9 @@ MLPlanWEKA: | |
mlr3automl: | ||
version: 'stable' | ||
|
||
NaiveAutoML: | ||
version: 'stable' | ||
|
||
oboe: | ||
version: 'stable' | ||
|
||
|