Skip to content

Commit

Permalink
fix(bc): only remove terms when asked
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jun 10, 2024
1 parent ccb2e9e commit 6f06908
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
updated_cfg = binding_constraints[index]
updated_cfg.update(obj)

updated_terms = set(self.coeffs) if self.coeffs else set()

# Remove the terms not in the current update but existing in the config
terms_to_remove = {key for key in updated_cfg if ("%" in key or "." in key) and key not in updated_terms}
for term_id in terms_to_remove:
updated_cfg.pop(term_id, None)
updated_properties = self.dict(exclude={"command_context", "command_name", "version", "id"}, exclude_none=True)
# This 2nd check is here to remove the last term.
if self.coeffs or updated_properties == {"coeffs": {}}:
# Remove terms which IDs contain a "%" or a "." in their name
term_ids = {k for k in updated_cfg if "%" in k or "." in k}
binding_constraints[index] = {k: v for k, v in updated_cfg.items() if k not in term_ids}

return super().apply_binding_constraint(study_data, binding_constraints, index, self.id, old_groups=old_groups)

Expand Down
17 changes: 17 additions & 0 deletions tests/integration/study_data_blueprint/test_binding_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,23 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st
]
assert constraint_terms == expected

# Update random field, shouldn't remove the term.
res = client.put(
f"v1/studies/{study_id}/bindingconstraints/{bc_id}",
json={"enabled": False},
headers=user_headers,
)
assert res.status_code == 200

res = client.get(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}",
headers=user_headers,
)
assert res.status_code == 200, res.json()
binding_constraint = res.json()
constraint_terms = binding_constraint["terms"]
assert constraint_terms == expected

# =============================
# GENERAL EDITION
# =============================
Expand Down

0 comments on commit 6f06908

Please sign in to comment.