-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
480 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
name: Lint | ||
|
||
on: | ||
push: | ||
branches: [main] | ||
pull_request: | ||
branches: [main] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out repo | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Set up uv | ||
uses: astral-sh/setup-uv@v2 | ||
with: | ||
enable-cache: true | ||
|
||
- name: Install dependencies | ||
run: uv pip install pre-commit --system | ||
|
||
- name: Run pre-commit | ||
run: pre-commit run --all-files --show-diff-on-failure |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: Tests | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
testing: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ['3.9', '3.10', '3.11' ] | ||
steps: | ||
- name: Check out repo | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Set up uv | ||
uses: astral-sh/setup-uv@v2 | ||
with: | ||
enable-cache: true | ||
|
||
- name: Install Dependencies | ||
run: uv pip install -r requirements.txt --system | ||
|
||
- name: Run tests | ||
run: | | ||
export ADD_DUMMY_TYPES=True | ||
uvicorn --app-dir=app app:app & sleep 10 | ||
pytest | ||
kill %1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from fastapi import FastAPI | ||
from routers.candidates import router as candidates_router | ||
from routers.versions import router as versions_router | ||
from starlette.responses import RedirectResponse | ||
|
||
|
||
app = FastAPI(title="BoFire Candidates API", version="0.1.0", root_path="/") | ||
|
||
|
||
@app.get("/", include_in_schema=False) | ||
async def redirect(): | ||
return RedirectResponse(url="/docs") | ||
|
||
|
||
app.include_router(candidates_router) | ||
app.include_router(versions_router) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import Optional | ||
|
||
from bofire.data_models.dataframes.api import Candidates, Experiments | ||
from bofire.data_models.strategies.api import AnyStrategy | ||
from pydantic import BaseModel, Field, model_validator | ||
|
||
|
||
class CandidateRequest(BaseModel): | ||
strategy_data: AnyStrategy | ||
n_candidates: int = Field( | ||
default=1, gt=0, description="Number of candidates to generate" | ||
) | ||
experiments: Optional[Experiments] | ||
pendings: Optional[Candidates] | ||
|
||
@model_validator(mode="after") | ||
def validate_experiments(self): | ||
if self.experiments is not None: | ||
self.strategy_data.domain.validate_experiments(self.experiments.to_pandas()) | ||
if self.pendings is not None: | ||
self.strategy_data.domain.validate_candidates( | ||
self.pendings.to_pandas(), only_inputs=True | ||
) | ||
return self |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import bofire.strategies.api as strategies | ||
from bofire.data_models.dataframes.api import Candidates | ||
from fastapi import APIRouter, HTTPException | ||
from models.candidates import CandidateRequest | ||
|
||
|
||
router = APIRouter(prefix="", tags=["candidates"]) | ||
|
||
|
||
@router.post("/candidates/generate", response_model=Candidates) | ||
def generate( | ||
candidate_request: CandidateRequest, | ||
) -> Candidates: | ||
strategy = strategies.map(candidate_request.strategy_data) | ||
if candidate_request.experiments is not None: | ||
strategy.tell(candidate_request.experiments.to_pandas()) | ||
try: | ||
df_candidates = strategy.ask(candidate_request.n_candidates) | ||
except ValueError as e: | ||
if str(e) == "Not enough experiments available to execute the strategy.": | ||
raise HTTPException(status_code=404, detail=str(e)) | ||
else: | ||
raise HTTPException( | ||
status_code=500, detail=f"A server error occurred. Details: {e}" | ||
) | ||
return Candidates.from_pandas(df_candidates, candidate_request.strategy_data.domain) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import bofire | ||
from fastapi import APIRouter | ||
|
||
|
||
router = APIRouter(prefix="/versions", tags=["versions"]) | ||
|
||
|
||
@router.get("", response_model=dict[str, str]) | ||
def get_versions() -> dict[str, str]: | ||
return {"bofire": bofire.__version__} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
bofire>=0.0.14 | ||
uvicorn | ||
fastapi | ||
pytest | ||
requests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
target-version = "py39" | ||
line-length = 88 | ||
output-format = "concise" | ||
|
||
[lint] | ||
select = ["B", "C", "E", "F", "W", "I"] | ||
ignore = [ | ||
"E501", # don't enforce for comments and docstrings | ||
"B017", # required for tests | ||
"B027", # required for optional _tell method | ||
"B028", | ||
"B904", | ||
"B905", | ||
] | ||
isort.split-on-trailing-comma = false | ||
isort.lines-after-imports = 2 | ||
|
||
[lint.mccabe] | ||
max-complexity = 18 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
|
||
import requests | ||
from pytest import fixture | ||
|
||
|
||
HEADERS = {"accept": "application/json", "Content-Type": "application/json"} | ||
|
||
|
||
class Client: | ||
def __init__(self, base_url: str, requests=requests): | ||
self.base_url = base_url | ||
self.requests = requests | ||
|
||
def get(self, path: str) -> requests.Response: | ||
return self.requests.get(f"{self.base_url}{path}", headers=HEADERS) | ||
|
||
def post(self, path: str, request_body: str) -> requests.Response: | ||
return self.requests.post( | ||
f"{self.base_url}{path}", data=request_body, headers=HEADERS | ||
) | ||
|
||
|
||
@fixture | ||
def client() -> Client: | ||
return Client(base_url=os.getenv("CANDIDATES_URL", "http://localhost:8000")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import json | ||
from typing import Optional | ||
|
||
from bofire.benchmarks.api import DTLZ2, Himmelblau | ||
from bofire.data_models.dataframes.api import Candidates, Experiments | ||
from bofire.data_models.strategies.api import ( | ||
AlwaysTrueCondition, | ||
AnyStrategy, | ||
NumberOfExperimentsCondition, | ||
RandomStrategy, | ||
SoboStrategy, | ||
Step, | ||
StepwiseStrategy, | ||
) | ||
from pydantic import BaseModel, Field | ||
|
||
from tests.conftest import Client | ||
|
||
|
||
class CandidateRequest(BaseModel): | ||
strategy_data: AnyStrategy | ||
n_candidates: int = Field( | ||
default=1, gt=0, description="Number of candidates to generate" | ||
) | ||
experiments: Optional[Experiments] | ||
pendings: Optional[Candidates] | ||
|
||
|
||
bench = Himmelblau() | ||
bench2 = DTLZ2(dim=6) | ||
experiments = bench.f(bench.domain.inputs.sample(15), return_complete=True) | ||
experiments2 = bench2.f(bench2.domain.inputs.sample(15), return_complete=True) | ||
|
||
strategy_data = StepwiseStrategy( | ||
domain=bench.domain, | ||
steps=[ | ||
Step( | ||
condition=NumberOfExperimentsCondition(n_experiments=10), | ||
strategy_data=RandomStrategy(domain=bench.domain), | ||
), | ||
Step( | ||
condition=AlwaysTrueCondition(), | ||
strategy_data=SoboStrategy(domain=bench.domain), | ||
), | ||
], | ||
) | ||
|
||
|
||
def test_candidates_request_validation(client: Client): | ||
cr = CandidateRequest( | ||
strategy_data=strategy_data, | ||
n_candidates=1, | ||
experiments=Experiments.from_pandas(experiments2, bench2.domain), | ||
pendings=None, | ||
) | ||
|
||
response = client.post( | ||
path="/candidates/generate", request_body=cr.model_dump_json() | ||
) | ||
assert response.status_code == 422 | ||
assert ( | ||
json.loads(response.content)["detail"][0]["msg"] | ||
== "Value error, no col for input feature `y`" | ||
) | ||
|
||
|
||
def test_candidates_missing_experiments(client: Client): | ||
cr = CandidateRequest( | ||
strategy_data=SoboStrategy(domain=bench.domain), | ||
n_candidates=1, | ||
experiments=None, | ||
pendings=None, | ||
) | ||
response = client.post( | ||
path="/candidates/generate", request_body=cr.model_dump_json() | ||
) | ||
assert response.status_code == 404 | ||
assert ( | ||
json.loads(response.content)["detail"] | ||
== "Not enough experiments available to execute the strategy." | ||
) | ||
|
||
|
||
def test_candidates_generate(client: Client): | ||
cr = CandidateRequest( | ||
strategy_data=strategy_data, | ||
n_candidates=2, | ||
experiments=None, | ||
pendings=None, | ||
) | ||
response = client.post( | ||
path="/candidates/generate", request_body=cr.model_dump_json() | ||
) | ||
df_candidates = Candidates(**json.loads(response.content)).to_pandas() | ||
assert df_candidates.shape[0] == 2 | ||
assert df_candidates.shape[1] == 2 | ||
assert sorted(df_candidates.columns.tolist()) == sorted( | ||
bench.domain.inputs.get_keys() | ||
) | ||
|
||
cr = CandidateRequest( | ||
strategy_data=strategy_data, | ||
n_candidates=1, | ||
experiments=Experiments.from_pandas(experiments, bench.domain), | ||
pendings=None, | ||
) | ||
response = client.post( | ||
path="/candidates/generate", request_body=cr.model_dump_json() | ||
) | ||
df_candidates = Candidates(**json.loads(response.content)).to_pandas() | ||
assert df_candidates.shape[0] == 1 | ||
assert df_candidates.shape[1] == 5 | ||
assert sorted(df_candidates.columns.tolist()) == sorted( | ||
bench.domain.candidate_column_names | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import json | ||
|
||
from tests.conftest import Client | ||
|
||
|
||
def test_candidates_request_validation(client: Client): | ||
response = client.get(path="/versions") | ||
assert response.status_code == 200 | ||
assert list(json.loads(response.content).keys()) == ["bofire"] |
Oops, something went wrong.