Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed Mandatory Encryption in Neo4jHook #30418

19 changes: 16 additions & 3 deletions airflow/providers/neo4j/hooks/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from typing import Any
from urllib.parse import urlsplit

from neo4j import Driver, GraphDatabase

Expand Down Expand Up @@ -61,12 +62,24 @@ def get_conn(self) -> Driver:

is_encrypted = self.connection.extra_dejson.get("encrypted", False)

self.client = GraphDatabase.driver(
uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted
)
self.client = self.get_client(self.connection, is_encrypted, uri)

return self.client

def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
"""
Function to determine that relevant driver based on extras.
:param conn: Connection object.
:param encrypted: boolean if encrypted connection or not.
:param uri: uri string for connection.
:return: Driver
"""
parsed_uri = urlsplit(uri)
kwargs = {}
if parsed_uri.scheme in ["bolt", "neo4j"]:
kwargs["encrypted"] = encrypted
return GraphDatabase.driver(uri, auth=(conn.login, conn.password), **kwargs)

def get_uri(self, conn: Connection) -> str:
"""
Build the uri based on extras
Expand Down
40 changes: 36 additions & 4 deletions tests/providers/neo4j/hooks/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
import pytest

from airflow.models import Connection
from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
from airflow.providers.neo4j.hooks.neo4j import Driver, Neo4jHook


class TestNeo4jHookConn:
@pytest.mark.parametrize(
"conn_extra, expected_uri",
[
({}, "bolt://host:7687"),
({"bolt_scheme": True}, "bolt://host:7687"),
({"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"),
({"neo4j_scheme": False}, "bolt://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": False}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": False}, "bolt+s://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": True}, "neo4j+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": True}, "neo4j+s://host:7687"),
],
)
def test_get_uri_neo4j_scheme(self, conn_extra, expected_uri):
Expand Down Expand Up @@ -101,3 +103,33 @@ def test_run_without_schema(self, mock_graph_database):
)
session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value
assert op_result == session.run.return_value.data.return_value

@pytest.mark.parametrize(
"conn_extra, expected",
[
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": True}, True),
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": False}, True),
({"certs_trusted_ca": True, "neo4j_scheme": False, "encrypted": False}, True),
({"certs_self_signed": True, "neo4j_scheme": True, "encrypted": False}, True),
({"certs_trusted_ca": True, "neo4j_scheme": True, "encrypted": False}, True),
({"certs_trusted_ca": False, "neo4j_scheme": False, "encrypted": True}, True),
],
)
def test_get_client(self, conn_extra, expected):
connection = Connection(
conn_type="neo4j",
login="login",
password="password",
host="host",
schema="schema",
extra=conn_extra,
)
# Use the environment variable mocking to test saving the configuration as a URI and
# to avoid mocking Airflow models class
with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()):
neo4j_hook = Neo4jHook()
is_encrypted = conn_extra.get("encrypted", False)
with neo4j_hook.get_client(
conn=connection, encrypted=is_encrypted, uri=neo4j_hook.get_uri(connection)
) as client:
assert isinstance(client, Driver) == expected
eladkal marked this conversation as resolved.
Show resolved Hide resolved