Skip to content

Commit

Permalink
Make Provider multichain (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
calina-c authored Mar 17, 2023
1 parent adae5e4 commit 459cb4b
Show file tree
Hide file tree
Showing 59 changed files with 566 additions and 397 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
rev: 'refs/tags/22.6.0:refs/tags/22.6.0'
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.7.9
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
2 changes: 0 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ RUN python3.8 -m pip install setuptools
RUN python3.8 -m pip install wheel
RUN python3.8 -m pip install .

# config.ini configuration file variables
ENV NETWORK_URL='http://127.0.0.1:8545'

ENV PROVIDER_PRIVATE_KEY=''
Expand All @@ -50,7 +49,6 @@ ENV AZURE_SHARE_OUTPUT='output'

ENV OCEAN_PROVIDER_URL='http://0.0.0.0:8030'

# docker-entrypoint.sh configuration file variables
ENV OCEAN_PROVIDER_WORKERS='1'
ENV OCEAN_PROVIDER_TIMEOUT='9000'
ENV ALLOW_NON_PUBLIC_IP=False
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ Add the corresponding environment variables in your `.env` file. Here is an exam
```
FLASK_APP=ocean_provider/run.py
PROVIDER_ADDRESS=your ethereum address goes here
PROVIDER_PRIVATE_KEY=the private key
PROVIDER_FEE_TOKEN = address of ERC20 token used to get fees
PROVIDER_PRIVATE_KEY= the private key or string containing a dict of chain_id to private key pairs
PROVIDER_FEE_TOKEN = the address of ERC20 token used to get fees, or string containing a dict of chain_id to token address pairs
```

You might also want to set `FLASK_ENV=development`. Then run ```flask run --port=8030```
Expand Down
14 changes: 7 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from eth_account import Account
from ocean_provider.run import app
from ocean_provider.utils.basics import get_web3, send_ether
from ocean_provider.utils.basics import get_provider_private_key, get_web3, send_ether
from ocean_provider.utils.provider_fees import get_c2d_environments

app = app
Expand Down Expand Up @@ -42,7 +42,7 @@ def consumer_address(consumer_wallet):

@pytest.fixture
def ganache_wallet():
web3 = get_web3()
web3 = get_web3(8996)
if (
web3.eth.accounts
and web3.eth.accounts[0].lower()
Expand All @@ -57,7 +57,7 @@ def ganache_wallet():

@pytest.fixture
def provider_wallet():
pk = os.environ.get("PROVIDER_PRIVATE_KEY")
pk = get_provider_private_key(8996)
return Account.from_key(pk)


Expand All @@ -68,7 +68,7 @@ def provider_address(provider_wallet):

@pytest.fixture(autouse=True)
def setup_all(provider_address, consumer_address, ganache_wallet):
web3 = get_web3()
web3 = get_web3(8996)
if ganache_wallet:
if (
web3.fromWei(
Expand All @@ -91,13 +91,13 @@ def setup_all(provider_address, consumer_address, ganache_wallet):

@pytest.fixture
def web3():
return get_web3()
return get_web3(8996)


@pytest.fixture
def free_c2d_env():
try:
environments = get_c2d_environments()
environments = get_c2d_environments(flat=True)
except AssertionError:
pytest.skip("C2D connection failed. Need fix in #610")

Expand All @@ -107,7 +107,7 @@ def free_c2d_env():
@pytest.fixture
def paid_c2d_env():
try:
environments = get_c2d_environments()
environments = get_c2d_environments(flat=True)
except AssertionError:
pytest.skip("C2D connection failed. Need fix in #610")

Expand Down
3 changes: 1 addition & 2 deletions ocean_provider/file_types/file_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
import logging
import os
import json
from typing import Any, Optional, Tuple
from urllib.parse import urljoin
from uuid import uuid4

from enforce_typing import enforce_types

from ocean_provider.file_types.definitions import EndUrlType, FilesType

logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions ocean_provider/file_types/file_types_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import Any, Tuple

from enforce_typing import enforce_types

from ocean_provider.file_types.file_types import (
ArweaveFile,
GraphqlQuery,
IpfsFile,
UrlFile,
GraphqlQuery,
)
from ocean_provider.file_types.types.smartcontract import SmartContractCall

Expand Down Expand Up @@ -55,6 +54,7 @@ def validate_and_create(file_obj) -> Tuple[bool, Any]:
elif file_obj["type"] == "smartcontract":
instance = SmartContractCall(
address=file_obj.get("address"),
chain_id=file_obj.get("chainId"),
abi=file_obj.get("abi"),
userdata=file_obj.get("userdata"),
)
Expand Down
46 changes: 31 additions & 15 deletions ocean_provider/file_types/types/smartcontract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ class SmartContractCall(FilesType):
def __init__(
self,
address: Optional[str] = None,
chain_id: Optional[int] = None,
abi: Optional[dict] = None,
userdata=None,
) -> None:
self.address = address
self.chain_id = chain_id
self.type = "smartcontract"
self.abi = abi
self.userdata = None
Expand All @@ -37,56 +39,70 @@ def validate_dict(self) -> Tuple[bool, Any]:
if not self.address:
return False, "malformed smartcontract type, missing contract address"
# validate abi

inputs = self.abi.get("inputs")
type = self.abi.get("type")
if inputs is None or type != "function":
t_type = self.abi.get("type")
if inputs is None or t_type != "function":
return False, "invalid abi"

mutability = self.abi.get("stateMutability", None)
if mutability not in ["view", "pure"]:
return False, "only view or pure functions are allowed"

if not self.abi.get("name"):
return False, "missing name"

# check that all inputs have a match in userdata
if len(inputs) > 0 and self.userdata is None:
return False, "Missing parameters"
for input in inputs:
value = self.userdata.get(input.get("name"))

missing_inputs = []
for input_item in inputs:
value = self.userdata.get(input_item.get("name"))
if not value:
return False, f"Missing userparam: {input.name}"
missing_inputs.append(input_item.name)

if missing_inputs:
return False, "Missing userparams: " + ",".join(missing_inputs)

return True, self

@enforce_types
def get_filename(self) -> str:
return uuid4().hex

def fetch_smartcontract_call(self):
web3 = get_web3()
web3 = get_web3(self.chain_id)
contract = web3.eth.contract(
address=web3.toChecksumAddress(self.address), abi=[self.abi]
)
function = contract.functions[self.abi.get("name")]
args = dict()
for input in self.abi.get("inputs"):
args[input.get("name")] = self.userdata.get(input.get("name"))
if input.get("type") == "address":
args[input.get("name")] = web3.toChecksumAddress(
args[input.get("name")]
)

for input_item in self.abi.get("inputs"):
name = input_item.get("name")
args[name] = self.userdata.get(name)
if input_item.get("type") == "address":
args[name] = web3.toChecksumAddress(args[name])

result = function(**args).call()

if isinstance(result, object):
return json.dumps(result), "application/json"

return result, "application/text"

def check_details(self, with_checksum=False):
try:
result, type = self.fetch_smartcontract_call()
details = {"contentLength": len(result) or "", "contentType": type}
result, t_type = self.fetch_smartcontract_call()
details = {"contentLength": len(result) or "", "contentType": t_type}

if with_checksum:
sha = hashlib.sha256()
sha.update(result.encode("utf-8"))
details["checksumType"] = "sha256"
details["checksum"] = sha.hexdigest()

return True, details
except Exception:
return False, {}
Expand All @@ -97,7 +113,7 @@ def build_download_response(
validate_url=True,
):
try:
result, type = self.fetch_smartcontract_call()
result, t_type = self.fetch_smartcontract_call()
return Response(
result,
200,
Expand Down
2 changes: 0 additions & 2 deletions ocean_provider/myapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

"""
This module creates an instance of flask `app`, creates `user_nonce` table if not exists, and sets the environment configuration.
If `PROVIDER_CONFIG_FILE` is not found in environment variables, default `config.ini` file is used.
"""

from flask import Flask, _app_ctx_stack
from flask_cors import CORS
from flask_sieve import Sieve
Expand Down
4 changes: 2 additions & 2 deletions ocean_provider/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#
import logging
import os

import jwt
from flask import jsonify, request
Expand All @@ -13,6 +12,7 @@
force_restore_token,
is_token_valid,
)
from ocean_provider.utils.basics import get_provider_private_key
from ocean_provider.utils.util import get_request_data
from ocean_provider.validation.provider_requests import (
CreateTokenRequest,
Expand Down Expand Up @@ -63,7 +63,7 @@ def create_auth_token():
address = data.get("address")
expiration = int(data.get("expiration"))

pk = os.environ.get("PROVIDER_PRIVATE_KEY")
pk = get_provider_private_key(use_universal_key=True)
token = jwt.encode({"exp": expiration, "address": address}, pk, algorithm="HS256")
token = token.decode("utf-8") if isinstance(token, bytes) else token

Expand Down
17 changes: 6 additions & 11 deletions ocean_provider/routes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ocean_provider.utils.basics import (
get_metadata_url,
get_provider_wallet,
get_web3,
validate_timestamp,
)
from ocean_provider.utils.compute import (
Expand Down Expand Up @@ -53,7 +52,6 @@

from . import services

provider_wallet = get_provider_wallet()
requests_session = get_requests_session()

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,10 +123,9 @@ def initializeCompute():
logger,
)

if not check_environment_exists(get_c2d_environments(), compute_env):
if not check_environment_exists(get_c2d_environments(flat=True), compute_env):
return error_response("Compute environment does not exist", 400, logger)

web3 = get_web3()
approve_params = {"datasets": []} if datasets else {}

index_for_provider_fees = comb_for_valid_transfer_and_fees(
Expand All @@ -152,6 +149,7 @@ def initializeCompute():
return error_response("DID is not a valid algorithm", 400, logger)

algo_service = algo_ddo.get_service_by_id(algorithm.get("serviceId"))
provider_wallet = get_provider_wallet(algo_ddo.chain_id)
algo_files_checksum, algo_container_checksum = get_algo_checksums(
algo_service, provider_wallet, algo_ddo
)
Expand All @@ -160,9 +158,7 @@ def initializeCompute():
dataset["algorithm"] = algorithm
dataset["consumerAddress"] = consumer_address
input_item_validator = InputItemValidator(
web3,
consumer_address,
provider_wallet,
dataset,
{"environment": compute_env},
i,
Expand Down Expand Up @@ -441,7 +437,7 @@ def computeStart():
logger.info(f"computeStart called. arguments = {data}")

consumer_address = data.get("consumerAddress")
validator = WorkflowValidator(get_web3(), consumer_address, provider_wallet, data)
validator = WorkflowValidator(consumer_address, data)

status = validator.validate()
if not status:
Expand All @@ -453,8 +449,8 @@ def computeStart():

compute_env = data.get("environment")

provider_wallet = get_provider_wallet(use_universal_key=True)
nonce, provider_signature = sign_for_compute(provider_wallet, consumer_address)
web3 = get_web3()
payload = {
"workflow": workflow,
"providerSignature": provider_signature,
Expand All @@ -464,7 +460,7 @@ def computeStart():
"environment": compute_env,
"validUntil": validator.valid_until,
"nonce": nonce,
"chainId": web3.chain_id,
"chainId": validator.chain_id,
}

response = requests_session.post(
Expand Down Expand Up @@ -527,18 +523,17 @@ def computeResult():
url = get_compute_result_endpoint()
consumer_address = data.get("consumerAddress")
job_id = data.get("jobId")
provider_wallet = get_provider_wallet(use_universal_key=True)
nonce, provider_signature = sign_for_compute(
provider_wallet, consumer_address, job_id
)
web3 = get_web3()
params = {
"index": data.get("index"),
"owner": data.get("consumerAddress"),
"jobId": job_id,
"consumerSignature": data.get("signature"),
"providerSignature": provider_signature,
"nonce": nonce,
"chainId": web3.chain_id,
}
req = PreparedRequest()
req.prepare_url(url, params)
Expand Down
Loading

0 comments on commit 459cb4b

Please sign in to comment.