Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async web3 #2819

Merged
merged 5 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions ens/async_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@
)

if TYPE_CHECKING:
from web3 import Web3 # noqa: F401
from web3.contract import ( # noqa: F401
AsyncContract,
)
from web3.main import AsyncWeb3 # noqa: F401
from web3.providers import ( # noqa: F401
AsyncBaseProvider,
BaseProvider,
)
from web3.types import ( # noqa: F401
Middleware,
AsyncMiddleware,
TxParams,
)

Expand All @@ -88,11 +88,14 @@ class AsyncENS(BaseENS):
like: ``"0x314159265dD8dbb310642f98f50C066173C1259b"``
"""

# mypy types
w3: "AsyncWeb3"

def __init__(
self,
provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default),
addr: ChecksumAddress = None,
middlewares: Optional[Sequence[Tuple["Middleware", str]]] = None,
middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = None,
) -> None:
"""
:param provider: a single provider used to connect to Ethereum
Expand All @@ -110,7 +113,7 @@ def __init__(
)

@classmethod
def from_web3(cls, w3: "Web3", addr: ChecksumAddress = None) -> "AsyncENS":
def from_web3(cls, w3: "AsyncWeb3", addr: ChecksumAddress = None) -> "AsyncENS":
"""
Generate an AsyncENS instance with web3

Expand All @@ -120,10 +123,15 @@ def from_web3(cls, w3: "Web3", addr: ChecksumAddress = None) -> "AsyncENS":
"""
provider = w3.manager.provider
middlewares = w3.middleware_onion.middlewares
return cls(
ns = cls(
cast("AsyncBaseProvider", provider), addr=addr, middlewares=middlewares
)

# inherit strict bytes checking from w3 instance
ns.strict_bytes_type_checking = w3.strict_bytes_type_checking

return ns

async def address(self, name: str) -> Optional[ChecksumAddress]:
"""
Look up the Ethereum address that `name` currently points to.
Expand Down Expand Up @@ -179,7 +187,7 @@ async def setup_address(
transact["from"] = owner

resolver: "AsyncContract" = await self._set_resolver(name, transact=transact)
return await resolver.functions.setAddr( # type: ignore
return await resolver.functions.setAddr(
raw_name_to_hash(name), address
).transact(transact)

Expand Down Expand Up @@ -394,9 +402,9 @@ async def set_text(
r = await self.resolver(normal_name)
if r:
if await _async_resolver_supports_interface(r, GET_TEXT_INTERFACE_ID):
return await r.functions.setText( # type: ignore
node, key, value
).transact(transaction_dict)
return await r.functions.setText(node, key, value).transact(
transaction_dict
)
else:
raise UnsupportedFunction(
f"Resolver for name `{name}` does not support `text` function"
Expand Down Expand Up @@ -494,7 +502,7 @@ async def _assert_control(
name: str,
parent_owned: Optional[str] = None,
) -> None:
if not address_in(account, await self.w3.eth.accounts): # type: ignore
if not address_in(account, await self.w3.eth.accounts):
raise UnauthorizedError(
f"in order to modify {name!r}, you must control account"
f" {account!r}, which owns {parent_owned or name!r}"
Expand Down Expand Up @@ -548,7 +556,7 @@ async def _setup_reverse(
transact = deepcopy(transact)
transact["from"] = address
reverse_registrar = await self._reverse_registrar()
return await reverse_registrar.functions.setName(name).transact(transact) # type: ignore # noqa: E501
return await reverse_registrar.functions.setName(name).transact(transact)

async def _reverse_registrar(self) -> "AsyncContract":
addr = await self.ens.caller.owner(
Expand Down
7 changes: 5 additions & 2 deletions ens/base_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
)

if TYPE_CHECKING:
from web3 import Web3 # noqa: F401
from web3 import ( # noqa: F401
AsyncWeb3,
Web3,
)
from web3.contract import ( # noqa: F401
AsyncContract,
Contract,
)


class BaseENS:
w3: "Web3" = None
w3: Union["AsyncWeb3", "Web3"] = None
ens: Union["Contract", "AsyncContract"] = None
_resolver_contract: Union[Type["Contract"], Type["AsyncContract"]] = None
_reverse_resolver_contract: Union[Type["Contract"], Type["AsyncContract"]] = None
Expand Down
3 changes: 3 additions & 0 deletions ens/ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class ENS(BaseENS):
like: ``"0x314159265dD8dbb310642f98f50C066173C1259b"``
"""

# mypy types
w3: "Web3"

def __init__(
self,
provider: "BaseProvider" = cast("BaseProvider", default),
Expand Down
14 changes: 8 additions & 6 deletions ens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

if TYPE_CHECKING:
from web3 import ( # noqa: F401
AsyncWeb3,
Web3 as _Web3,
)
from web3.providers import ( # noqa: F401
Expand All @@ -63,6 +64,7 @@
)
from web3.types import ( # noqa: F401
ABIFunction,
AsyncMiddleware,
Middleware,
RPCEndpoint,
)
Expand Down Expand Up @@ -288,10 +290,10 @@ def get_abi_output_types(abi: "ABIFunction") -> List[str]:

def init_async_web3(
provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default),
middlewares: Optional[Sequence[Tuple["Middleware", str]]] = (),
) -> "_Web3":
middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = (),
) -> "AsyncWeb3":
from web3 import (
Web3 as Web3Main,
AsyncWeb3 as AsyncWeb3Main,
)
from web3.eth import (
AsyncEth as AsyncEthMain,
Expand All @@ -306,11 +308,11 @@ def init_async_web3(
middlewares.append((_async_ens_stalecheck_middleware, "stalecheck"))

if provider is default:
async_w3 = Web3Main(
async_w3 = AsyncWeb3Main(
middlewares=middlewares, ens=None, modules={"eth": (AsyncEthMain)}
)
else:
async_w3 = Web3Main(
async_w3 = AsyncWeb3Main(
provider,
middlewares=middlewares,
ens=None,
Expand All @@ -321,7 +323,7 @@ def init_async_web3(


async def _async_ens_stalecheck_middleware(
make_request: Callable[["RPCEndpoint", Any], Any], w3: "_Web3"
make_request: Callable[["RPCEndpoint", Any], Any], w3: "AsyncWeb3"
) -> "Middleware":
from web3.middleware import (
async_make_stalecheck_middleware,
Expand Down
4 changes: 1 addition & 3 deletions ethpm/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ def get_contract_instance(self, name: ContractName, address: Address) -> Contrac
self.manifest["contractTypes"][name]
)
contract_instance = self.w3.eth.contract(address=address, **contract_kwargs)
# TODO: type ignore may be able to be removed after
# more of AsyncContract is finished
return contract_instance # type: ignore
return contract_instance

#
# Build Dependencies
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2819.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use ``AsyncWeb3`` class and preserve typing for the async api calls.
8 changes: 2 additions & 6 deletions tests/core/contracts/test_contract_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def test_initial_greeting(foo_contract):

def test_can_update_greeting(w3, foo_contract):
# send transaction that updates the greeting
tx_hash = foo_contract.functions.setBar(
"testing contracts is easy",
).transact(
tx_hash = foo_contract.functions.setBar("testing contracts is easy").transact(
{
"from": w3.eth.accounts[1],
}
Expand All @@ -99,9 +97,7 @@ def test_can_update_greeting(w3, foo_contract):

def test_updating_greeting_emits_event(w3, foo_contract):
# send transaction that updates the greeting
tx_hash = foo_contract.functions.setBar(
"testing contracts is easy",
).transact(
tx_hash = foo_contract.functions.setBar("testing contracts is easy").transact(
{
"from": w3.eth.accounts[1],
}
Expand Down
10 changes: 6 additions & 4 deletions tests/core/middleware/test_eth_tester_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def test_get_transaction_count_formatters(w3, block_number):


def test_get_block_formatters(w3):
latest_block = w3.eth.get_block("latest")
all_block_keys = BlockData.__annotations__.keys()
all_non_poa_block_keys = set(
[k for k in all_block_keys if k != "proofOfAuthorityData"]
)

all_block_keys = set(BlockData.__annotations__.keys())
latest_block = w3.eth.get_block("latest")
latest_block_keys = set(latest_block.keys())

assert all_block_keys == latest_block_keys
assert all_non_poa_block_keys == latest_block_keys


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/core/middleware/test_simple_cache_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid

from web3 import (
AsyncWeb3,
Web3,
)
from web3._utils.caching import (
Expand Down Expand Up @@ -166,7 +167,7 @@ async def _async_simple_cache_middleware_for_testing(make_request, async_w3):

@pytest.fixture
def async_w3():
return Web3(
return AsyncWeb3(
provider=AsyncEthereumTesterProvider(),
middlewares=[
(_async_simple_cache_middleware_for_testing, "simple_cache"),
Expand Down
12 changes: 6 additions & 6 deletions tests/core/providers/test_async_http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)

from web3 import (
Web3,
AsyncWeb3,
)
from web3._utils import (
request,
Expand All @@ -14,10 +14,10 @@
AsyncEth,
)
from web3.geth import (
AsyncGeth,
AsyncGethAdmin,
AsyncGethPersonal,
AsyncGethTxPool,
Geth,
)
from web3.middleware import (
async_attrdict_middleware,
Expand All @@ -37,25 +37,25 @@

def test_no_args():
provider = AsyncHTTPProvider()
w3 = Web3(provider)
w3 = AsyncWeb3(provider)
assert w3.manager.provider == provider
assert w3.manager.provider.is_async


def test_init_kwargs():
provider = AsyncHTTPProvider(endpoint_uri=URI, request_kwargs={"timeout": 60})
w3 = Web3(provider)
w3 = AsyncWeb3(provider)
assert w3.manager.provider == provider


def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> None:
async_w3 = Web3(AsyncHTTPProvider(endpoint_uri=URI))
async_w3 = AsyncWeb3(AsyncHTTPProvider(endpoint_uri=URI))

# assert default modules

assert isinstance(async_w3.eth, AsyncEth)
assert isinstance(async_w3.net, AsyncNet)
assert isinstance(async_w3.geth, Geth)
assert isinstance(async_w3.geth, AsyncGeth)
assert isinstance(async_w3.geth.admin, AsyncGethAdmin)
assert isinstance(async_w3.geth.personal, AsyncGethPersonal)
assert isinstance(async_w3.geth.txpool, AsyncGethTxPool)
Expand Down
8 changes: 2 additions & 6 deletions tests/ens/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@
simple_resolver_bytecode_runtime,
)
from web3 import (
AsyncWeb3,
Web3,
)
from web3.contract import (
AsyncContract,
Contract,
)
from web3.eth import (
AsyncEth,
)
from web3.providers.eth_tester import (
AsyncEthereumTesterProvider,
EthereumTesterProvider,
Expand Down Expand Up @@ -359,9 +357,7 @@ def TEST_ADDRESS(address_conversion_func):
@pytest_asyncio.fixture(scope="session")
def async_w3():
provider = AsyncEthereumTesterProvider()
_async_w3 = Web3(
provider, modules={"eth": [AsyncEth]}, middlewares=provider.middlewares
)
_async_w3 = AsyncWeb3(provider, middlewares=provider.middlewares)
return _async_w3


Expand Down
9 changes: 5 additions & 4 deletions tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_open_port,
)
from web3 import (
AsyncWeb3,
Web3,
)
from web3._utils.module_testing.go_ethereum_admin_module import (
Expand Down Expand Up @@ -121,7 +122,7 @@ class TestGoEthereumTxPoolModuleTest(GoEthereumTxPoolModuleTest):
@pytest_asyncio.fixture(scope="module")
async def async_w3(geth_process, endpoint_uri):
await wait_for_aiohttp(endpoint_uri)
_w3 = Web3(AsyncHTTPProvider(endpoint_uri))
_w3 = AsyncWeb3(AsyncHTTPProvider(endpoint_uri))
return _w3


Expand All @@ -130,19 +131,19 @@ class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.xfail(
reason="running geth with the --nodiscover flag doesn't allow peer addition"
)
async def test_admin_peers(self, async_w3: "Web3") -> None:
async def test_admin_peers(self, async_w3: "AsyncWeb3") -> None:
await super().test_admin_peers(async_w3)

@pytest.mark.asyncio
async def test_admin_start_stop_http(self, async_w3: "Web3") -> None:
async def test_admin_start_stop_http(self, async_w3: "AsyncWeb3") -> None:
# This test causes all tests after it to fail on CI if it's allowed to run
pytest.xfail(
reason="Only one HTTP endpoint is allowed to be active at any time"
)
await super().test_admin_start_stop_http(async_w3)

@pytest.mark.asyncio
async def test_admin_start_stop_ws(self, async_w3: "Web3") -> None:
async def test_admin_start_stop_ws(self, async_w3: "AsyncWeb3") -> None:
# This test causes all tests after it to fail on CI if it's allowed to run
pytest.xfail(reason="Only one WS endpoint is allowed to be active at any time")
await super().test_admin_start_stop_ws(async_w3)
Expand Down
6 changes: 5 additions & 1 deletion web3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from eth_account import Account # noqa: E402,
import pkg_resources

from web3.main import Web3 # noqa: E402,
from web3.main import (
AsyncWeb3,
Web3,
)
from web3.providers.async_rpc import ( # noqa: E402
AsyncHTTPProvider,
)
Expand All @@ -22,6 +25,7 @@

__all__ = [
"__version__",
"AsyncWeb3",
"Web3",
"HTTPProvider",
"IPCProvider",
Expand Down
Loading