161
161
start = "start" ,
162
162
)
163
163
164
- __version__ = "0.8 .0"
164
+ __version__ = "0.9 .0"
165
165
166
166
167
167
_ALPHABET = string .ascii_lowercase + string .digits
@@ -235,8 +235,8 @@ def _is_edge_attr_match(
235
235
motif_edges = _aggregate_edge_labels (motif_edges )
236
236
host_edges = _aggregate_edge_labels (host_edges )
237
237
238
- motif_types = motif_edges .get (' __labels__' , set ())
239
- host_types = host_edges .get (' __labels__' , set ())
238
+ motif_types = motif_edges .get (" __labels__" , set ())
239
+ host_types = host_edges .get (" __labels__" , set ())
240
240
241
241
if motif_types and not motif_types .intersection (host_types ):
242
242
return False
@@ -246,7 +246,7 @@ def _is_edge_attr_match(
246
246
continue
247
247
if host_edges .get (attr ) != val :
248
248
return False
249
-
249
+
250
250
return True
251
251
252
252
@@ -271,6 +271,7 @@ def _aggregate_edge_labels(edges: Dict) -> Dict:
271
271
aggregated [edge_id ] = attrs
272
272
return aggregated
273
273
274
+
274
275
def _get_entity_from_host (
275
276
host : Union [nx .DiGraph , nx .MultiDiGraph ], entity_name , entity_attribute = None
276
277
):
@@ -288,7 +289,7 @@ def _get_entity_from_host(
288
289
edge_data = host .get_edge_data (* entity_name )
289
290
if not edge_data :
290
291
return None # print(f"Nothing found for {entity_name} {entity_attribute}")
291
-
292
+
292
293
if entity_attribute :
293
294
# looking for edge attribute:
294
295
if isinstance (host , nx .MultiDiGraph ):
@@ -491,15 +492,16 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
491
492
for r in ret :
492
493
r_attr = {}
493
494
for i , v in r .items ():
494
- r_attr [(i , list (v .get ('__labels__' ))[0 ])] = v .get (entity_attribute , None )
495
+ r_attr [(i , list (v .get ("__labels__" ))[0 ])] = v .get (
496
+ entity_attribute , None
497
+ )
495
498
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
496
499
ret_with_attr .append (r_attr )
497
-
500
+
498
501
ret = ret_with_attr
499
502
500
503
result [data_path ] = list (ret )[offset_limit ]
501
504
502
-
503
505
return result
504
506
505
507
def return_clause (self , clause ):
@@ -519,7 +521,6 @@ def return_clause(self, clause):
519
521
item = str (item .value )
520
522
self ._return_requests .append (item )
521
523
522
-
523
524
def order_clause (self , order_clause ):
524
525
self ._order_by = []
525
526
for item in order_clause [0 ].children :
@@ -544,7 +545,6 @@ def skip_clause(self, skip):
544
545
skip = int (skip [- 1 ])
545
546
self ._skip = skip
546
547
547
-
548
548
def aggregate (self , func , results , entity , group_keys ):
549
549
# Collect data based on group keys
550
550
grouped_data = {}
@@ -558,12 +558,24 @@ def _collate_data(data, unique_labels, func):
558
558
# for ["COUNT", "SUM", "AVG"], we treat None as 0
559
559
if func in ["COUNT" , "SUM" , "AVG" ]:
560
560
collated_data = {
561
- label : [(v or 0 ) for rel in data for k , v in rel .items () if k [1 ] == label ] for label in unique_labels
561
+ label : [
562
+ (v or 0 )
563
+ for rel in data
564
+ for k , v in rel .items ()
565
+ if k [1 ] == label
566
+ ]
567
+ for label in unique_labels
562
568
}
563
569
# for ["MAX", "MIN"], we treat None as non-existent
564
570
elif func in ["MAX" , "MIN" ]:
565
571
collated_data = {
566
- label : [v for rel in data for k , v in rel .items () if (k [1 ] == label and v is not None )] for label in unique_labels
572
+ label : [
573
+ v
574
+ for rel in data
575
+ for k , v in rel .items ()
576
+ if (k [1 ] == label and v is not None )
577
+ ]
578
+ for label in unique_labels
567
579
}
568
580
569
581
return collated_data
@@ -583,7 +595,14 @@ def _collate_data(data, unique_labels, func):
583
595
elif func == "AVG" :
584
596
sum_data = {label : sum (data ) for label , data in collated_data .items ()}
585
597
count_data = {label : len (data ) for label , data in collated_data .items ()}
586
- avg_data = {label : sum_data [label ] / count_data [label ] if count_data [label ] > 0 else 0 for label in sum_data }
598
+ avg_data = {
599
+ label : (
600
+ sum_data [label ] / count_data [label ]
601
+ if count_data [label ] > 0
602
+ else 0
603
+ )
604
+ for label in sum_data
605
+ }
587
606
aggregate_results [group ] = avg_data
588
607
elif func == "MAX" :
589
608
max_data = {label : max (data ) for label , data in collated_data .items ()}
@@ -602,7 +621,11 @@ def returns(self, ignore_limit=False):
602
621
offset_limit = slice (0 , None ),
603
622
)
604
623
if len (self ._aggregate_functions ) > 0 :
605
- group_keys = [key for key in results .keys () if not any (key .endswith (func [1 ]) for func in self ._aggregate_functions )]
624
+ group_keys = [
625
+ key
626
+ for key in results .keys ()
627
+ if not any (key .endswith (func [1 ]) for func in self ._aggregate_functions )
628
+ ]
606
629
607
630
aggregated_results = {}
608
631
for func , entity in self ._aggregate_functions :
@@ -865,7 +888,9 @@ def flatten_tokens(edge_tokens):
865
888
flat_tokens = []
866
889
for token in edge_tokens :
867
890
if isinstance (token , Tree ):
868
- flat_tokens .extend (flatten_tokens (token .children )) # Recursively flatten the tree
891
+ flat_tokens .extend (
892
+ flatten_tokens (token .children )
893
+ ) # Recursively flatten the tree
869
894
else :
870
895
flat_tokens .append (token )
871
896
return flat_tokens
0 commit comments