Skip to content

Commit

Permalink
Update for 2.1 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakuklev committed Aug 2, 2023
1 parent fdca591 commit 49f3b89
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion xopt/generators/bayesian/bax_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class BaxGenerator(BayesianGenerator):
alias = "BAX"
name = "BAX"
algorithm: Algorithm = Field(description="algorithm evaluated in the BAX process")
algorithm_results: Dict = Field(
None, description="dictionary results from algorithm", exclude=True
Expand Down
9 changes: 4 additions & 5 deletions xopt/generators/bayesian/bayesian_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ class BayesianGenerator(Generator, ABC):
None,
description="limits for travel distance between points in normalized space",
)
fixed_features: Dict[str, float] = Field(
fixed_features: Optional[Dict[str, float]] = Field(
None, description="fixed features used in Bayesian optimization"
)

@field_validator("model_constructor", mode='before')
def validate_model_constructor(cls, value):
print(f'Verifying model {value}')
constructor_dict = {"standard": StandardModelConstructor}
if value is None:
value = StandardModelConstructor()
Expand All @@ -83,9 +84,7 @@ def validate_model_constructor(cls, value):
else:
raise ValueError(f"{value} not found")
elif isinstance(value, dict):
#name = value.pop("name")
# name is ClassVar, not instance field
name = cls.name
name = value.pop("name")
if name in constructor_dict:
value = constructor_dict[name](**value)
else:
Expand All @@ -106,7 +105,7 @@ def validate_numerical_optimizer(cls, value):
else:
raise ValueError(f"{value} not found")
elif isinstance(value, dict):
name = cls.name
name = value.pop("name")
if name in optimizer_dict:
value = optimizer_dict[name](**value)
else:
Expand Down
2 changes: 1 addition & 1 deletion xopt/generators/bayesian/models/time_dependent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Union
from typing import Dict, List, Union

import pandas as pd

Expand Down
2 changes: 0 additions & 2 deletions xopt/numerical_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def optimize(self, function, bounds, n_candidates=1):

class LBFGSOptimizer(NumericalOptimizer):
name: str = Field("LBFGS", frozen=True)
name2: Literal['test'] = 'test2'
n_raw_samples: PositiveInt = Field(
20,
description="number of raw samples used to seed optimization",
Expand Down Expand Up @@ -71,7 +70,6 @@ class GridOptimizer(NumericalOptimizer):
"""

name: str = Field("grid", frozen=True)
name3: Literal['test'] = "test"
n_grid_points: PositiveInt = Field(
10, description="number of grid points per axis used for optimization"
)
Expand Down
18 changes: 14 additions & 4 deletions xopt/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def recursive_serialize(v, base_key=""):
# This will iterate model fields
for key, value in dict(v).items():
# print(f'{v=} {key=} {value=}')
print(f'{v=} {key=} {value=}')
if isinstance(value, dict):
v[key] = recursive_serialize(value, base_key=key)
elif isinstance(value, torch.nn.Module):
Expand Down Expand Up @@ -99,10 +99,20 @@ def orjson_dumps(v: BaseModel, *, base_key=""):
json_encoder = partial(custom_pydantic_encoder, JSON_ENCODERS)
return orjson_dumps_custom(v, default=json_encoder, base_key=base_key)

def orjson_dumps_except_root(v: BaseModel, *, base_key=""):
# to ensure pydantic custom serializer works, needs to return dict and not string
# so, serialize subfields instead of base model
json_encoder = partial(custom_pydantic_encoder, JSON_ENCODERS)
v2 = recursive_serialize(v.model_dump(), base_key=base_key)
data = {}
for field, fv in dict(v2).items():
data[field] = orjson.dumps(fv, default=json_encoder).decode()
return data

def orjson_dumps_custom(v: BaseModel, *, default, base_key=""):
v = recursive_serialize(v.model_dump(), base_key=base_key)
return orjson.dumps(v, default=default).decode()
dump = orjson.dumps(v, default=default).decode()
return dump


def orjson_loads(v, default=None):
Expand Down Expand Up @@ -139,9 +149,9 @@ def validate_files(cls, value, info: FieldValidationInfo):

return value

@model_serializer(mode='plain', when_used='json', return_type='str')
@model_serializer(mode='plain', when_used='json')
def serialize_json(self, base_key=""):
return orjson_dumps(self, base_key=base_key)
return orjson_dumps_except_root(self, base_key=base_key)

# TODO: implement json load parsing on main object (json_loads is gone)

Expand Down

0 comments on commit 49f3b89

Please sign in to comment.