Skip to content

Commit

Permalink
feat: add additional operations to BaseFlags (#1486)
Browse files Browse the repository at this point in the history
  • Loading branch information
celsiusnarhwal authored Jul 12, 2022
1 parent fedee05 commit 7accb95
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 2 deletions.
148 changes: 146 additions & 2 deletions discord/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Tuple,
Type,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -140,6 +141,44 @@ def __iter__(self) -> Iterator[Tuple[str, bool]]:
if isinstance(value, flag_value):
yield (name, self._has_flag(value.flag))

def __and__(self, other):
if isinstance(other, self.__class__):
return self.__class__._from_value(self.value & other.value)
elif isinstance(other, flag_value):
return self.__class__._from_value(self.value & other.flag)
else:
raise TypeError(f"'&' not supported between instances of {type(self)} and {type(other)}")

def __or__(self, other):
if isinstance(other, self.__class__):
return self.__class__._from_value(self.value | other.value)
elif isinstance(other, flag_value):
return self.__class__._from_value(self.value | other.flag)
else:
raise TypeError(f"'|' not supported between instances of {type(self)} and {type(other)}")

def __add__(self, other):
try:
return self | other
except TypeError:
raise TypeError(f"'+' not supported between instances of {type(self)} and {type(other)}")

def __sub__(self, other):
if isinstance(other, self.__class__):
return self.__class__._from_value(self.value & ~other.value)
elif isinstance(other, flag_value):
return self.__class__._from_value(self.value & ~other.flag)
else:
raise TypeError(f"'-' not supported between instances of {type(self)} and {type(other)}")

def __invert__(self):
return self.__class__._from_value(~self.value)

__rand__: Callable[[Union[BaseFlags, flag_value]], bool] = __and__
__ror__: Callable[[Union[BaseFlags, flag_value]], bool] = __or__
__radd__: Callable[[Union[BaseFlags, flag_value]], bool] = __add__
__rsub__: Callable[[Union[BaseFlags, flag_value]], bool] = __sub__

def _has_flag(self, o: int) -> bool:
return (self.value & o) == o

Expand Down Expand Up @@ -169,6 +208,21 @@ class SystemChannelFlags(BaseFlags):
.. describe:: x != y
Checks if two flags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -242,6 +296,21 @@ class MessageFlags(BaseFlags):
.. describe:: x != y
Checks if two flags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -337,6 +406,21 @@ class PublicUserFlags(BaseFlags):
.. describe:: x != y
Checks if two PublicUserFlags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -482,6 +566,21 @@ class Intents(BaseFlags):
.. describe:: x != y
Checks if two flags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -970,7 +1069,7 @@ def scheduled_events(self):
- :meth:`Guild.get_scheduled_event`
"""
return 1 << 16

@flag_value
def auto_moderation_configuration(self):
""":class:`bool`: Whether guild auto moderation configuration events are enabled.
Expand All @@ -982,7 +1081,7 @@ def auto_moderation_configuration(self):
- :func:`on_auto_moderation_rule_delete`
"""
return 1 << 20

@flag_value
def auto_moderation_execution(self):
""":class:`bool`: Whether guild auto moderation execution events are enabled.
Expand Down Expand Up @@ -1022,6 +1121,21 @@ class MemberCacheFlags(BaseFlags):
.. describe:: x != y
Checks if two flags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -1146,6 +1260,21 @@ class ApplicationFlags(BaseFlags):
.. describe:: x != y
Checks if two ApplicationFlags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down Expand Up @@ -1244,6 +1373,21 @@ class ChannelFlags(BaseFlags):
.. describe:: x != y
Checks if two ChannelFlags are not equal.
.. describe:: x + y
Adds two flags together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two flags from each other.
.. describe:: x | y
Returns the union of two flags. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two flags.
.. describe:: ~x
Returns the inverse of a flag.
.. describe:: hash(x)
Return the flag's hash.
Expand Down
16 changes: 16 additions & 0 deletions discord/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class Permissions(BaseFlags):
Checks if a permission is a strict subset of another permission.
.. describe:: x > y
.. describe:: x + y
Adds two permissions together. Equivalent to ``x | y``.
.. describe:: x - y
Subtracts two permissions from each other.
.. describe:: x | y
Returns the union of two permissions. Equivalent to ``x + y``.
.. describe:: x & y
Returns the intersection of two permissions.
.. describe:: ~x
Returns the inverse of a permission.
Checks if a permission is a strict superset of another permission.
.. describe:: hash(x)
Expand Down

0 comments on commit 7accb95

Please sign in to comment.