Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gripql/python: passing None to in_(), out(), both(), etc. should have no effect #148

Merged
merged 2 commits into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions gripql/python/gripql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@
from gripql.util import AttrDict, Rate, process_url


def _wrap_value(value, typ):
wrapped = []
if isinstance(value, list):
if not all(isinstance(i, typ) for i in value):
raise TypeError("expected all values in array to be a string")
wrapped = value
elif isinstance(value, typ):
wrapped.append(value)
elif value is None:
pass
else:
raise TypeError("expected value to be a %s" % typ)
return wrapped


def _wrap_str_value(value):
return _wrap_value(value, str)


def _wrap_dict_value(value):
return _wrap_value(value, dict)


class Query:
def __init__(self, url, graph, user=None, password=None):
self.query = []
Expand All @@ -34,8 +57,7 @@ def V(self, id=[]):

"id" is an ID or a list of vertex IDs to start from. Optional.
"""
if not isinstance(id, list):
id = [id]
id = _wrap_str_value(id)
return self.__append({"v": id})

def E(self, id=[]):
Expand All @@ -44,8 +66,7 @@ def E(self, id=[]):

"id" is an ID to start from. Optional.
"""
if not isinstance(id, list):
id = [id]
id = _wrap_str_value(id)
return self.__append({"e": id})

def where(self, expression):
Expand All @@ -54,13 +75,12 @@ def where(self, expression):
"""
return self.__append({"where": expression})

def fields(self, fields=[]):
def fields(self, field=[]):
"""
Select document properties to be returned in document.
"""
if not isinstance(fields, list):
fields = [fields]
return self.__append({"fields": fields})
field = _wrap_str_value(field)
return self.__append({"fields": field})

def in_(self, label=[]):
"""
Expand All @@ -69,8 +89,7 @@ def in_(self, label=[]):
"label" is the label of the edge to follow.
"label" can be a list.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"in": label})

def out(self, label=[]):
Expand All @@ -80,8 +99,7 @@ def out(self, label=[]):
"label" is the label of the edge to follow.
"label" can be a list.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"out": label})

def both(self, label=[]):
Expand All @@ -91,8 +109,7 @@ def both(self, label=[]):
"label" is the label of the edge to follow.
"label" can be a list.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"both": label})

def inEdge(self, label=[]):
Expand All @@ -104,8 +121,7 @@ def inEdge(self, label=[]):

Must be called from a vertex.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"in_edge": label})

def outEdge(self, label=[]):
Expand All @@ -117,8 +133,7 @@ def outEdge(self, label=[]):

Must be called from a vertex.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"out_edge": label})

def bothEdge(self, label=[]):
Expand All @@ -130,8 +145,7 @@ def bothEdge(self, label=[]):

Must be called from a vertex.
"""
if not isinstance(label, list):
label = [label]
label = _wrap_str_value(label)
return self.__append({"both_edge": label})

def mark(self, name):
Expand All @@ -155,8 +169,7 @@ def select(self, marks):
[A2, B2],
]
"""
if not isinstance(marks, list):
marks = [marks]
marks = _wrap_str_value(marks)
return self.__append({"select": {"marks": marks}})

def limit(self, n):
Expand All @@ -181,8 +194,7 @@ def distinct(self, props=[]):
"""
Select distinct elements based on the provided property list.
"""
if not isinstance(props, list):
props = [props]
props = _wrap_str_value(props)
return self.__append({"distinct": props})

def match(self, queries):
Expand All @@ -209,17 +221,22 @@ def aggregate(self, aggregations):
"""
Aggregate results of query output
"""
if not isinstance(aggregations, list):
aggregations = [aggregations]
aggregations = _wrap_dict_value(aggregations)
return self.__append({"aggregate": {"aggregations": aggregations}})

def toJson(self):
def to_json(self):
"""
Return the query as a JSON string.
"""
output = {"query": self.query}
return json.dumps(output)

def to_dict(self):
"""
Return the query as a dictionary.
"""
return {"query": self.query}

def __iter__(self):
return self.__stream()

Expand All @@ -244,12 +261,12 @@ def __stream(self, debug=False):
rate.init()
response = requests.post(
self.url,
json={"query": self.query},
json=self.to_dict(),
stream=True,
auth=(self.user, self.password)
)
logger.debug('POST %s', self.url)
logger.debug("BODY %s", self.toJson())
logger.debug("BODY %s", self.to_json())
logger.debug('STATUS CODE %s', response.status_code)
response.raise_for_status()

Expand Down
18 changes: 18 additions & 0 deletions gripql/python/tests/wrap_value_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest

from gripql.query import Query


class TestQueryFormat(unittest.TestCase):

def test_wrapping_none(self):
q = Query("localhost", "test")
self.assertEqual(q.V().in_().to_json(), q.V().in_(None).to_json())
self.assertEqual(q.V().out().to_json(), q.V().out(None).to_json())
self.assertEqual(q.V().both().to_json(), q.V().both(None).to_json())
with self.assertRaises(TypeError):
q.V().in_(["foo", None]).to_json()
with self.assertRaises(TypeError):
q.V().in_(["foo", 1]).to_json()
with self.assertRaises(TypeError):
q.V().in_(1).to_json()