From e65077274f3857b9105e1a7a8e54d1e01f50b849 Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Mon, 17 Jun 2024 09:45:29 -0700 Subject: [PATCH] Add global Gremlin `connection_protocol` field to notebook configuration (#621) * Add Gremlin connection protocol field to %%graph_notebook_config * update changelog --- ChangeLog.md | 1 + .../configuration/generate_config.py | 54 ++++++++++++--- .../configuration/get_config.py | 19 ++++-- src/graph_notebook/magics/graph_magic.py | 65 ++++++++++++------- test/unit/configuration/test_configuration.py | 28 +++++++- 5 files changed, 127 insertions(+), 40 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index 73adfb20..eb77bb16 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Added `--connection-protocol` option to `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/617)) +- Added global Gremlin `connection_protocol` setting to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/621)) - Restored left alignment of numeric value columns in results table widget ([Link to PR](https://github.com/aws/graph-notebook/pull/620)) ## Release 4.4.0 (June 10, 2024) diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index e868a4d3..ad50aa62 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,11 +8,14 @@ import os from enum import Enum -from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, DEFAULT_GREMLIN_SERIALIZER, \ - DEFAULT_GREMLIN_TRAVERSAL_SOURCE, DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \ - NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \ - GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS, \ - NEPTUNE_DB_SERVICE_NAME, normalize_service_name +from graph_notebook.neptune.client import (SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, + DEFAULT_GREMLIN_SERIALIZER, DEFAULT_GREMLIN_TRAVERSAL_SOURCE, + DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, + HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS, + DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, + NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, + GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS, + NEPTUNE_DB_SERVICE_NAME, normalize_service_name) DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -53,13 +56,15 @@ class GremlinSection(object): """ def __init__(self, traversal_source: str = '', username: str = '', password: str = '', - message_serializer: str = ''): + message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False): """ :param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are connected to an endpoint that can access multiple graphs. :param username: used to specify a username for authenticating to Gremlin Server, if the endpoint supports it. :param password: used to specify a password for authenticating to Gremlin Server, if the endpoint supports it. :param message_serializer: used to specify a serializer for encoding the data to and from Gremlin Server. + :param connection_protocol: used to specify a protocol for the Gremlin connection. + :param include_protocol: used to specify whether to include connection_protocol in the Gremlin config. """ if traversal_source == '': @@ -80,12 +85,25 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str f'Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1].') message_serializer = DEFAULT_GREMLIN_SERIALIZER - self.traversal_source = traversal_source self.username = username self.password = password self.message_serializer = message_serializer + if include_protocol: + protocol_lower = connection_protocol.lower() + if protocol_lower == '': + connection_protocol = DEFAULT_GREMLIN_PROTOCOL + elif protocol_lower in HTTP_PROTOCOL_FORMATS: + connection_protocol = DEFAULT_HTTP_PROTOCOL + elif protocol_lower in WS_PROTOCOL_FORMATS: + connection_protocol = DEFAULT_WS_PROTOCOL + else: + print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. " + f"Valid protocols: [websockets, http].") + connection_protocol = DEFAULT_GREMLIN_PROTOCOL + self.connection_protocol = connection_protocol + def to_dict(self): return self.__dict__ @@ -142,8 +160,20 @@ def __init__(self, host: str, port: int, self.auth_mode = auth_mode self.load_from_s3_arn = load_from_s3_arn self.aws_region = aws_region - self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer) \ - if gremlin_section is not None else GremlinSection() + default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL + if gremlin_section is not None: + if hasattr(gremlin_section, "connection_protocol"): + if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL: + print("Enforcing HTTP connection protocol for proxy connections.") + final_protocol = DEFAULT_HTTP_PROTOCOL + else: + final_protocol = gremlin_section.connection_protocol + else: + final_protocol = default_protocol + self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer, + connection_protocol=final_protocol, include_protocol=True) + else: + self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True) self.neo4j = Neo4JSection() else: self.is_neptune_config = False @@ -267,6 +297,9 @@ def generate_default_config(): parser.add_argument("--gremlin_serializer", help="the serializer to use as the encoding format when creating Gremlin connections", default=DEFAULT_GREMLIN_SERIALIZER) + parser.add_argument("--gremlin_connection_protocol", + help="the connection protocol to use for Gremlin connections", + default=DEFAULT_GREMLIN_PROTOCOL) parser.add_argument("--neo4j_username", help="the username to use for Neo4J connections", default=DEFAULT_NEO4J_USERNAME) parser.add_argument("--neo4j_password", help="the password to use for Neo4J connections", @@ -285,7 +318,8 @@ def generate_default_config(): args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port), SparqlSection(args.sparql_path, ''), GremlinSection(args.gremlin_traversal_source, args.gremlin_username, - args.gremlin_password, args.gremlin_serializer), + args.gremlin_password, args.gremlin_serializer, + args.gremlin_connection_protocol), Neo4JSection(args.neo4j_username, args.neo4j_password, args.neo4j_auth, args.neo4j_database), args.neptune_hosts) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 334b3a3b..0c5d713f 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -12,6 +12,7 @@ NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, NEPTUNE_DB_CONFIG_NAMES, NEPTUNE_ANALYTICS_CONFIG_NAMES neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region'] +neptune_gremlin_params = ['connection_protocol'] def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration: @@ -21,8 +22,8 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I else: ssl_verify = True sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') - gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection() - neo4j_section = Neo4JSection(**data['neo4j']) if 'neo4j' in data else Neo4JSection('', '', True, '') + neo4j_section = Neo4JSection(**data['neo4j']) \ + if 'neo4j' in data else Neo4JSection('', '', True, '') proxy_host = str(data['proxy_host']) if 'proxy_host' in data else '' proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182 @@ -30,8 +31,13 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I if is_neptune_host: neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME - if gremlin_section.to_dict()['traversal_source'] != 'g': - print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n') + if 'gremlin' in data: + data['gremlin']['include_protocol'] = True + gremlin_section = GremlinSection(**data['gremlin']) + if gremlin_section.to_dict()['traversal_source'] != 'g': + print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n') + else: + gremlin_section = GremlinSection(include_protocol=True) if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \ or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD: print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n') @@ -49,9 +55,13 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I for p in neptune_params: if p in data: excluded_params.append(p) + for gp in neptune_gremlin_params: + if gp in data['gremlin']: + excluded_params.append(gp) if excluded_params: print(f"The provided configuration contains the following parameters that are incompatible with the " f"specified host: {str(excluded_params)}. These parameters have not been saved.\n") + gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection() config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], ssl_verify=ssl_verify, sparql_section=sparql_section, gremlin_section=gremlin_section, neo4j_section=neo4j_section, @@ -63,4 +73,5 @@ def get_config(path: str = DEFAULT_CONFIG_LOCATION, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration: with open(path) as config_file: data = json.load(config_file) + print(data) return get_config_from_dict(data=data, neptune_hosts=neptune_hosts) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 40cdce66..97298ed2 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -458,24 +458,24 @@ def graph_notebook_config(self, line='', cell='', local_ns: dict = None): self.graph_notebook_config = config self._generate_client_from_config(config) print('set notebook config to:') - print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) + print(json.dumps(self.graph_notebook_config.to_dict(), indent=4)) elif args.mode == 'reset': self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist) print('reset notebook config to:') - print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) + print(json.dumps(self.graph_notebook_config.to_dict(), indent=4)) elif args.mode == 'silent': """ silent option to that our neptune_menu extension can receive json instead of python Configuration object """ config_dict = self.graph_notebook_config.to_dict() - store_to_ns(args.store_to, json.dumps(config_dict, indent=2), local_ns) - return print(json.dumps(config_dict, indent=2)) + store_to_ns(args.store_to, json.dumps(config_dict, indent=4), local_ns) + return print(json.dumps(config_dict, indent=4)) else: config_dict = self.graph_notebook_config.to_dict() - print(json.dumps(config_dict, indent=2)) + print(json.dumps(config_dict, indent=4)) - store_to_ns(args.store_to, json.dumps(self.graph_notebook_config.to_dict(), indent=2), local_ns) + store_to_ns(args.store_to, json.dumps(self.graph_notebook_config.to_dict(), indent=4), local_ns) return self.graph_notebook_config @@ -1024,7 +1024,6 @@ def sparql_status(self, line='', local_ns: dict = None): @cell_magic @needs_local_scope @display_exceptions - @neptune_db_only def gremlin(self, line, cell, local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('query_mode', nargs='?', default='query', @@ -1032,8 +1031,9 @@ def gremlin(self, line, cell, local_ns: dict = None): parser.add_argument('-cp', '--connection-protocol', type=str.lower, default='', help=f'Neptune endpoints only. Connection protocol to use for connecting to the Gremlin ' f'database - either Websockets or HTTP. Valid inputs: {GREMLIN_PROTOCOL_FORMATS}. ' - f'Default is {DEFAULT_GREMLIN_PROTOCOL}. Please note that this option has no effect ' - f'on the Profile and Explain modes, which must use HTTP.') + f'If not specified, defaults to the value of the gremlin.connection_protocol field ' + f'in %graph_notebook_config. Please note that this option has no effect on the ' + f'Profile and Explain modes, which must use HTTP.') parser.add_argument('--explain-type', type=str.lower, default='', help='Explain mode to use when using the explain query mode.') parser.add_argument('-p', '--path-pattern', default='', help='path pattern') @@ -1120,9 +1120,14 @@ def gremlin(self, line, cell, local_ns: dict = None): transport_args = {'max_content_length': mcl_bytes} if mode == QueryMode.EXPLAIN: - res = self.client.gremlin_explain(cell, - args={'explain.mode': args.explain_type} if args.explain_type else {}) - res.raise_for_status() + try: + res = self.client.gremlin_explain(cell, + args={'explain.mode': args.explain_type} if args.explain_type else {}) + res.raise_for_status() + except Exception as e: + if self.client.is_analytics_domain(): + print("%%gremlin is incompatible with Neptune Analytics.") + raise e # Replace strikethrough character bytes, can't be encoded to ASCII explain_bytes = res.content.replace(b'\xcc', b'-') explain_bytes = explain_bytes.replace(b'\xb6', b'') @@ -1156,8 +1161,13 @@ def gremlin(self, line, cell, local_ns: dict = None): except JSONDecodeError: print('--profile-misc-args received invalid input, please check that you are passing in a valid ' 'string representation of a map, ex. "{\'profile.x\':\'true\'}"') - res = self.client.gremlin_profile(query=cell, args=profile_args) - res.raise_for_status() + try: + res = self.client.gremlin_profile(query=cell, args=profile_args) + res.raise_for_status() + except Exception as e: + if self.client.is_analytics_domain(): + print("%%gremlin is incompatible with Neptune Analytics.") + raise e profile_bytes = res.content.replace(b'\xcc', b'-') profile_bytes = profile_bytes.replace(b'\xb6', b'') query_res = profile_bytes.decode('utf-8') @@ -1175,16 +1185,23 @@ def gremlin(self, line, cell, local_ns: dict = None): using_http = False query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms if self.client.is_neptune_domain(): - connection_protocol = normalize_protocol_name(args.connection_protocol) - if self.graph_notebook_config.proxy_host != '' or connection_protocol == DEFAULT_HTTP_PROTOCOL: - using_http = True - query_res_http = self.client.gremlin_http_query(cell, headers={ - 'Accept': 'application/vnd.gremlin-v1.0+json;types=false'}) - query_res_http.raise_for_status() - query_res_http_json = query_res_http.json() - query_res = query_res_http_json['result']['data'] - else: - query_res = self.client.gremlin_query(cell, transport_args=transport_args) + connection_protocol = normalize_protocol_name(args.connection_protocol) \ + if args.connection_protocol != '' \ + else self.graph_notebook_config.gremlin.connection_protocol + try: + if connection_protocol == DEFAULT_HTTP_PROTOCOL: + using_http = True + query_res_http = self.client.gremlin_http_query(cell, headers={ + 'Accept': 'application/vnd.gremlin-v1.0+json;types=false'}) + query_res_http.raise_for_status() + query_res_http_json = query_res_http.json() + query_res = query_res_http_json['result']['data'] + else: + query_res = self.client.gremlin_query(cell, transport_args=transport_args) + except Exception as e: + if self.client.is_analytics_domain(): + print("%%gremlin is incompatible with Neptune Analytics.") + raise e else: query_res = self.client.gremlin_query(cell, transport_args=transport_args) query_time = time.time() * 1000 - query_start diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index d4e4089b..e76647b8 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -9,7 +9,8 @@ from graph_notebook.configuration.get_config import get_config from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, AuthModeEnum, \ generate_config, generate_default_config, GremlinSection -from graph_notebook.neptune.client import NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME +from graph_notebook.neptune.client import NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, \ + DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL class TestGenerateConfiguration(unittest.TestCase): @@ -49,6 +50,7 @@ def test_generate_default_config(self): self.assertEqual('g', config.gremlin.traversal_source) self.assertEqual('', config.gremlin.username) self.assertEqual('', config.gremlin.password) + self.assertEqual(DEFAULT_GREMLIN_PROTOCOL, config.gremlin.connection_protocol) self.assertEqual('graphsonv3', config.gremlin.message_serializer) self.assertEqual('neo4j', config.neo4j.username) self.assertEqual('password', config.neo4j.password) @@ -219,6 +221,7 @@ def test_generate_configuration_override_defaults_generic(self): c = generate_config(config.host, config.port, ssl=config.ssl) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) + print(config_from_file.to_dict()) self.assertEqual(config.to_dict(), config_from_file.to_dict()) def test_configuration_neptune_host_with_whitespace(self): @@ -249,6 +252,7 @@ def test_configuration_gremlinsection_generic_default(self): self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') self.assertEqual(config.gremlin.message_serializer, 'graphsonv3') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) def test_configuration_gremlinsection_generic_override(self): config = Configuration('localhost', @@ -262,6 +266,7 @@ def test_configuration_gremlinsection_generic_override(self): self.assertEqual(config.gremlin.username, 'foo') self.assertEqual(config.gremlin.password, 'bar') self.assertEqual(config.gremlin.message_serializer, 'graphbinaryv1') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) def test_configuration_gremlinsection_neptune_default(self): config = Configuration(self.neptune_host_reg, self.port) @@ -269,6 +274,7 @@ def test_configuration_gremlinsection_neptune_default(self): self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') self.assertEqual(config.gremlin.message_serializer, 'graphsonv3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_GREMLIN_PROTOCOL) def test_configuration_gremlinsection_neptune_override(self): config = Configuration(self.neptune_host_reg, @@ -276,12 +282,30 @@ def test_configuration_gremlinsection_neptune_override(self): gremlin_section=GremlinSection(traversal_source='t', username='foo', password='bar', - message_serializer='graphbinary'), + message_serializer='graphbinary', + connection_protocol='http', + include_protocol=True), ) self.assertEqual(config.gremlin.traversal_source, 'g') self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') self.assertEqual(config.gremlin.message_serializer, 'graphbinaryv1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_protocol_neptune_default_with_proxy(self): + config = Configuration(self.neptune_host_reg, + self.port, + proxy_host='test_proxy') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_protocol_neptune_override_with_proxy(self): + config = Configuration(self.neptune_host_reg, + self.port, + proxy_host='test_proxy', + gremlin_section=GremlinSection(connection_protocol='ws', + include_protocol=True) + ) + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) def test_configuration_neptune_service_default(self): config = Configuration(self.neptune_host_reg, self.port)