diff --git a/esrally/driver/scheduler.py b/esrally/driver/scheduler.py index 1613885b3..f741d8913 100644 --- a/esrally/driver/scheduler.py +++ b/esrally/driver/scheduler.py @@ -90,10 +90,11 @@ def scheduler_for(task: esrally.track.Task): if not task.throttled: return Unthrottled() + schedule = task.schedule or "deterministic" try: - scheduler_class = __SCHEDULERS[task.schedule] + scheduler_class = __SCHEDULERS[schedule] except KeyError: - raise exceptions.RallyError(f"No scheduler available for name [{task.schedule}]") + raise exceptions.RallyError(f"No scheduler available for name [{schedule}]") # for backwards-compatibility - treat existing schedulers as top-level schedulers if is_legacy_scheduler(scheduler_class): @@ -113,7 +114,7 @@ def is_legacy_scheduler(scheduler_class): target throughput. """ constructor_params = inspect.signature(scheduler_class.__init__).parameters - return len(constructor_params) == 2 and "params" in constructor_params + return len(constructor_params) >= 2 and "params" in constructor_params def is_simple_scheduler(scheduler_class): @@ -195,7 +196,7 @@ def __init__(self, task, legacy_scheduler_class): self.legacy_scheduler = legacy_scheduler_class(task.params) def next(self, current): - return self.legacy_scheduler(current) + return self.legacy_scheduler.next(current) class Unthrottled(Scheduler): diff --git a/esrally/track/loader.py b/esrally/track/loader.py index 7ac92f06a..19afe9624 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -796,7 +796,7 @@ def post_process_for_test_mode(t): logger.debug("Resetting measurement time period for [%s] to [%d] seconds.", str(leaf_task), leaf_task.time_period) # Keep throttled to expose any errors but increase the target throughput for short execution times. - if leaf_task.throttled: + if leaf_task.throttled and leaf_task.target_throughput: original_throughput = leaf_task.target_throughput leaf_task.params.pop("target-throughput", None) leaf_task.params.pop("target-interval", None) @@ -1321,7 +1321,7 @@ def parse_task(self, task_spec, ops, challenge_name, default_warmup_iterations=N # may as well an inline operation op = self.parse_operation(op_spec, error_ctx="inline operation in challenge %s" % challenge_name) - schedule = self._r(task_spec, "schedule", error_ctx=op.name, mandatory=False, default_value="deterministic") + schedule = self._r(task_spec, "schedule", error_ctx=op.name, mandatory=False) task_name = self._r(task_spec, "name", error_ctx=op.name, mandatory=False, default_value=op.name) task = track.Task(name=task_name, operation=op, diff --git a/esrally/track/track.py b/esrally/track/track.py index 71b031b02..f6829c9a2 100644 --- a/esrally/track/track.py +++ b/esrally/track/track.py @@ -764,9 +764,8 @@ def __eq__(self, other): class Task: THROUGHPUT_PATTERN = re.compile(r"(?P(\d*\.)?\d+)\s(?P\w+/s)") - def __init__(self, name, operation, meta_data=None, warmup_iterations=None, iterations=None, warmup_time_period=None, time_period=None, - clients=1, - completes_parent=False, schedule="deterministic", params=None): + def __init__(self, name, operation, meta_data=None, warmup_iterations=None, iterations=None, warmup_time_period=None, + time_period=None, clients=1, completes_parent=False, schedule=None, params=None): self.name = name self.operation = operation self.meta_data = meta_data if meta_data else {} @@ -824,7 +823,7 @@ def numeric(v): @property def throttled(self): - return self.target_throughput is not None + return self.schedule is not None or self.target_throughput is not None def __hash__(self): # Note that we do not include `params` in __hash__ and __eq__ (the other attributes suffice to uniquely define a task) diff --git a/tests/driver/scheduler_test.py b/tests/driver/scheduler_test.py index d79914fbb..503c96342 100644 --- a/tests/driver/scheduler_test.py +++ b/tests/driver/scheduler_test.py @@ -126,8 +126,14 @@ class LegacyScheduler: def __init__(self, params): pass + class LegacySchedulerWithAdditionalArgs: + # pylint: disable=unused-variable + def __init__(self, params, my_default_param=True): + pass + def test_detects_legacy_scheduler(self): self.assertTrue(scheduler.is_legacy_scheduler(SchedulerCategorizationTests.LegacyScheduler)) + self.assertTrue(scheduler.is_legacy_scheduler(SchedulerCategorizationTests.LegacySchedulerWithAdditionalArgs)) def test_a_regular_scheduler_is_not_a_legacy_scheduler(self): self.assertFalse(scheduler.is_legacy_scheduler(scheduler.DeterministicScheduler)) @@ -138,3 +144,32 @@ def test_is_simple_scheduler(self): def test_is_not_simple_scheduler(self): self.assertFalse(scheduler.is_simple_scheduler(scheduler.UnitAwareScheduler)) + + +class LegacyWrappingSchedulerTests(TestCase): + class SimpleLegacyScheduler: + # pylint: disable=unused-variable + def __init__(self, params): + pass + + def next(self, current): + return current + + def setUp(self): + scheduler.register_scheduler("simple", LegacyWrappingSchedulerTests.SimpleLegacyScheduler) + + def tearDown(self): + scheduler.remove_scheduler("simple") + + def test_legacy_scheduler(self): + task = track.Task(name="raw-request", + operation=track.Operation( + name="raw", + operation_type=track.OperationType.RawRequest.name), + clients=1, + schedule="simple") + + s = scheduler.scheduler_for(task) + + self.assertEqual(0, s.next(0)) + self.assertEqual(0, s.next(0)) diff --git a/tests/track/track_test.py b/tests/track/track_test.py index 4466359fa..683eff54e 100644 --- a/tests/track/track_test.py +++ b/tests/track/track_test.py @@ -226,26 +226,34 @@ def test_cannot_union_mixed_document_corpora_by_meta_data(self): class TaskTests(TestCase): - def task(self, target_throughput=None, target_interval=None): + def task(self, schedule=None, target_throughput=None, target_interval=None): op = track.Operation("bulk-index", track.OperationType.Bulk.name) params = {} if target_throughput: params["target-throughput"] = target_throughput if target_interval: params["target-interval"] = target_interval - return track.Task("test", op, params=params) + return track.Task("test", op, schedule=schedule, params=params) def test_unthrottled_task(self): task = self.task() self.assertIsNone(task.target_throughput) + self.assertFalse(task.throttled) + + def test_task_with_scheduler_is_throttled(self): + task = self.task(schedule="daily-traffic-pattern") + self.assertIsNone(task.target_throughput) + self.assertTrue(task.throttled) def test_valid_throughput_with_unit(self): task = self.task(target_throughput="5 MB/s") self.assertEqual(track.Throughput(5.0, "MB/s"), task.target_throughput) + self.assertTrue(task.throttled) def test_valid_throughput_numeric(self): task = self.task(target_throughput=3.2) self.assertEqual(track.Throughput(3.2, "ops/s"), task.target_throughput) + self.assertTrue(task.throttled) def test_invalid_throughput_format_is_rejected(self): task = self.task(target_throughput="3.2 docs")