Skip to content

Commit

Permalink
Allow comparison of SpectralType directly to string (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
teutoburg authored Sep 17, 2024
2 parents 4bff8c9 + 1407f2a commit 79f6524
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
37 changes: 30 additions & 7 deletions astar_utils/spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class SpectralType:
In this context, the luminosity class (if any) is ignored for sorting and
comparison (<, >, <=, >=), as it represents a second physical dimension.
However, instances of this class may also be compared for equality (== and
!=), in which case all three attributes are considered.
!=), in which case all three attributes are considered. It is also possible
to compare instances directly to strings, if the string is a valid
construtor for this class.
Attributes
----------
Expand Down Expand Up @@ -106,7 +108,7 @@ class SpectralType:
def __post_init__(self, spectype) -> None:
"""Validate input and populate fields."""
if not (match := self._regex.fullmatch(spectype)):
raise ValueError(spectype)
raise ValueError(f"{spectype!r} is not a valid spectral type.")

classes = match.groupdict()
# Circumvent frozen as per the docs...
Expand Down Expand Up @@ -178,14 +180,35 @@ def _comp_tuple(self) -> tuple[int, float]:
sub_cls = 5
return (self._spec_cls_idx, sub_cls)

@classmethod
def _comp_guard(cls, other):
if isinstance(other, str):
other = cls(other)
if not isinstance(other, cls):
raise TypeError("Can only compare equal types or valid str.")
return other

def __eq__(self, other) -> bool:
"""Return self == other."""
other = self._comp_guard(other)
return self._comp_tuple == other._comp_tuple

def __lt__(self, other) -> bool:
"""Return self < other."""
if not isinstance(other, self.__class__):
raise TypeError("Can only compare equal types.")
other = self._comp_guard(other)
return self._comp_tuple < other._comp_tuple

def __le__(self, other) -> bool:
"""Return self < other."""
if not isinstance(other, self.__class__):
raise TypeError("Can only compare equal types.")
"""Return self <= other."""
other = self._comp_guard(other)
return self._comp_tuple <= other._comp_tuple

def __gt__(self, other) -> bool:
"""Return self > other."""
other = self._comp_guard(other)
return self._comp_tuple > other._comp_tuple

def __ge__(self, other) -> bool:
"""Return self >= other."""
other = self._comp_guard(other)
return self._comp_tuple >= other._comp_tuple
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "astar-utils"
version = "0.3.1a1"
version = "0.3.1a2"
description = "Contains commonly-used utilities for AstarVienna's projects."
license = "GPL-3.0-or-later"
authors = ["Fabian Haberhauer <fabian.haberhauer@univie.ac.at>"]
Expand Down
29 changes: 29 additions & 0 deletions tests/test_spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,35 @@ def test_throws_on_invalid_compare(self, operation):
operation(SpectralType("A0V"), 42)


class TestComparesStr:
def test_lt(self):
assert SpectralType("A0V") < "A7V"

def test_le(self):
assert SpectralType("A0V") <= "A0V"

def test_gt(self):
assert SpectralType("A0V") > "B7V"

def test_ge(self):
assert SpectralType("A0V") >= "A0V"

def test_eq(self):
assert SpectralType("A0V") == "A0V"

def test_ne(self):
assert SpectralType("A0V") != "A1V"

def test_reverse_le(self):
assert "A0" <= SpectralType("A0V")

def test_reverse_gt(self):
assert "A7" > SpectralType("A0V")

def test_reverse_ne(self):
assert "A1" != SpectralType("A0V")


class TestRepresentations:
@pytest.mark.parametrize(("ssl_cls", "exptcted"),
[("A0V", "SpectralType('A0V')"),
Expand Down

0 comments on commit 79f6524

Please sign in to comment.