Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run dbt on WebAssembly using Pyodide #5803

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20220909-154722.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Run dbt on WebAssembly using Pyodide
time: 2022-09-09T15:47:22.228524-04:00
custom:
Author: arieldbt
Issue: "1970"
PR: "5803"
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from time import sleep
import sys

# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
# dbt.clients.parallel.RLock is a function returning this type
from dbt.clients.parallel import RLock
from threading import get_ident
from typing import (
Any,
Expand Down
107 changes: 107 additions & 0 deletions core/dbt/clients/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from dbt import flags

from abc import ABCMeta, abstractmethod
import json
from typing import Any, Dict
import requests
from requests import Response
from urllib.parse import urlencode


class Http(metaclass=ABCMeta):
@abstractmethod
def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
raise NotImplementedError

@abstractmethod
def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError

@abstractmethod
def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError


class PyodideHttp(Http):
def __init__(self) -> None:
super().__init__()
from pyodide.http import open_url

self._open_url = open_url

def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
if params is not None:
url += f"?{urlencode(params)}"
r = self._open_url(url=url)
return json.load(r)

def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError

def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError


class Requests(Http):
def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
return self.get_response(url=url, params=params, timeout=timeout).json()

def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
return requests.get(url=url, params=params, timeout=timeout)

def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
return requests.post(url=url, data=data, headers=headers, timeout=timeout)


if flags.IS_PYODIDE:
http = PyodideHttp()
else:
http = Requests()
34 changes: 34 additions & 0 deletions core/dbt/clients/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dbt import flags
from threading import Lock as PyodideLock
from threading import RLock as PyodideRLock

if flags.IS_PYODIDE:
pass # multiprocessing doesn't work in pyodide
else:
from multiprocessing.dummy import Pool as MultiprocessingThreadPool
from multiprocessing.synchronize import Lock as MultiprocessingLock
from multiprocessing.synchronize import RLock as MultiprocessingRLock


class PyodideThreadPool:
def __init__(self, num_threads: int) -> None:
pass

def close(self):
pass

def join(self):
pass

def terminate(self):
pass


if flags.IS_PYODIDE:
Lock = PyodideLock
ThreadPool = PyodideThreadPool
RLock = PyodideRLock
else:
Lock = MultiprocessingLock
ThreadPool = MultiprocessingThreadPool
RLock = MultiprocessingRLock
5 changes: 3 additions & 2 deletions core/dbt/clients/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
from typing import Any, Dict, List
import requests
from dbt.clients.http import http
from dbt.events.functions import fire_event
from dbt.events.types import (
RegistryProgressMakingGETRequest,
Expand Down Expand Up @@ -40,7 +41,7 @@ def _get(package_name, registry_base_url=None):
url = _get_url(package_name, registry_base_url)
fire_event(RegistryProgressMakingGETRequest(url=url))
# all exceptions from requests get caught in the retry logic so no need to wrap this here
resp = requests.get(url, timeout=30)
resp = http.get_response(url, timeout=30)
fire_event(RegistryProgressGETResponse(url=url, resp_code=resp.status_code))
resp.raise_for_status()

Expand Down Expand Up @@ -164,7 +165,7 @@ def _get_index(registry_base_url=None):
url = _get_url("index", registry_base_url)
fire_event(RegistryIndexProgressMakingGETRequest(url=url))
# all exceptions from requests get caught in the retry logic so no need to wrap this here
resp = requests.get(url, timeout=30)
resp = http.get_response(url, timeout=30)
fire_event(RegistryIndexProgressGETResponse(url=url, resp_code=resp.status_code))
resp.raise_for_status()

Expand Down
4 changes: 2 additions & 2 deletions core/dbt/clients/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import subprocess
import sys
import tarfile
import requests
import stat
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union

Expand All @@ -22,6 +21,7 @@
SystemStdErrMsg,
SystemReportReturnCode,
)
from dbt.clients.http import http
import dbt.exceptions
from dbt.utils import _connection_exception_retry as connection_exception_retry

Expand Down Expand Up @@ -451,7 +451,7 @@ def download(
) -> None:
path = convert_path(path)
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
response = requests.get(url, timeout=connection_timeout)
response = http.get_response(url, timeout=connection_timeout)
with open(path, "wb") as handle:
for block in response.iter_content(1024 * 64):
handle.write(block)
Expand Down
20 changes: 15 additions & 5 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass, field
from itertools import chain, islice
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from dbt.clients.parallel import Lock

from typing import (
Dict,
List,
Expand Down Expand Up @@ -641,10 +642,19 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
default_factory=ParsingInfo,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_lock: Lock = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
if flags.IS_PYODIDE:
# Not sure how to avoid this change
# Fails with this error:
# mashumaro.exceptions.UnserializableDataError: <built-in function allocate_lock> as a field type is not supported by mashumaro
_lock: Callable = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
else:
_lock: Lock = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

def __pre_serialize__(self):
# serialization won't work with anything except an empty source_patches because
Expand Down
29 changes: 19 additions & 10 deletions core/dbt/flags.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
import multiprocessing

if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore
from pathlib import Path
import sys
from typing import Optional

# PROFILES_DIR must be set before the other flags
Expand Down Expand Up @@ -45,6 +41,7 @@
CACHE_SELECTED_ONLY = None
TARGET_PATH = None
LOG_PATH = None
IS_PYODIDE = "pyodide" in sys.modules # whether dbt is running via pyodide

_NON_BOOLEAN_FLAGS = [
"LOG_FORMAT",
Expand Down Expand Up @@ -117,13 +114,25 @@ def env_set_path(key: str) -> Optional[Path]:
ENABLE_LEGACY_LOGGER = env_set_truthy("DBT_ENABLE_LEGACY_LOGGER")


def _get_context():
# TODO: change this back to use fork() on linux when we have made that safe
return multiprocessing.get_context("spawn")
# This is not a flag, it's a place to store the lock
if IS_PYODIDE:
from typing import NamedTuple
from threading import Lock as PyodideLock
from threading import RLock as PyodideRLock

class PyodideContext(NamedTuple):
Lock = PyodideLock
RLock = PyodideRLock

# This is not a flag, it's a place to store the lock
MP_CONTEXT = _get_context()
MP_CONTEXT = PyodideContext()
else:
import multiprocessing

if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore
# TODO: change this back to use fork() on linux when we have made that safe
MP_CONTEXT = multiprocessing.get_context("spawn")


def set_from_args(args, user_config):
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):

# Construct a phony config
config = RuntimeConfig.from_args(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
RuntimeArgs(
project_dir, profiles_dir, single_threaded or flags.IS_PYODIDE, profile, target
)
)
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server
Expand Down
19 changes: 13 additions & 6 deletions core/dbt/parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
from dbt.clients.jinja import get_rendered
import dbt.tracking as tracking
from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
from functools import reduce
from itertools import chain
import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

# No support for compiled dependencies on pyodide
if flags.IS_PYODIDE:
pass
else:
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
# New for Python models :p
import ast
from dbt.dataclass_schema import ValidationError
Expand Down Expand Up @@ -283,12 +287,15 @@ def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None:
exp_sample_node = deepcopy(node)
exp_sample_config = deepcopy(config)
model_parser_copy.populate(exp_sample_node, exp_sample_config, experimental_sample)
# use the experimental parser exclusively if the flag is on
if flags.USE_EXPERIMENTAL_PARSER:
statically_parsed = self.run_experimental_parser(node)
# run the stable static parser unless it is explicitly turned off
if flags.IS_PYODIDE:
pass
else:
statically_parsed = self.run_static_parser(node)
# use the experimental parser exclusively if the flag is on
if flags.USE_EXPERIMENTAL_PARSER:
statically_parsed = self.run_experimental_parser(node)
# run the stable static parser unless it is explicitly turned off
else:
statically_parsed = self.run_static_parser(node)

# if the static parser succeeded, extract some data in easy-to-compare formats
if isinstance(statically_parsed, dict):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class BaseTask(metaclass=ABCMeta):

def __init__(self, args, config):
self.args = args
self.args.single_threaded = False
self.args.single_threaded = False or flags.IS_PYODIDE
self.config = config

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet

from .printer import (
print_run_result_error,
print_run_end_messages,
)

from dbt.clients.parallel import ThreadPool
from dbt.clients.system import write_file
from dbt.task.base import ConfiguredTask
from dbt.adapters.base import BaseRelation
Expand Down Expand Up @@ -266,7 +266,7 @@ def _submit(self, pool, args, callback):

This does still go through the callback path for result collection.
"""
if self.config.args.single_threaded:
if self.config.args.single_threaded or flags.IS_PYODIDE:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)
Expand Down
Loading