Skip to content

Commit

Permalink
test: [automl] add ccb test that checks for ft names
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Mar 13, 2023
1 parent 37f4b19 commit 533e067
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/python_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
shell: bash
run: |
pip install -r requirements.txt
pip install pytest twine
pip install pytest vw-executor twine
pip install built_wheel/*.whl
twine check built_wheel/*.whl
python -m pytest ./python/tests/
Expand Down Expand Up @@ -133,7 +133,7 @@ jobs:
shell: bash
run: |
pip install -r requirements.txt
pip install pytest
pip install pytest vw-executor
- name: Run unit tests
shell: bash
run: |
Expand Down Expand Up @@ -212,7 +212,7 @@ jobs:
source .env/bin/activate && \
pip install --upgrade pip && \
pip install -r requirements.txt && \
pip install pytest twine && \
pip install pytest vw-executor twine && \
pip install built_wheel/*.whl && \
twine check built_wheel/*.whl && \
python --version && \
Expand Down Expand Up @@ -272,7 +272,7 @@ jobs:
shell: bash
run: |
pip install -r requirements.txt
pip install pytest twine
pip install pytest vw-executor twine
pip install built_wheel/*.whl
twine check built_wheel/*.whl
python -m pytest ./python/tests/
Expand Down Expand Up @@ -361,7 +361,7 @@ jobs:
export wheel_file="${wheel_files[0]}"
echo Installing ${wheel_file}...
pip install -r requirements.txt
pip install pytest twine
pip install pytest vw-executor twine
pip install ${wheel_file}
twine check ${wheel_file}
python -m pytest .\\python\\tests\\
Expand Down
57 changes: 57 additions & 0 deletions python/tests/test_ccb.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,60 @@ def test_ccb_non_slot_none_outcome():
# CCB label is set to UNSET by default.
assert label.type == vowpalwabbit.CCBLabelType.UNSET
assert label.outcome is None


def test_ccb_and_automl():
import random, json, os, shutil
import numpy as np
from vw_executor.vw import Vw

people_ccb = ["Tom", "Anna"]
topics_ccb = ["sports", "politics", "music"]

def my_ccb_simulation(n=10000, swap_after=5000, variance=0, bad_features=0, seed=0):
random.seed(seed)
np.random.seed(seed)

envs = [[[0.8, 0.4], [0.2, 0.4]]]
offset = 0
for i in range(1, n):
person = random.randint(0, 1)
chosen = [int(i) for i in np.random.permutation(2)]
rewards = [envs[offset][person][chosen[0]], envs[offset][person][chosen[1]]]

for i in range(len(rewards)):
rewards[i] += np.random.normal(0.5, variance)

yield {
"c": {
"shared": {"name": people_ccb[person]},
"_multi": [{"a": {"topic": topics_ccb[i]}} for i in range(2)],
"_slots": [{"_id": i} for i in range(2)],
},
"_outcomes": [
{"_label_cost": -min(rewards[i], 1), "_a": chosen[i:], "_p": [1.0 / (2 - i)] * (2 - i)}
for i in range(2)
],
}

def save_examples(examples, path):
with open(path, "w") as f:
for ex in examples:
f.write(f'{json.dumps(ex, separators=(",", ":"))}\n')

input_file = "ccb.json"
cache_dir = ".cache"
save_examples(my_ccb_simulation(n=1000, variance=0.1, bad_features=1, seed=0), input_file)

assert os.path.exists(input_file)

vw = Vw(cache_dir, "/root/vowpal_wabbit/build/vowpalwabbit/cli/vw")
q = vw.train(input_file, "-b 18 -q :: --ccb_explore_adf --dsjson", ["--invert_hash"])
fts_names_q = set([n for n in q[0].model9("--invert_hash").weights.index])

assert len(fts_names_q) == 39

os.remove(input_file)
shutil.rmtree(cache_dir)
assert not os.path.exists(input_file)
assert not os.path.exists(cache_dir)

0 comments on commit 533e067

Please sign in to comment.