Skip to content

Commit

Permalink
working example
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh committed Dec 16, 2022
1 parent 9387879 commit 1d52c1a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
3 changes: 0 additions & 3 deletions superset/databases/ssh_tunnel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

import sqlalchemy as sa
from flask import current_app
from flask_appbuilder import Model
Expand All @@ -31,7 +29,6 @@
)

app_config = current_app.config
ssh_manager = app_config["SSH_TUNNEL_MANAGER"]


class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
Expand Down
1 change: 0 additions & 1 deletion superset/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Any, Callable, Dict, List, Optional

import celery
import sshtunnel
from cachelib.base import BaseCache
from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
Expand Down
22 changes: 12 additions & 10 deletions superset/extensions/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,40 @@
# under the License.

import importlib
from typing import Dict
from typing import TYPE_CHECKING

from flask import Flask
from sshtunnel import open_tunnel, SSHTunnelForwarder

from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe

if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel

class SSHManager: # pylint: disable=too-few-public-methods

class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]

def build_sqla_url(self, sqlalchemy_url: str, server: SSHTunnelForwarder) -> str:
def build_sqla_url( # pylint: disable=no-self-use
self, sqlalchemy_url: str, server: SSHTunnelForwarder
) -> str:
# override any ssh tunnel configuration object
url = make_url_safe(sqlalchemy_url)
return url.set(
host=server.local_bind_address,
host=server.local_bind_address[0],
port=server.local_bind_port,
)

def create_tunnel(
self,
ssh_tunnel: SSHTunnel,
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> SSHTunnelForwarder:
url = make_url_safe(sqlalchemy_database_uri)
params = {
"ssh_address_or_host": ssh_tunnel.ssh_address_or_host,
"ssh_address_or_host": ssh_tunnel.server_address,
"ssh_port": ssh_tunnel.server_port,
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, url.port), # bind_port, bind_host
Expand All @@ -58,9 +62,7 @@ def create_tunnel(
params["private_key"] = ssh_tunnel.private_key
params["private_key_password"] = ssh_tunnel.private_key_password

tunnel = open_tunnel(params)

return tunnel, self.build_sqla_url(sqlalchemy_database_uri, tunnel)
return open_tunnel(**params)


class SSHManagerFactory:
Expand Down
5 changes: 5 additions & 0 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
migrate,
profiling,
results_backend_manager,
ssh_manager_factory,
talisman,
)
from superset.security import SupersetSecurityManager
Expand Down Expand Up @@ -417,6 +418,7 @@ def init_app_in_ctx(self) -> None:
self.configure_data_sources()
self.configure_auth_provider()
self.configure_async_queries()
self.configure_ssh_manager()

# Hook that provides administrators a handle on the Flask APP
# after initialization
Expand Down Expand Up @@ -474,6 +476,9 @@ def init_app(self) -> None:
def configure_auth_provider(self) -> None:
machine_auth_provider_factory.init_app(self.superset_app)

def configure_ssh_manager(self) -> None:
ssh_manager_factory.init_app(self.superset_app)

def setup_event_logger(self) -> None:
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
self.superset_app.config.get("EVENT_LOGGER", DBEventLogger())
Expand Down
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import numpy
import pandas as pd
import sqlalchemy as sqla
import sshtunnel
from flask import g, request
from flask_appbuilder import Model
from sqlalchemy import (
Expand Down Expand Up @@ -389,9 +388,10 @@ def get_sqla_engine_with_context(
database_id=self.id
):
# if ssh_tunnel is available build engine with information
engine_context, sqlalchemy_uri = ssh_manager_factory.instance.create_tunnel(
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(sqlalchemy_uri, server_context)

with engine_context as server_context:
yield self._get_sqla_engine(
Expand Down

0 comments on commit 1d52c1a

Please sign in to comment.