From 8ad4272d9a545066a217ec5cf540690796709c9d Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 31 Aug 2023 22:17:28 +0100 Subject: [PATCH 1/6] #26 allow user-defined params --- bpx/schema.py | 22 +++++++++++++++++++--- tests/test_schema.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/bpx/schema.py b/bpx/schema.py index 07c76f6..bd13ed8 100644 --- a/bpx/schema.py +++ b/bpx/schema.py @@ -1,10 +1,10 @@ -from typing import List, Literal, Union, Dict +from typing import List, Literal, Union, Dict, get_args -from pydantic import BaseModel, Field, Extra +from pydantic import BaseModel, Field, Extra, root_validator from bpx import Function, InterpolatedTable -FloatFunctionTable = Union[float, Function, InterpolatedTable] +FloatFunctionTable = Union[int, float, Function, InterpolatedTable] class ExtraBaseModel(BaseModel): @@ -248,6 +248,18 @@ class Electrode(Contact): ) +class UserDefined(BaseModel): + class Config: + extra = Extra.allow + + @root_validator(pre=True) + def validate_extra_fields(cls, values): + for k, v in values.items(): + if not isinstance(v, get_args(FloatFunctionTable)): + raise TypeError(f"{k} must be of type 'FloatFunctionTable'") + return values + + class Experiment(ExtraBaseModel): time: List[float] = Field( alias="Time [s]", @@ -288,6 +300,10 @@ class Parameterisation(ExtraBaseModel): separator: Contact = Field( alias="Separator", ) + user_defined: UserDefined = Field( + None, + alias="User defined", + ) class BPX(ExtraBaseModel): diff --git a/tests/test_schema.py b/tests/test_schema.py index 797910e..e03488e 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -144,6 +144,26 @@ def test_validation_data(self): }, } + def test_user_defined(self): + test = copy.copy(self.base) + test["Parameterisation"]["User defined"] = { + "a": 1, + "b": 2.0, + "c": 3.0, + } + obj = parse_obj_as(BPX, test) + self.assertEqual(obj.parameterisation.user_defined.a, 1) + self.assertEqual(obj.parameterisation.user_defined.b, 2) + self.assertEqual(obj.parameterisation.user_defined.c, 3) + + def test_bad_user_defined(self): + test = copy.copy(self.base) + test["Parameterisation"]["User defined"] = { + "bad": "strings aren't allowed", + } + with self.assertRaises(ValidationError): + parse_obj_as(BPX, test) + if __name__ == "__main__": unittest.main() From 274dc79a37d1497d515eddb5c849776454e88f63 Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 31 Aug 2023 22:20:23 +0100 Subject: [PATCH 2/6] #26 changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31b32cf..7993cc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# Unreleased +- Allow user-defined parameters to be added using the field ["Parameterisation"]["User-Defined"] ([#44](https://github.com/pybamm-team/BPX/pull/44)) + + # [v0.3.0](https://github.com/pybamm-team/BPX/releases/tag/v0.3.1) - Temporarily pin Pydantic version ([#35](https://github.com/pybamm-team/BPX/pull/35)) # [v0.3.0](https://github.com/pybamm-team/BPX/releases/tag/v0.3.0) From fc3806201ca311e274ce3fa44f72e5f33719235b Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 5 Oct 2023 13:51:18 +0100 Subject: [PATCH 3/6] #26 fix types for user-defined parameters --- bpx/schema.py | 16 ++++++++++++++-- tests/test_schema.py | 18 +++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/bpx/schema.py b/bpx/schema.py index bd58f6a..1e2a8a2 100644 --- a/bpx/schema.py +++ b/bpx/schema.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Union, Dict, get_args +from typing import List, Literal, Union, Dict from pydantic import BaseModel, Field, Extra, root_validator from bpx import Function, InterpolatedTable from warnings import warn @@ -273,10 +273,22 @@ class UserDefined(BaseModel): class Config: extra = Extra.allow + def __init__(Self, **data): + """ + Overwrite the default __init__ to convert strings to Function objects and + dicts to InterpolatedTable objects + """ + for k, v in data.items(): + if isinstance(v, str): + data[k] = Function(v) + elif isinstance(v, dict): + data[k] = InterpolatedTable(**v) + super().__init__(**data) + @root_validator(pre=True) def validate_extra_fields(cls, values): for k, v in values.items(): - if not isinstance(v, get_args(FloatFunctionTable)): + if not isinstance(v, (float, Function, InterpolatedTable)): raise TypeError(f"{k} must be of type 'FloatFunctionTable'") return values diff --git a/tests/test_schema.py b/tests/test_schema.py index 4353137..6bb4840 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -270,10 +270,26 @@ def test_user_defined(self): self.assertEqual(obj.parameterisation.user_defined.b, 2) self.assertEqual(obj.parameterisation.user_defined.c, 3) + def test_user_defined_table(self): + test = copy.copy(self.base) + test["Parameterisation"]["User-defined"] = { + "a": { + "x": [1.0, 2.0], + "y": [2.3, 4.5], + }, + } + parse_obj_as(BPX, test) + + def test_user_defined_function(self): + test = copy.copy(self.base) + test["Parameterisation"]["User-defined"] = {"a": "2.0 * x"} + parse_obj_as(BPX, test) + def test_bad_user_defined(self): test = copy.copy(self.base) + # bool not allowed type test["Parameterisation"]["User-defined"] = { - "bad": "strings aren't allowed", + "bad": True, } with self.assertRaises(ValidationError): parse_obj_as(BPX, test) From f116efd5221b2929a713955737474cb621331498 Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 5 Oct 2023 13:52:16 +0100 Subject: [PATCH 4/6] #26 use get_args --- bpx/schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bpx/schema.py b/bpx/schema.py index 1e2a8a2..3908f6b 100644 --- a/bpx/schema.py +++ b/bpx/schema.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Union, Dict +from typing import List, Literal, Union, Dict, get_args from pydantic import BaseModel, Field, Extra, root_validator from bpx import Function, InterpolatedTable from warnings import warn @@ -288,7 +288,7 @@ def __init__(Self, **data): @root_validator(pre=True) def validate_extra_fields(cls, values): for k, v in values.items(): - if not isinstance(v, (float, Function, InterpolatedTable)): + if not isinstance(v, get_args(FloatFunctionTable)): raise TypeError(f"{k} must be of type 'FloatFunctionTable'") return values From dd506d8470519471c8fc2f6c19b1a78442083a0c Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 5 Oct 2023 13:53:21 +0100 Subject: [PATCH 5/6] #26 use get_args --- bpx/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bpx/schema.py b/bpx/schema.py index 3908f6b..888cafc 100644 --- a/bpx/schema.py +++ b/bpx/schema.py @@ -273,7 +273,7 @@ class UserDefined(BaseModel): class Config: extra = Extra.allow - def __init__(Self, **data): + def __init__(self, **data): """ Overwrite the default __init__ to convert strings to Function objects and dicts to InterpolatedTable objects From b797fea2246c98bbda5f2e7df93fdb3af5dbef2e Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Thu, 5 Oct 2023 15:48:50 +0100 Subject: [PATCH 6/6] #26 update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7637464..aa70310 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Unreleased -- Allow user-defined parameters to be added using the field ["Parameterisation"]["User-Defined"] ([#44](https://github.com/pybamm-team/BPX/pull/44)) +- Allow user-defined parameters to be added using the field ["Parameterisation"]["User-defined"] ([#44](https://github.com/pybamm-team/BPX/pull/44)) - Added validation based on models: SPM, SPMe, DFN ([#34](https://github.com/pybamm-team/BPX/pull/34)). A warning will be produced if the user-defined model type does not match the parameter set (e.g., if the model is `SPM`, but the full DFN model parameters are provided). - Added support for well-mixed, blended electrodes that contain more than one active material ([#33](https://github.com/pybamm-team/BPX/pull/33))