Skip to content

Commit

Permalink
Merge b86489c into 554bbfa
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike authored Aug 12, 2022
2 parents 554bbfa + b86489c commit 237f7cd
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 14 deletions.
113 changes: 113 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
item_from_fields,
item_from_fields_sync,
)
from web_poet.fields import fields_dict


@attrs.define
Expand Down Expand Up @@ -271,3 +272,115 @@ def to_item(self) -> Item:
page = ExtendedPage2(response=EXAMPLE_RESPONSE)
item = page.to_item()
assert item == Item(name="Hello!", price="$123")


def test_field_meta():
class MyPage(ItemPage):
@field(meta={"good": True})
def field1(self):
return "foo"

@field
def field2(self):
return "foo"

def to_item(self):
return item_from_fields_sync(self)

page = MyPage()
for fields in [fields_dict(MyPage), fields_dict(page)]:
assert list(fields.keys()) == ["field1", "field2"]
assert fields["field1"].name == "field1"
assert fields["field1"].meta == {"good": True}

assert fields["field2"].name == "field2"
assert fields["field2"].meta is None


def test_field_extra():
@attrs.define
class OnlyNameItem:
name: str

@attrs.define
class OnlyPriceItem:
price: str

class BasePage(ItemPage):
item_cls = OnlyNameItem

@field
def name(self): # noqa: D102
return "name"

@field(extra=True)
def price(self): # noqa: D102
return "price"

def to_item(self): # noqa: D102
return item_from_fields_sync(self, self.item_cls)

# BasePage contains field which is not in item class,
# but the field is defined as extra, so an exception is not raised
page = BasePage()
assert page.to_item() == OnlyNameItem(name="name")

class FullItemPage(BasePage):
item_cls = Item

# extra field is available in an item, so it's used now
page = FullItemPage()
assert page.to_item() == Item(name="name", price="price")

class OnlyPricePage(BasePage):
item_cls = OnlyPriceItem

# regular fields are always passed
page = OnlyPricePage()
with pytest.raises(TypeError, match="unexpected keyword argument 'name'"):
page.to_item()


@pytest.mark.asyncio
async def test_field_extra_async():
@attrs.define
class OnlyNameItem:
name: str

@attrs.define
class OnlyPriceItem:
price: str

class BasePage(ItemPage):
item_cls = OnlyNameItem

@field
async def name(self): # noqa: D102
return "name"

@field(extra=True)
async def price(self): # noqa: D102
return "price"

async def to_item(self): # noqa: D102
return await item_from_fields(self, self.item_cls)

# BasePage contains field which is not in item class,
# but the field is defined as extra, so an exception is not raised
page = BasePage()
assert await page.to_item() == OnlyNameItem(name="name")

class FullItemPage(BasePage):
item_cls = Item

# extra field is available in an item, so it's used now
page = FullItemPage()
assert await page.to_item() == Item(name="name", price="price")

class OnlyPricePage(BasePage):
item_cls = OnlyPriceItem

# regular fields are always passed
page = OnlyPricePage()
with pytest.raises(TypeError, match="unexpected keyword argument 'name'"):
await page.to_item()
63 changes: 49 additions & 14 deletions web_poet/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,39 @@ async def to_item(self):
"""
from functools import update_wrapper
from typing import Dict, Optional

import attrs
from itemadapter import ItemAdapter

from web_poet.utils import cached_method, ensure_awaitable

_FIELDS_ATTRIBUTE = "_marked_as_fields"
_FIELDS_INFO_ATTRIBUTE = "_web_poet_fields_info"


def field(method=None, *, cached=False):
@attrs.define
class FieldInfo:
name: str
meta: Optional[dict] = None
extra: bool = False


def field(method=None, *, cached: bool = False, meta: Optional[dict] = None, extra: bool = False):
"""
Page Object method decorated with ``@field`` decorator becomes a property,
which is used by :func:`item_from_fields` or :func:`item_from_fields_sync`
to populate item attributes.
By default, the value is computed on each property access.
Use ``@field(cached=True)`` to cache the property value.
Fields decorated with ``@field(extra=True)`` are not passed to item
classes by :func:`item_from_fields` if items don't support them, regardless
of ``item_cls_fields`` argument.
``meta`` parameter allows to store arbitrary information for the
field - e.g. ``@field(meta={"expensive": True})``. This information
can be later retrieved for all fields using :func:`fields_dict` function.
"""

class _field:
Expand All @@ -53,10 +70,11 @@ def __init__(self, method):
self.unbound_method = method

def __set_name__(self, owner, name):
if not hasattr(owner, _FIELDS_ATTRIBUTE):
# dict is used instead of set to preserve the insertion order
setattr(owner, _FIELDS_ATTRIBUTE, {})
getattr(owner, _FIELDS_ATTRIBUTE)[name] = True
if not hasattr(owner, _FIELDS_INFO_ATTRIBUTE):
setattr(owner, _FIELDS_INFO_ATTRIBUTE, {})

field_info = FieldInfo(name=name, meta=meta, extra=extra)
getattr(owner, _FIELDS_INFO_ATTRIBUTE)[name] = field_info

def __get__(self, instance, owner=None):
return self.unbound_method(instance)
Expand All @@ -71,30 +89,47 @@ def __get__(self, instance, owner=None):
return _field


def fields_dict(cls_or_instance) -> Dict[str, FieldInfo]:
"""Return a dictionary with information about the fields defined
for the class"""
return getattr(cls_or_instance, _FIELDS_INFO_ATTRIBUTE, {})


async def item_from_fields(obj, item_cls=dict, *, item_cls_fields=False):
"""Return an item of ``item_cls`` type, with its attributes populated
from the ``obj`` methods decorated with :class:`field` decorator.
If ``item_cls_fields`` is True, ``@fields`` whose names don't match
any of the ``item_cls`` attributes are not passed to ``item_cls.__init__``.
When ``item_cls_fields`` is False (default), all ``@fields`` are passed
to ``item_cls.__init__``.
to ``item_cls.__init__``, unless they're created with ``extra=True``
argument.
"""
item_dict = item_from_fields_sync(obj, item_cls=dict, item_cls_fields=False)
field_names = item_dict.keys()
if item_cls_fields:
field_names = _without_unsupported_field_names(item_cls, field_names)
field_names = _final_field_names(obj, item_cls, item_cls_fields)
item_dict = {name: getattr(obj, name) for name in field_names}
return item_cls(**{name: await ensure_awaitable(item_dict[name]) for name in field_names})


def item_from_fields_sync(obj, item_cls=dict, *, item_cls_fields=False):
"""Synchronous version of :func:`item_from_fields`."""
field_names = list(getattr(obj, _FIELDS_ATTRIBUTE, {}))
if item_cls_fields:
field_names = _without_unsupported_field_names(item_cls, field_names)
field_names = _final_field_names(obj, item_cls, item_cls_fields)
return item_cls(**{name: getattr(obj, name) for name in field_names})


def _final_field_names(obj, item_cls, item_cls_fields):
fields = fields_dict(obj)
extra_field_names = _without_unsupported_field_names(
item_cls, [info.name for info in fields.values() if info.extra]
)

regular_field_names = [info.name for info in fields.values() if not info.extra]
if item_cls_fields:
regular_field_names = _without_unsupported_field_names(item_cls, regular_field_names)

return regular_field_names + extra_field_names


def _without_unsupported_field_names(item_cls, field_names):
item_field_names = ItemAdapter.get_field_names_from_class(item_cls)
if item_field_names is None: # item_cls doesn't define field names upfront
Expand Down

0 comments on commit 237f7cd

Please sign in to comment.