Skip to content

Commit

Permalink
Add global Gremlin connection_protocol field to notebook configurat…
Browse files Browse the repository at this point in the history
…ion (#621)

* Add Gremlin connection protocol field to %%graph_notebook_config

* update changelog
  • Loading branch information
michaelnchin authored Jun 17, 2024
1 parent bccd7f2 commit e650772
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 40 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming

- Added `--connection-protocol` option to `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/617))
- Added global Gremlin `connection_protocol` setting to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/621))
- Restored left alignment of numeric value columns in results table widget ([Link to PR](https://github.com/aws/graph-notebook/pull/620))

## Release 4.4.0 (June 10, 2024)
Expand Down
54 changes: 44 additions & 10 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import os
from enum import Enum

from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, DEFAULT_GREMLIN_SERIALIZER, \
DEFAULT_GREMLIN_TRAVERSAL_SOURCE, DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \
GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS, \
NEPTUNE_DB_SERVICE_NAME, normalize_service_name
from graph_notebook.neptune.client import (SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION,
DEFAULT_GREMLIN_SERIALIZER, DEFAULT_GREMLIN_TRAVERSAL_SOURCE,
DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL,
HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS,
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE,
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants,
GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS,
NEPTUNE_DB_SERVICE_NAME, normalize_service_name)

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')

Expand Down Expand Up @@ -53,13 +56,15 @@ class GremlinSection(object):
"""

def __init__(self, traversal_source: str = '', username: str = '', password: str = '',
message_serializer: str = ''):
message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
:param username: used to specify a username for authenticating to Gremlin Server, if the endpoint supports it.
:param password: used to specify a password for authenticating to Gremlin Server, if the endpoint supports it.
:param message_serializer: used to specify a serializer for encoding the data to and from Gremlin Server.
:param connection_protocol: used to specify a protocol for the Gremlin connection.
:param include_protocol: used to specify whether to include connection_protocol in the Gremlin config.
"""

if traversal_source == '':
Expand All @@ -80,12 +85,25 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str
f'Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1].')
message_serializer = DEFAULT_GREMLIN_SERIALIZER


self.traversal_source = traversal_source
self.username = username
self.password = password
self.message_serializer = message_serializer

if include_protocol:
protocol_lower = connection_protocol.lower()
if protocol_lower == '':
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
elif protocol_lower in HTTP_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_HTTP_PROTOCOL
elif protocol_lower in WS_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_WS_PROTOCOL
else:
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. "
f"Valid protocols: [websockets, http].")
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
self.connection_protocol = connection_protocol

def to_dict(self):
return self.__dict__

Expand Down Expand Up @@ -142,8 +160,20 @@ def __init__(self, host: str, port: int,
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer) \
if gremlin_section is not None else GremlinSection()
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL
if gremlin_section is not None:
if hasattr(gremlin_section, "connection_protocol"):
if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL:
print("Enforcing HTTP connection protocol for proxy connections.")
final_protocol = DEFAULT_HTTP_PROTOCOL
else:
final_protocol = gremlin_section.connection_protocol
else:
final_protocol = default_protocol
self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer,
connection_protocol=final_protocol, include_protocol=True)
else:
self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True)
self.neo4j = Neo4JSection()
else:
self.is_neptune_config = False
Expand Down Expand Up @@ -267,6 +297,9 @@ def generate_default_config():
parser.add_argument("--gremlin_serializer",
help="the serializer to use as the encoding format when creating Gremlin connections",
default=DEFAULT_GREMLIN_SERIALIZER)
parser.add_argument("--gremlin_connection_protocol",
help="the connection protocol to use for Gremlin connections",
default=DEFAULT_GREMLIN_PROTOCOL)
parser.add_argument("--neo4j_username", help="the username to use for Neo4J connections",
default=DEFAULT_NEO4J_USERNAME)
parser.add_argument("--neo4j_password", help="the password to use for Neo4J connections",
Expand All @@ -285,7 +318,8 @@ def generate_default_config():
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port),
SparqlSection(args.sparql_path, ''),
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
args.gremlin_password, args.gremlin_serializer),
args.gremlin_password, args.gremlin_serializer,
args.gremlin_connection_protocol),
Neo4JSection(args.neo4j_username, args.neo4j_password,
args.neo4j_auth, args.neo4j_database),
args.neptune_hosts)
Expand Down
19 changes: 15 additions & 4 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, NEPTUNE_DB_CONFIG_NAMES, NEPTUNE_ANALYTICS_CONFIG_NAMES

neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region']
neptune_gremlin_params = ['connection_protocol']


def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:
Expand All @@ -21,17 +22,22 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I
else:
ssl_verify = True
sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection()
neo4j_section = Neo4JSection(**data['neo4j']) if 'neo4j' in data else Neo4JSection('', '', True, '')
neo4j_section = Neo4JSection(**data['neo4j']) \
if 'neo4j' in data else Neo4JSection('', '', True, '')
proxy_host = str(data['proxy_host']) if 'proxy_host' in data else ''
proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182

is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)

if is_neptune_host:
neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
if 'gremlin' in data:
data['gremlin']['include_protocol'] = True
gremlin_section = GremlinSection(**data['gremlin'])
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
else:
gremlin_section = GremlinSection(include_protocol=True)
if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \
or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD:
print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n')
Expand All @@ -49,9 +55,13 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I
for p in neptune_params:
if p in data:
excluded_params.append(p)
for gp in neptune_gremlin_params:
if gp in data['gremlin']:
excluded_params.append(gp)
if excluded_params:
print(f"The provided configuration contains the following parameters that are incompatible with the "
f"specified host: {str(excluded_params)}. These parameters have not been saved.\n")
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection()

config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], ssl_verify=ssl_verify,
sparql_section=sparql_section, gremlin_section=gremlin_section, neo4j_section=neo4j_section,
Expand All @@ -63,4 +73,5 @@ def get_config(path: str = DEFAULT_CONFIG_LOCATION,
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:
with open(path) as config_file:
data = json.load(config_file)
print(data)
return get_config_from_dict(data=data, neptune_hosts=neptune_hosts)
65 changes: 41 additions & 24 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,24 +458,24 @@ def graph_notebook_config(self, line='', cell='', local_ns: dict = None):
self.graph_notebook_config = config
self._generate_client_from_config(config)
print('set notebook config to:')
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
print(json.dumps(self.graph_notebook_config.to_dict(), indent=4))
elif args.mode == 'reset':
self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist)
print('reset notebook config to:')
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
print(json.dumps(self.graph_notebook_config.to_dict(), indent=4))
elif args.mode == 'silent':
"""
silent option to that our neptune_menu extension can receive json instead
of python Configuration object
"""
config_dict = self.graph_notebook_config.to_dict()
store_to_ns(args.store_to, json.dumps(config_dict, indent=2), local_ns)
return print(json.dumps(config_dict, indent=2))
store_to_ns(args.store_to, json.dumps(config_dict, indent=4), local_ns)
return print(json.dumps(config_dict, indent=4))
else:
config_dict = self.graph_notebook_config.to_dict()
print(json.dumps(config_dict, indent=2))
print(json.dumps(config_dict, indent=4))

store_to_ns(args.store_to, json.dumps(self.graph_notebook_config.to_dict(), indent=2), local_ns)
store_to_ns(args.store_to, json.dumps(self.graph_notebook_config.to_dict(), indent=4), local_ns)

return self.graph_notebook_config

Expand Down Expand Up @@ -1024,16 +1024,16 @@ def sparql_status(self, line='', local_ns: dict = None):
@cell_magic
@needs_local_scope
@display_exceptions
@neptune_db_only
def gremlin(self, line, cell, local_ns: dict = None):
parser = argparse.ArgumentParser()
parser.add_argument('query_mode', nargs='?', default='query',
help='query mode (default=query) [query|explain|profile]')
parser.add_argument('-cp', '--connection-protocol', type=str.lower, default='',
help=f'Neptune endpoints only. Connection protocol to use for connecting to the Gremlin '
f'database - either Websockets or HTTP. Valid inputs: {GREMLIN_PROTOCOL_FORMATS}. '
f'Default is {DEFAULT_GREMLIN_PROTOCOL}. Please note that this option has no effect '
f'on the Profile and Explain modes, which must use HTTP.')
f'If not specified, defaults to the value of the gremlin.connection_protocol field '
f'in %graph_notebook_config. Please note that this option has no effect on the '
f'Profile and Explain modes, which must use HTTP.')
parser.add_argument('--explain-type', type=str.lower, default='',
help='Explain mode to use when using the explain query mode.')
parser.add_argument('-p', '--path-pattern', default='', help='path pattern')
Expand Down Expand Up @@ -1120,9 +1120,14 @@ def gremlin(self, line, cell, local_ns: dict = None):
transport_args = {'max_content_length': mcl_bytes}

if mode == QueryMode.EXPLAIN:
res = self.client.gremlin_explain(cell,
args={'explain.mode': args.explain_type} if args.explain_type else {})
res.raise_for_status()
try:
res = self.client.gremlin_explain(cell,
args={'explain.mode': args.explain_type} if args.explain_type else {})
res.raise_for_status()
except Exception as e:
if self.client.is_analytics_domain():
print("%%gremlin is incompatible with Neptune Analytics.")
raise e
# Replace strikethrough character bytes, can't be encoded to ASCII
explain_bytes = res.content.replace(b'\xcc', b'-')
explain_bytes = explain_bytes.replace(b'\xb6', b'')
Expand Down Expand Up @@ -1156,8 +1161,13 @@ def gremlin(self, line, cell, local_ns: dict = None):
except JSONDecodeError:
print('--profile-misc-args received invalid input, please check that you are passing in a valid '
'string representation of a map, ex. "{\'profile.x\':\'true\'}"')
res = self.client.gremlin_profile(query=cell, args=profile_args)
res.raise_for_status()
try:
res = self.client.gremlin_profile(query=cell, args=profile_args)
res.raise_for_status()
except Exception as e:
if self.client.is_analytics_domain():
print("%%gremlin is incompatible with Neptune Analytics.")
raise e
profile_bytes = res.content.replace(b'\xcc', b'-')
profile_bytes = profile_bytes.replace(b'\xb6', b'')
query_res = profile_bytes.decode('utf-8')
Expand All @@ -1175,16 +1185,23 @@ def gremlin(self, line, cell, local_ns: dict = None):
using_http = False
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
if self.client.is_neptune_domain():
connection_protocol = normalize_protocol_name(args.connection_protocol)
if self.graph_notebook_config.proxy_host != '' or connection_protocol == DEFAULT_HTTP_PROTOCOL:
using_http = True
query_res_http = self.client.gremlin_http_query(cell, headers={
'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http.raise_for_status()
query_res_http_json = query_res_http.json()
query_res = query_res_http_json['result']['data']
else:
query_res = self.client.gremlin_query(cell, transport_args=transport_args)
connection_protocol = normalize_protocol_name(args.connection_protocol) \
if args.connection_protocol != '' \
else self.graph_notebook_config.gremlin.connection_protocol
try:
if connection_protocol == DEFAULT_HTTP_PROTOCOL:
using_http = True
query_res_http = self.client.gremlin_http_query(cell, headers={
'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http.raise_for_status()
query_res_http_json = query_res_http.json()
query_res = query_res_http_json['result']['data']
else:
query_res = self.client.gremlin_query(cell, transport_args=transport_args)
except Exception as e:
if self.client.is_analytics_domain():
print("%%gremlin is incompatible with Neptune Analytics.")
raise e
else:
query_res = self.client.gremlin_query(cell, transport_args=transport_args)
query_time = time.time() * 1000 - query_start
Expand Down
Loading

0 comments on commit e650772

Please sign in to comment.