diff --git a/splitio/client/client.py b/splitio/client/client.py index d4c37fa4..8e71030e 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -201,7 +201,7 @@ def __init__(self, factory, recorder, labels_enabled=True): :rtype: Client """ ClientBase.__init__(self, factory, recorder, labels_enabled) - self._context_factory = EvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments')) + self._context_factory = EvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments'), factory._get_storage('rule_based_segments')) def destroy(self): """ @@ -668,7 +668,7 @@ def __init__(self, factory, recorder, labels_enabled=True): :rtype: Client """ ClientBase.__init__(self, factory, recorder, labels_enabled) - self._context_factory = AsyncEvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments')) + self._context_factory = AsyncEvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments'), factory._get_storage('rule_based_segments')) async def destroy(self): """ diff --git a/splitio/engine/evaluator.py b/splitio/engine/evaluator.py index f913ebba..80a75eec 100644 --- a/splitio/engine/evaluator.py +++ b/splitio/engine/evaluator.py @@ -6,10 +6,11 @@ from splitio.models.grammar.condition import ConditionType from splitio.models.grammar.matchers.misc import DependencyMatcher from splitio.models.grammar.matchers.keys import UserDefinedSegmentMatcher +from splitio.models.grammar.matchers.rule_based_segment import RuleBasedSegmentMatcher from splitio.optional.loaders import asyncio CONTROL = 'control' -EvaluationContext = namedtuple('EvaluationContext', ['flags', 'segment_memberships']) +EvaluationContext = namedtuple('EvaluationContext', ['flags', 'segment_memberships', 'segment_rbs_memberships', 'segment_rbs_conditions']) _LOGGER = logging.getLogger(__name__) @@ -98,9 +99,10 @@ def _treatment_for_flag(self, flag, key, bucketing, attributes, ctx): class EvaluationDataFactory: - def __init__(self, split_storage, segment_storage): + def __init__(self, split_storage, segment_storage, rbs_segment_storage): self._flag_storage = split_storage self._segment_storage = segment_storage + self._rbs_segment_storage = rbs_segment_storage def context_for(self, key, feature_names): """ @@ -114,28 +116,50 @@ def context_for(self, key, feature_names): pending = set(feature_names) splits = {} pending_memberships = set() + pending_rbs_memberships = set() while pending: fetched = self._flag_storage.fetch_many(list(pending)) features = filter_missing(fetched) splits.update(features) pending = set() for feature in features.values(): - cf, cs = get_dependencies(feature) + cf, cs, crbs = get_dependencies(feature) pending.update(filter(lambda f: f not in splits, cf)) pending_memberships.update(cs) - - return EvaluationContext(splits, { - segment: self._segment_storage.segment_contains(segment, key) - for segment in pending_memberships - }) - + pending_rbs_memberships.update(crbs) + + rbs_segment_memberships = {} + rbs_segment_conditions = {} + key_membership = False + segment_memberhsip = False + for rbs_segment in pending_rbs_memberships: + key_membership = key in self._rbs_segment_storage.get(rbs_segment).excluded.get_excluded_keys() + segment_memberhsip = False + for segment_name in self._rbs_segment_storage.get(rbs_segment).excluded.get_excluded_segments(): + if self._segment_storage.segment_contains(segment_name, key): + segment_memberhsip = True + break + + rbs_segment_memberships.update({rbs_segment: segment_memberhsip or key_membership}) + if not (segment_memberhsip or key_membership): + rbs_segment_conditions.update({rbs_segment: [condition for condition in self._rbs_segment_storage.get(rbs_segment).conditions]}) + + return EvaluationContext( + splits, + { segment: self._segment_storage.segment_contains(segment, key) + for segment in pending_memberships + }, + rbs_segment_memberships, + rbs_segment_conditions + ) class AsyncEvaluationDataFactory: - def __init__(self, split_storage, segment_storage): + def __init__(self, split_storage, segment_storage, rbs_segment_storage): self._flag_storage = split_storage self._segment_storage = segment_storage - + self._rbs_segment_storage = rbs_segment_storage + async def context_for(self, key, feature_names): """ Recursively iterate & fetch all data required to evaluate these flags. @@ -148,23 +172,47 @@ async def context_for(self, key, feature_names): pending = set(feature_names) splits = {} pending_memberships = set() + pending_rbs_memberships = set() while pending: fetched = await self._flag_storage.fetch_many(list(pending)) features = filter_missing(fetched) splits.update(features) pending = set() for feature in features.values(): - cf, cs = get_dependencies(feature) + cf, cs, crbs = get_dependencies(feature) pending.update(filter(lambda f: f not in splits, cf)) pending_memberships.update(cs) - + pending_rbs_memberships.update(crbs) + segment_names = list(pending_memberships) segment_memberships = await asyncio.gather(*[ self._segment_storage.segment_contains(segment, key) for segment in segment_names ]) - return EvaluationContext(splits, dict(zip(segment_names, segment_memberships))) + rbs_segment_memberships = {} + rbs_segment_conditions = {} + key_membership = False + segment_memberhsip = False + for rbs_segment in pending_rbs_memberships: + rbs_segment_obj = await self._rbs_segment_storage.get(rbs_segment) + key_membership = key in rbs_segment_obj.excluded.get_excluded_keys() + segment_memberhsip = False + for segment_name in rbs_segment_obj.excluded.get_excluded_segments(): + if await self._segment_storage.segment_contains(segment_name, key): + segment_memberhsip = True + break + + rbs_segment_memberships.update({rbs_segment: segment_memberhsip or key_membership}) + if not (segment_memberhsip or key_membership): + rbs_segment_conditions.update({rbs_segment: [condition for condition in rbs_segment_obj.conditions]}) + + return EvaluationContext( + splits, + dict(zip(segment_names, segment_memberships)), + rbs_segment_memberships, + rbs_segment_conditions + ) def get_dependencies(feature): @@ -173,14 +221,17 @@ def get_dependencies(feature): """ feature_names = [] segment_names = [] + rbs_segment_names = [] for condition in feature.conditions: for matcher in condition.matchers: + if isinstance(matcher,RuleBasedSegmentMatcher): + rbs_segment_names.append(matcher._rbs_segment_name) if isinstance(matcher,UserDefinedSegmentMatcher): segment_names.append(matcher._segment_name) elif isinstance(matcher, DependencyMatcher): feature_names.append(matcher._split_name) - return feature_names, segment_names + return feature_names, segment_names, rbs_segment_names def filter_missing(features): return {k: v for (k, v) in features.items() if v is not None} diff --git a/tests/engine/test_evaluator.py b/tests/engine/test_evaluator.py index 67c7387d..6268ad1d 100644 --- a/tests/engine/test_evaluator.py +++ b/tests/engine/test_evaluator.py @@ -2,12 +2,108 @@ import logging import pytest -from splitio.models.splits import Split +from splitio.models.splits import Split, Status from splitio.models.grammar.condition import Condition, ConditionType from splitio.models.impressions import Label +from splitio.models.grammar import condition +from splitio.models import rule_based_segments from splitio.engine import evaluator, splitters from splitio.engine.evaluator import EvaluationContext +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, InMemoryRuleBasedSegmentStorage, \ + InMemorySplitStorageAsync, InMemorySegmentStorageAsync, InMemoryRuleBasedSegmentStorageAsync +from splitio.engine.evaluator import EvaluationDataFactory, AsyncEvaluationDataFactory +rbs_raw = { + "changeNumber": 123, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] +} + +split_conditions = [ + condition.from_raw({ + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "in rule based segment sample_rule_based_segment" + }), + condition.from_raw({ + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "ALL_KEYS", + "negate": False + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ], + "label": "default rule" + }) +] + class EvaluatorTests(object): """Test evaluator behavior.""" @@ -27,7 +123,7 @@ def test_evaluate_treatment_killed_split(self, mocker): mocked_split.killed = True mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = '{"some_property": 123}' - ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set()) + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={}, segment_rbs_conditions={}) result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'off' assert result['configurations'] == '{"some_property": 123}' @@ -45,7 +141,7 @@ def test_evaluate_treatment_ok(self, mocker): mocked_split.killed = False mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = '{"some_property": 123}' - ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set()) + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={}, segment_rbs_conditions={}) result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'on' assert result['configurations'] == '{"some_property": 123}' @@ -54,7 +150,6 @@ def test_evaluate_treatment_ok(self, mocker): assert mocked_split.get_configurations_for.mock_calls == [mocker.call('on')] assert result['impressions_disabled'] == mocked_split.impressions_disabled - def test_evaluate_treatment_ok_no_config(self, mocker): """Test that a killed split returns the default treatment.""" e = self._build_evaluator_with_mocks(mocker) @@ -65,7 +160,7 @@ def test_evaluate_treatment_ok_no_config(self, mocker): mocked_split.killed = False mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = None - ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set()) + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={}, segment_rbs_conditions={}) result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'on' assert result['configurations'] == None @@ -92,7 +187,7 @@ def test_evaluate_treatments(self, mocker): mocked_split2.change_number = 123 mocked_split2.get_configurations_for.return_value = None - ctx = EvaluationContext(flags={'feature2': mocked_split, 'feature4': mocked_split2}, segment_memberships=set()) + ctx = EvaluationContext(flags={'feature2': mocked_split, 'feature4': mocked_split2}, segment_memberships=set(), segment_rbs_memberships={}, segment_rbs_conditions={}) results = e.eval_many_with_context('some_key', 'some_bucketing_key', ['feature2', 'feature4'], {}, ctx) result = results['feature4'] assert result['configurations'] == None @@ -115,7 +210,7 @@ def test_get_gtreatment_for_split_no_condition_matches(self, mocker): mocked_split.change_number = '123' mocked_split.conditions = [] mocked_split.get_configurations_for = None - ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set()) + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={}, segment_rbs_conditions={}) assert e._treatment_for_flag(mocked_split, 'some_key', 'some_bucketing', {}, ctx) == ( 'off', Label.NO_CONDITION_MATCHED @@ -132,6 +227,64 @@ def test_get_gtreatment_for_split_non_rollout(self, mocker): mocked_split = mocker.Mock(spec=Split) mocked_split.killed = False mocked_split.conditions = [mocked_condition_1] - treatment, label = e._treatment_for_flag(mocked_split, 'some_key', 'some_bucketing', {}, EvaluationContext(None, None)) + treatment, label = e._treatment_for_flag(mocked_split, 'some_key', 'some_bucketing', {}, EvaluationContext(None, None, None, None)) assert treatment == 'on' assert label == 'some_label' + + def test_evaluate_treatment_with_rule_based_segment(self, mocker): + """Test that a non-killed split returns the appropriate treatment.""" + e = evaluator.Evaluator(splitters.Splitter()) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={'sample_rule_based_segment': False}, segment_rbs_conditions={'sample_rule_based_segment': rule_based_segments.from_raw(rbs_raw).conditions}) + result = e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx) + assert result['treatment'] == 'on' + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), segment_rbs_memberships={'sample_rule_based_segment': True}, segment_rbs_conditions={'sample_rule_based_segment': []}) + result = e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx) + assert result['treatment'] == 'off' + +class EvaluationDataFactoryTests(object): + """Test evaluation factory class.""" + + def test_get_context(self): + """Test context.""" + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + flag_storage = InMemorySplitStorage([]) + segment_storage = InMemorySegmentStorage() + rbs_segment_storage = InMemoryRuleBasedSegmentStorage() + flag_storage.update([mocked_split], [], -1) + rbs = rule_based_segments.from_raw(rbs_raw) + rbs_segment_storage.update([rbs], [], -1) + + eval_factory = EvaluationDataFactory(flag_storage, segment_storage, rbs_segment_storage) + ec = eval_factory.context_for('bilal@split.io', ['some']) + assert ec.segment_rbs_conditions == {'sample_rule_based_segment': rbs.conditions} + assert ec.segment_rbs_memberships == {'sample_rule_based_segment': False} + + ec = eval_factory.context_for('mauro@split.io', ['some']) + assert ec.segment_rbs_conditions == {} + assert ec.segment_rbs_memberships == {'sample_rule_based_segment': True} + +class EvaluationDataFactoryAsyncTests(object): + """Test evaluation factory class.""" + + @pytest.mark.asyncio + async def test_get_context(self): + """Test context.""" + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + flag_storage = InMemorySplitStorageAsync([]) + segment_storage = InMemorySegmentStorageAsync() + rbs_segment_storage = InMemoryRuleBasedSegmentStorageAsync() + await flag_storage.update([mocked_split], [], -1) + rbs = rule_based_segments.from_raw(rbs_raw) + await rbs_segment_storage.update([rbs], [], -1) + + eval_factory = AsyncEvaluationDataFactory(flag_storage, segment_storage, rbs_segment_storage) + ec = await eval_factory.context_for('bilal@split.io', ['some']) + assert ec.segment_rbs_conditions == {'sample_rule_based_segment': rbs.conditions} + assert ec.segment_rbs_memberships == {'sample_rule_based_segment': False} + + ec = await eval_factory.context_for('mauro@split.io', ['some']) + assert ec.segment_rbs_conditions == {} + assert ec.segment_rbs_memberships == {'sample_rule_based_segment': True}