Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #86 from unparalleled-js/refactor/jules/use-endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Sep 1, 2022
2 parents 0d389be + b17288a commit e8503a1
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 47 deletions.
6 changes: 4 additions & 2 deletions ape_starknet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ape_starknet.conversion import StarknetAddressConverter
from ape_starknet.ecosystems import Starknet
from ape_starknet.explorer import StarknetExplorer
from ape_starknet.provider import StarknetProvider
from ape_starknet.provider import StarknetDevnetProvider, StarknetProvider
from ape_starknet.tokens import TokenManager
from ape_starknet.utils import NETWORKS, PLUGIN_NAME

Expand Down Expand Up @@ -41,9 +41,11 @@ def networks():

@plugins.register(plugins.ProviderPlugin)
def providers():
for network_name in network_names:
for network_name in list(NETWORKS.keys()):
yield PLUGIN_NAME, network_name, StarknetProvider

yield PLUGIN_NAME, LOCAL_NETWORK_NAME, StarknetDevnetProvider


@plugins.register(plugins.AccountPlugin)
def account_types():
Expand Down
17 changes: 11 additions & 6 deletions ape_starknet/accounts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import json
import random
from dataclasses import dataclass
from math import ceil
from pathlib import Path
Expand Down Expand Up @@ -116,13 +115,19 @@ def public_key_addresses(self) -> Iterator[AddressType]:
for account in self.accounts:
yield account.address

@cached_property
@property
def test_accounts(self) -> List["StarknetDevnetAccount"]:
random_generator = random.Random()
random_generator.seed(self.devnet_account_seed)
if self.provider.network.name != LOCAL_NETWORK_NAME:
return []

return self._test_accounts

@cached_property
def _test_accounts(self):
predeployed_accounts = self.provider.devnet_client.predeployed_accounts
devnet_accounts = [
StarknetDevnetAccount(private_key=random_generator.getrandbits(128))
for _ in range(self.number_of_devnet_accounts)
StarknetDevnetAccount(private_key=int(acc["private_key"], 16))
for acc in predeployed_accounts
]

# Track all devnet account contracts in chain manager for look-up purposes
Expand Down
115 changes: 76 additions & 39 deletions ape_starknet/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from urllib.parse import urlparse
from urllib.request import urlopen

import requests
from ape.api import BlockAPI, ProviderAPI, ReceiptAPI, SubprocessProvider, TransactionAPI
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.contracts import ContractInstance
from ape.exceptions import ProviderNotConnectedError, TransactionError
from ape.types import AddressType, BlockID, ContractLog, LogFilter
from ape.utils import DEFAULT_NUMBER_OF_TEST_ACCOUNTS, cached_property, raises_not_implemented
from ethpm_types import ContractType
from requests import Session
from starknet_py.net.client_models import (
BlockSingleTransactionTrace,
ContractCode,
Expand Down Expand Up @@ -44,7 +44,31 @@
from ape_starknet.utils.basemodel import StarknetBase


class StarknetProvider(SubprocessProvider, ProviderAPI, StarknetBase):
class DevnetClient:
def __init__(self, host_address: str):
self.session = Session()
self.host_address = host_address

@cached_property
def predeployed_accounts(self) -> List[Dict]:
return self._get("predeployed_accounts")

def increase_time(self, amount: int):
return self._post("increase_time", json={"time": amount})

def _get(self, uri: str, **kwargs):
return self._request("get", uri, **kwargs)

def _post(self, uri: str, **kwargs):
return self._request("post", uri, **kwargs)

def _request(self, method: str, uri: str, **kwargs):
response = self.session.request(method.upper(), url=f"{self.host_address}/{uri}", **kwargs)
response.raise_for_status()
return response.json()


class StarknetProvider(ProviderAPI, StarknetBase):
"""
A Starknet provider.
"""
Expand All @@ -54,10 +78,6 @@ class StarknetProvider(SubprocessProvider, ProviderAPI, StarknetBase):
token_manager: TokenManager = TokenManager()
cached_code: Dict[int, ContractCode] = {}

@property
def process_name(self) -> str:
return "starknet-devnet"

@property
def is_connected(self) -> bool:
was_successful = False
Expand All @@ -81,20 +101,6 @@ def starknet_client(self) -> GatewayClient:

return self.client

def build_command(self) -> List[str]:
parts = urlparse(self.uri)
return [
self.process_name,
"--host",
str(parts.hostname),
"--port",
str(parts.port),
"--accounts",
str(DEFAULT_NUMBER_OF_TEST_ACCOUNTS),
"--seed",
str(DEFAULT_ACCOUNT_SEED),
]

@cached_property
def plugin_config(self) -> StarknetConfig:
return self.config_manager.get_config(PLUGIN_NAME) or StarknetConfig() # type: ignore
Expand All @@ -108,13 +114,6 @@ def uri(self) -> str:
return network_config.get("uri") or f"http://127.0.0.1:{DEFAULT_PORT}"

def connect(self):
if self.network.name == LOCAL_NETWORK_NAME:
# Behave like a 'SubprocessProvider'
if not self.is_connected:
super().connect()

self.start()

self.client = GatewayClient(self.uri, chain=self.chain_id)

def disconnect(self):
Expand Down Expand Up @@ -172,6 +171,7 @@ def gas_price(self) -> int:
"""
**NOTE**: Currently, the gas price is fixed to always be 100 gwei.
"""

return self.conversion_manager.convert("100 gwei", int)

@handle_client_errors
Expand Down Expand Up @@ -295,17 +295,6 @@ def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]:
def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI:
return txn

def set_timestamp(self, new_timestamp: int):
pending_timestamp = self.get_block("pending").timestamp
seconds_to_increase = new_timestamp - pending_timestamp
response = requests.post(
url=f"{self.uri}/increase_time", json={"time": seconds_to_increase}
)
response.raise_for_status()
response_data = response.json()
if "timestamp_increased_by" not in response_data:
raise StarknetProviderError(response_data)

def get_virtual_machine_error(self, exception: Exception):
return get_virtual_machine_error(exception)

Expand All @@ -324,4 +313,52 @@ def declare(self, contract_type: ContractType) -> ContractDeclaration:
return self.provider.send_transaction(transaction)


__all__ = ["StarknetProvider"]
class StarknetDevnetProvider(SubprocessProvider, StarknetProvider):
"""
A subprocess provider for the starknet-devnet process.
"""

@property
def process_name(self) -> str:
return "starknet-devnet"

@cached_property
def devnet_client(self) -> DevnetClient:
return DevnetClient(self.uri)

def connect(self):
if self.network.name == LOCAL_NETWORK_NAME:
# Behave like a 'SubprocessProvider'
if not self.is_connected:
super().connect()

self.start()

self.client = GatewayClient(self.uri, chain=self.chain_id)

def build_command(self) -> List[str]:
parts = urlparse(self.uri)
return [
self.process_name,
"--host",
str(parts.hostname),
"--port",
str(parts.port),
"--accounts",
str(DEFAULT_NUMBER_OF_TEST_ACCOUNTS),
"--seed",
str(DEFAULT_ACCOUNT_SEED),
]

def set_timestamp(self, new_timestamp: int):
if self.devnet_client is None:
raise StarknetProviderError("Must be connected to starknet-devnet to use this feature.")

pending_timestamp = self.get_block("pending").timestamp
seconds_to_increase = new_timestamp - pending_timestamp
result = self.devnet_client.increase_time(seconds_to_increase)
if "timestamp_increased_by" not in result:
raise StarknetProviderError(result)


__all__ = ["StarknetProvider", "StarknetDevnetProvider"]

0 comments on commit e8503a1

Please sign in to comment.