-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from experimental-design/feature/jobber
Jobber based Candidate Generation
- Loading branch information
Showing
19 changed files
with
2,195 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
db.json | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import datetime | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from bofire.data_models.base import BaseModel | ||
from bofire.data_models.dataframes.api import Candidates, Experiments | ||
from bofire.data_models.strategies.api import AnyStrategy | ||
from pydantic import Field, model_validator | ||
|
||
|
||
class ProposalRequest(BaseModel): | ||
"""Request model for generating candidates.""" | ||
|
||
strategy_data: AnyStrategy = Field(description="BoFire strategy data") | ||
n_candidates: int = Field( | ||
default=1, gt=0, description="Number of candidates to generate" | ||
) | ||
experiments: Optional[Experiments] = Field( | ||
default=None, description="Experiments to provide to the strategy" | ||
) | ||
pendings: Optional[Candidates] = Field( | ||
default=None, description="Candidates that are pending to be executed" | ||
) | ||
|
||
@model_validator(mode="after") | ||
def validate_experiments(self): | ||
"""Validates the experiments.""" | ||
if self.experiments is not None: | ||
self.strategy_data.domain.validate_experiments(self.experiments.to_pandas()) | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def validate_pendings(self): | ||
"""Validates that pendings are None.""" | ||
if self.pendings is not None: | ||
raise ValueError("Pendings must be None for proposals.") | ||
return self | ||
|
||
|
||
class StateEnum(str, Enum): | ||
"""Enum for the state of a proposal.""" | ||
|
||
CREATED = "CREATED" | ||
CLAIMED = "CLAIMED" | ||
FAILED = "FAILED" | ||
FINISHED = "FINISHED" | ||
|
||
|
||
class Proposal(ProposalRequest): | ||
"""Model for a candidates proposal.""" | ||
|
||
id: Optional[int] = Field(default=None, description="Proposal ID") | ||
candidates: Optional[Candidates] = Field( | ||
default=None, description="Candidates generated by the proposal" | ||
) | ||
created_at: datetime.datetime = Field( | ||
default_factory=datetime.datetime.now, | ||
description="Timestamp when the proposal was created", | ||
) | ||
last_updated_at: datetime.datetime = Field( | ||
default_factory=datetime.datetime.now, | ||
description="Timestamp when the proposal was last updated", | ||
) | ||
state: StateEnum = Field( | ||
default=StateEnum.CREATED, description="State of the proposal" | ||
) | ||
error_message: Optional[str] = Field( | ||
default=None, description="Error message if the proposal failed" | ||
) | ||
|
||
@model_validator(mode="after") | ||
def validate_candidates(self): | ||
"""Validates the candidates.""" | ||
if self.candidates is not None: | ||
self.strategy_data.domain.validate_candidates( | ||
self.candidates.to_pandas(), only_inputs=True | ||
) | ||
if len(self.candidates.rows) != self.n_candidates: | ||
raise ValueError( | ||
f"Number of candidates ({len(self.candidates.rows)}) does not " | ||
"match n_candidates ({self.n_candidates})." | ||
) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.