From 035503fae3d4651a0dc51975fd17d5fabd877376 Mon Sep 17 00:00:00 2001 From: adamstruck Date: Fri, 21 Sep 2018 11:55:04 -0700 Subject: [PATCH 1/2] gripql/python: passing None to methods such as in_(), out(), both() has no effect --- gripql/python/gripql/query.py | 75 +++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/gripql/python/gripql/query.py b/gripql/python/gripql/query.py index 680ef07f..145a8a9b 100644 --- a/gripql/python/gripql/query.py +++ b/gripql/python/gripql/query.py @@ -8,6 +8,29 @@ from gripql.util import AttrDict, Rate, process_url +def _wrap_value(value, typ): + wrapped = [] + if isinstance(value, list): + if not all(isinstance(i, typ) for i in value): + raise TypeError("expected all values in array to be a string") + wrapped = value + elif isinstance(value, typ): + wrapped.append(value) + elif value is None: + pass + else: + raise TypeError("expected value to be a %s" % typ) + return wrapped + + +def _wrap_str_value(value): + return _wrap_value(value, str) + + +def _wrap_dict_value(value): + return _wrap_value(value, dict) + + class Query: def __init__(self, url, graph, user=None, password=None): self.query = [] @@ -34,8 +57,7 @@ def V(self, id=[]): "id" is an ID or a list of vertex IDs to start from. Optional. """ - if not isinstance(id, list): - id = [id] + id = _wrap_str_value(id) return self.__append({"v": id}) def E(self, id=[]): @@ -44,8 +66,7 @@ def E(self, id=[]): "id" is an ID to start from. Optional. """ - if not isinstance(id, list): - id = [id] + id = _wrap_str_value(id) return self.__append({"e": id}) def where(self, expression): @@ -54,13 +75,12 @@ def where(self, expression): """ return self.__append({"where": expression}) - def fields(self, fields=[]): + def fields(self, field=[]): """ Select document properties to be returned in document. """ - if not isinstance(fields, list): - fields = [fields] - return self.__append({"fields": fields}) + field = _wrap_str_value(field) + return self.__append({"fields": field}) def in_(self, label=[]): """ @@ -69,8 +89,7 @@ def in_(self, label=[]): "label" is the label of the edge to follow. "label" can be a list. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"in": label}) def out(self, label=[]): @@ -80,8 +99,7 @@ def out(self, label=[]): "label" is the label of the edge to follow. "label" can be a list. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"out": label}) def both(self, label=[]): @@ -91,8 +109,7 @@ def both(self, label=[]): "label" is the label of the edge to follow. "label" can be a list. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"both": label}) def inEdge(self, label=[]): @@ -104,8 +121,7 @@ def inEdge(self, label=[]): Must be called from a vertex. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"in_edge": label}) def outEdge(self, label=[]): @@ -117,8 +133,7 @@ def outEdge(self, label=[]): Must be called from a vertex. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"out_edge": label}) def bothEdge(self, label=[]): @@ -130,8 +145,7 @@ def bothEdge(self, label=[]): Must be called from a vertex. """ - if not isinstance(label, list): - label = [label] + label = _wrap_str_value(label) return self.__append({"both_edge": label}) def mark(self, name): @@ -155,8 +169,7 @@ def select(self, marks): [A2, B2], ] """ - if not isinstance(marks, list): - marks = [marks] + marks = _wrap_str_value(marks) return self.__append({"select": {"marks": marks}}) def limit(self, n): @@ -181,8 +194,7 @@ def distinct(self, props=[]): """ Select distinct elements based on the provided property list. """ - if not isinstance(props, list): - props = [props] + props = _wrap_str_value(props) return self.__append({"distinct": props}) def match(self, queries): @@ -209,17 +221,22 @@ def aggregate(self, aggregations): """ Aggregate results of query output """ - if not isinstance(aggregations, list): - aggregations = [aggregations] + aggregations = _wrap_dict_value(aggregations) return self.__append({"aggregate": {"aggregations": aggregations}}) - def toJson(self): + def to_json(self): """ Return the query as a JSON string. """ output = {"query": self.query} return json.dumps(output) + def to_dict(self): + """ + Return the query as a dictionary. + """ + return {"query": self.query} + def __iter__(self): return self.__stream() @@ -244,12 +261,12 @@ def __stream(self, debug=False): rate.init() response = requests.post( self.url, - json={"query": self.query}, + json=self.to_dict(), stream=True, auth=(self.user, self.password) ) logger.debug('POST %s', self.url) - logger.debug("BODY %s", self.toJson()) + logger.debug("BODY %s", self.to_json()) logger.debug('STATUS CODE %s', response.status_code) response.raise_for_status() From f2f8f812fc8e368b939237ad7ce513541fabf22e Mon Sep 17 00:00:00 2001 From: adamstruck Date: Fri, 21 Sep 2018 12:09:44 -0700 Subject: [PATCH 2/2] gripql/python: added a basic test --- gripql/python/tests/wrap_value_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 gripql/python/tests/wrap_value_test.py diff --git a/gripql/python/tests/wrap_value_test.py b/gripql/python/tests/wrap_value_test.py new file mode 100644 index 00000000..83836b68 --- /dev/null +++ b/gripql/python/tests/wrap_value_test.py @@ -0,0 +1,18 @@ +import unittest + +from gripql.query import Query + + +class TestQueryFormat(unittest.TestCase): + + def test_wrapping_none(self): + q = Query("localhost", "test") + self.assertEqual(q.V().in_().to_json(), q.V().in_(None).to_json()) + self.assertEqual(q.V().out().to_json(), q.V().out(None).to_json()) + self.assertEqual(q.V().both().to_json(), q.V().both(None).to_json()) + with self.assertRaises(TypeError): + q.V().in_(["foo", None]).to_json() + with self.assertRaises(TypeError): + q.V().in_(["foo", 1]).to_json() + with self.assertRaises(TypeError): + q.V().in_(1).to_json()