Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Fix typing issues of StringSetFlag #107

Merged
merged 2 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/107.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix typing issues of `StringSetFlag` by refactoring it using a separate interface definition file
57 changes: 57 additions & 0 deletions src/ai/backend/common/enum_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import enum

__all__ = (
'StringSetFlag',
)


class StringSetFlag(enum.Flag):

def __eq__(self, other):
return self.value == other

def __hash__(self):
return hash(self.value)

def __or__(self, other):
if isinstance(other, type(self)):
other = other.value
if not isinstance(other, (set, frozenset)):
other = set((other,))
return set((self.value,)) | other

__ror__ = __or__

def __and__(self, other):
if isinstance(other, (set, frozenset)):
return self.value in other
if isinstance(other, str):
return self.value == other
raise TypeError

__rand__ = __and__

def __xor__(self, other):
if isinstance(other, (set, frozenset)):
return set((self.value,)) ^ other
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __rxor__(self, other):
if isinstance(other, (set, frozenset)):
return other ^ set((self.value,))
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __str__(self):
return self.value
22 changes: 22 additions & 0 deletions src/ai/backend/common/enum_extension.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import enum


class StringSetFlag(enum.Flag):
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __or__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> set[str]: ...
def __and__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> bool: ...
def __xor__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> set[str]: ...
def __ror__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
def __rand__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> bool: ...
def __rxor__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
def __str__(self) -> str: ...
52 changes: 1 addition & 51 deletions src/ai/backend/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import base64
from collections import OrderedDict
from datetime import timedelta
import enum
from itertools import chain
import numbers
import random
Expand Down Expand Up @@ -34,6 +33,7 @@
current_loop,
run_through,
)
from .enum_extension import StringSetFlag # for legacy imports # noqa
from .files import AsyncFileWriter # for legacy imports # noqa
from .networking import ( # for legacy imports # noqa
curl,
Expand Down Expand Up @@ -198,56 +198,6 @@ def str_to_timedelta(tstr: str) -> timedelta:
return timedelta(**params) # type: ignore


class StringSetFlag(enum.Flag):

def __eq__(self, other):
return self.value == other

def __hash__(self):
return hash(self.value)

def __or__(self, other):
if isinstance(other, type(self)):
other = other.value
if not isinstance(other, (set, frozenset)):
other = set((other,))
return set((self.value,)) | other

__ror__ = __or__

def __and__(self, other):
if isinstance(other, (set, frozenset)):
return self.value in other
if isinstance(other, str):
return self.value == other
raise TypeError

__rand__ = __and__

def __xor__(self, other):
if isinstance(other, (set, frozenset)):
return set((self.value,)) ^ other
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __rxor__(self, other):
if isinstance(other, (set, frozenset)):
return other ^ set((self.value,))
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __str__(self):
return self.value


class FstabEntry:
"""
Entry class represents a non-comment line on the `fstab` file.
Expand Down
8 changes: 3 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import pytest

from ai.backend.common.asyncio import AsyncBarrier, run_through
from ai.backend.common.enum_extension import StringSetFlag
from ai.backend.common.files import AsyncFileWriter
from ai.backend.common.networking import curl
from ai.backend.common.utils import (
odict, dict2kvlist, nmget,
generate_uuid, get_random_seq,
readable_size_to_bytes,
str_to_timedelta,
StringSetFlag,
)
from ai.backend.common.testutils import (
mock_corofunc, mock_awaitable, AsyncContextManagerMock,
Expand Down Expand Up @@ -156,9 +156,7 @@ async def test_curl_returns_default_value_if_not_success(mocker) -> None:

def test_string_set_flag() -> None:

# FIXME: Remove "type: ignore" when mypy gets released with
# python/mypy#11579.
class MyFlags(StringSetFlag): # type: ignore
class MyFlags(StringSetFlag):
A = 'a'
B = 'b'

Expand All @@ -182,7 +180,7 @@ class MyFlags(StringSetFlag): # type: ignore
assert {'b'} == MyFlags.A ^ {'a', 'b'}
assert {'a', 'b', 'c'} == MyFlags.A ^ {'b', 'c'}
with pytest.raises(TypeError):
123 & MyFlags.A
123 & MyFlags.A # type: ignore[operator]

assert {'a', 'c'} & MyFlags.A
assert not {'a', 'c'} & MyFlags.B
Expand Down