-
Notifications
You must be signed in to change notification settings - Fork 0
/
call-retrain.py
82 lines (80 loc) · 2.92 KB
/
call-retrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import requests
import valohai
import os
auth_token = os.environ.get("VALOHAI_API_TOKEN")
dataset_path = valohai.parameters('dataset_path').value
resp = requests.request(
url="https://app.valohai.com/api/v0/pipelines/",
method="POST",
headers={"Authorization": f"Token {auth_token}"},
json={
"edges": [
{
"source_node": "preprocess-node",
"source_key": "*.csv",
"source_type": "output",
"target_node": "train-model-node",
"target_type": "input",
"target_key": "dataset"
}
],
"nodes": [
{
"name": "preprocess-node",
"type": "execution",
"on_error": "stop-all",
"template": {
"environment": "0167d05d-a1d7-cc02-8256-6455a6ecfa56",
"commit": "main",
"step": "preprocess",
"image": "python:3.9",
"command": "pip install evidently==0.4.16 pandas==2.2.1 scikit-learn==1.4.1.post1 valohai\npython preprocess.py",
"inputs": {
"dataset": [dataset_path]
},
"parameters": {},
"runtime_config": {},
"inherit_environment_variables": True,
"environment_variable_groups": [],
"tags": [],
"time_limit": 0,
"environment_variables": {},
"allow_reuse": False
}
},
{
"name": "train-model-node",
"type": "execution",
"on_error": "stop-all",
"template": {
"environment": "0167d05d-a1d7-cc02-8256-6455a6ecfa56",
"commit": "main",
"step": "train_model",
"image": "python:3.9",
"command": "pip install evidently==0.4.16 pandas==2.2.1 scikit-learn==1.4.1.post1\npython train_model.py",
"inputs": {
"dataset": [
"datum://california_housing_dataset"
]
},
"parameters": {},
"runtime_config": {},
"inherit_environment_variables": True,
"environment_variable_groups": [],
"tags": [],
"time_limit": 0,
"environment_variables": {},
"allow_reuse": False
}
}
],
"project": "0190fe5c-ca46-bf10-880d-d92127f69fd2",
"tags": [],
"parameters": {},
"title": "train-pipeline-triggered-because-drift-detected"
},
)
if resp.status_code == 400:
raise RuntimeError(resp.json())
resp.raise_for_status()
data = resp.json()