Skip to content

Commit

Permalink
Actually use client returned from '.options()' (elastic#1710)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-deam authored May 2, 2023
1 parent c175607 commit 4a0294c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
6 changes: 3 additions & 3 deletions esrally/driver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def __init__(self):
async def __call__(self, es, params):
params, request_params, transport_params, headers = self._transport_request_params(params)
# we don't set headers at the options level because the Query runner sets them via the client's '_perform_request' method
es.options(**transport_params)
es = es.options(**transport_params)
# Mandatory to ensure it is always provided. This is especially important when this runner is used in a
# composite context where there is no actual parameter source and the entire request structure must be provided
# by the composite's parameter source.
Expand Down Expand Up @@ -1946,7 +1946,7 @@ def __repr__(self, *args, **kwargs):
class RawRequest(Runner):
async def __call__(self, es, params):
params, request_params, transport_params, headers = self._transport_request_params(params)
es.options(**transport_params)
es = es.options(**transport_params)

path = mandatory(params, "path", self)

Expand Down Expand Up @@ -2747,7 +2747,7 @@ class Downsample(Runner):

async def __call__(self, es, params):
params, request_params, transport_params, request_headers = self._transport_request_params(params)
es.options(**transport_params)
es = es.options(**transport_params)

fixed_interval = mandatory(params, "fixed-interval", self)
if fixed_interval is None:
Expand Down
1 change: 1 addition & 0 deletions tests/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ async def test_execute_schedule_throughput_throttled(self, es):
async def perform_request(*args, **kwargs):
return None

es.options.return_value = es
es.init_request_context.return_value = {"request_start": 0, "request_end": 10}
# as this method is called several times we need to return a fresh instance every time as the previous
# one has been "consumed".
Expand Down
31 changes: 31 additions & 0 deletions tests/driver/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,7 @@ class TestQueryRunner:
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_match_only_request_body_defined(self, es):
es.options.return_value = es
search_response = {
"timed_out": False,
"took": 5,
Expand Down Expand Up @@ -1601,6 +1602,7 @@ async def test_query_match_only_request_body_defined(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_with_timeout_and_headers(self, es):
es.options.return_value = es
search_response = {
"timed_out": False,
"took": 5,
Expand Down Expand Up @@ -1655,6 +1657,7 @@ async def test_query_with_timeout_and_headers(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_match_using_request_params(self, es):
es.options.return_value = es
response = {
"timed_out": False,
"took": 62,
Expand Down Expand Up @@ -1713,6 +1716,7 @@ async def test_query_match_using_request_params(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_no_detailed_results(self, es):
es.options.return_value = es
response = {
"timed_out": False,
"took": 62,
Expand Down Expand Up @@ -1766,6 +1770,7 @@ async def test_query_no_detailed_results(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_hits_total_as_number(self, es):
es.options.return_value = es
search_response = {
"timed_out": False,
"took": 5,
Expand Down Expand Up @@ -1822,6 +1827,7 @@ async def test_query_hits_total_as_number(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_match_all(self, es):
es.options.return_value = es
search_response = {
"timed_out": False,
"took": 5,
Expand Down Expand Up @@ -1881,6 +1887,7 @@ async def test_query_match_all(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_match_all_doc_type_fallback(self, es):
es.options.return_value = es
search_response = {
"timed_out": False,
"took": 5,
Expand Down Expand Up @@ -1937,6 +1944,7 @@ async def test_query_match_all_doc_type_fallback(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_only_one_page(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2000,6 +2008,7 @@ async def test_scroll_query_only_one_page(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_no_request_cache(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2058,6 +2067,7 @@ async def test_scroll_query_no_request_cache(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_only_one_page_only_request_body_defined(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2120,6 +2130,7 @@ async def test_scroll_query_only_one_page_only_request_body_defined(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_with_explicit_number_of_pages(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2193,6 +2204,7 @@ async def test_scroll_query_with_explicit_number_of_pages(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_cannot_clear_scroll(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2243,6 +2255,7 @@ async def test_scroll_query_cannot_clear_scroll(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_scroll_query_request_all_pages(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2311,6 +2324,7 @@ async def test_scroll_query_request_all_pages(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_runner_search_with_pages_logs_warning_and_executes(self, es):
es.options.return_value = es
# page 1
search_response = {
"_scroll_id": "some-scroll-id",
Expand Down Expand Up @@ -2368,6 +2382,7 @@ async def test_query_runner_search_with_pages_logs_warning_and_executes(self, es
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_query_runner_fails_with_unknown_operation_type(self, es):
es.options.return_value = es
query_runner = runner.Query()

params = {
Expand Down Expand Up @@ -3603,6 +3618,7 @@ class TestRawRequestRunner:
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_raises_missing_slash(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand All @@ -3619,6 +3635,7 @@ async def test_raises_missing_slash(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_issue_request_with_defaults(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand All @@ -3630,6 +3647,7 @@ async def test_issue_request_with_defaults(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_issue_delete_index(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand All @@ -3648,6 +3666,7 @@ async def test_issue_delete_index(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_issue_create_index(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand All @@ -3671,6 +3690,7 @@ async def test_issue_create_index(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_issue_msearch(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand Down Expand Up @@ -3702,6 +3722,7 @@ async def test_issue_msearch(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_raw_with_timeout_and_opaqueid(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock()
r = runner.RawRequest()

Expand Down Expand Up @@ -5376,6 +5397,7 @@ class TestDownsampleRunner:
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_index_downsample(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock(return_value=io.BytesIO(json.dumps(self.default_response).encode()))

sql_runner = runner.Downsample()
Expand All @@ -5402,6 +5424,7 @@ async def test_index_downsample(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_mandatory_fixed_interval_in_body_param(self, es):
es.options.return_value = es
sql_runner = runner.Downsample()
params = {"operation-type": "downsample", "source-index": "source-index", "target-index": "target-index"}

Expand All @@ -5415,6 +5438,7 @@ async def test_mandatory_fixed_interval_in_body_param(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_mandatory_source_index_in_body_param(self, es):
es.options.return_value = es
sql_runner = runner.Downsample()
params = {"operation-type": "downsample", "fixed-interval": "1d", "target-index": "target-index"}

Expand All @@ -5428,6 +5452,7 @@ async def test_mandatory_source_index_in_body_param(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_mandatory_target_index_in_body_param(self, es):
es.options.return_value = es
sql_runner = runner.Downsample()
params = {"operation-type": "downsample", "fixed-interval": "1d", "source-index": "source-index"}

Expand Down Expand Up @@ -5578,6 +5603,7 @@ class TestQueryWithSearchAfterScroll:
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_search_after_with_pit(self, es):
es.options.return_value = es
pit_op = "open-point-in-time1"
pit_id = "0123456789abcdef"
params = {
Expand Down Expand Up @@ -5690,6 +5716,7 @@ async def test_search_after_with_pit(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_search_after_without_pit(self, es):
es.options.return_value = es
params = {
"name": "search-with-pit",
"operation-type": "paginated-search",
Expand Down Expand Up @@ -5852,6 +5879,7 @@ class TestCompositeAgg:
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_composite_agg_without_pit(self, es):
es.options.return_value = es
params = {
"name": "composite-agg-without-pit",
"operation-type": "composite-agg",
Expand Down Expand Up @@ -5978,6 +6006,7 @@ async def test_composite_agg_without_pit(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_composite_agg_with_pit(self, es):
es.options.return_value = es
pit_op = "open-point-in-time1"
pit_id = "0123456789abcdef"
params = {
Expand Down Expand Up @@ -6353,6 +6382,7 @@ def teardown_method(self, method):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_execute_multiple_streams(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock(
side_effect=[
# raw-request
Expand Down Expand Up @@ -6426,6 +6456,7 @@ async def test_execute_multiple_streams(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@pytest.mark.asyncio
async def test_propagates_violated_assertions(self, es):
es.options.return_value = es
es.perform_request = mock.AsyncMock(
side_effect=[
# search
Expand Down

0 comments on commit 4a0294c

Please sign in to comment.