From 62cfaa543bc851628275ee4fb3d13b6d192bdbd8 Mon Sep 17 00:00:00 2001 From: hackeT <40039738+Tatsuya-hasegawa@users.noreply.github.com> Date: Thu, 11 May 2023 08:44:09 +0900 Subject: [PATCH] Fix a critical bug of Splunk results reader, lack of pagination (#657) * fix a critical bug of splunk result reader * typo pagenate -> paginate * Refactored code and reformatted long lines. Updated failing tests for new code. --------- Co-authored-by: Ian Hellen --- msticpy/data/drivers/splunk_driver.py | 151 +++++++++++++++++++---- tests/data/drivers/test_splunk_driver.py | 28 ++++- 2 files changed, 155 insertions(+), 24 deletions(-) diff --git a/msticpy/data/drivers/splunk_driver.py b/msticpy/data/drivers/splunk_driver.py index 68099da6a..b0c5694ef 100644 --- a/msticpy/data/drivers/splunk_driver.py +++ b/msticpy/data/drivers/splunk_driver.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Splunk Driver class.""" -from datetime import datetime +import logging +from datetime import datetime, timedelta from time import sleep from typing import Any, Dict, Iterable, Optional, Tuple, Union @@ -14,6 +15,7 @@ from ..._version import VERSION from ...common.exceptions import ( MsticpyConnectionError, + MsticpyDataQueryError, MsticpyImportExtraError, MsticpyUserConfigError, ) @@ -35,6 +37,8 @@ __version__ = VERSION __author__ = "Ashwin Patil" +logger = logging.getLogger(__name__) + SPLUNK_CONNECT_ARGS = { "host": "(string) The host name (the default is 'localhost').", @@ -73,7 +77,9 @@ def __init__(self, **kwargs): self.service = None self._loaded = True self._connected = False - self._debug = kwargs.get("debug", False) + if kwargs.get("debug", False): + logger.setLevel(logging.DEBUG) + self.set_driver_property( DriverProps.PUBLIC_ATTRS, { @@ -194,9 +200,17 @@ def query( Other Parameters ---------------- count : int, optional - Passed to Splunk oneshot method if `oneshot` is True, by default, 0 + Passed to Splunk job that indicates the maximum number + of entities to return. A value of 0 indicates no maximum, + by default, 0 oneshot : bool, optional Set to True for oneshot (blocking) mode, by default False + page_size = int, optional + Pass to Splunk results reader in terms of fetch speed, + which sets of result amount will be got at a time, + by default, 100 + timeout : int, optional + Amount of time to wait for results, by default 60 Returns ------- @@ -212,35 +226,38 @@ def query( # default to unlimited query unless count is specified count = kwargs.pop("count", 0) - # Normal, oneshot or blocking searches. Defaults to non-blocking - # Oneshot is blocking a blocking HTTP call which may cause time-outs - # https://dev.splunk.com/enterprise/docs/python/sdk-python/howtousesplunkpython/howtorunsearchespython + # Get sets of N results at a time, N=100 by default + page_size = kwargs.pop("page_size", 100) + + # Normal (non-blocking) searches or oneshot (blocking) searches. + # Defaults to Normal(non-blocking) + + # Oneshot is a blocking search that is scheduled to run immediately. + # Instead of returning a search job, this mode returns the results + # of the search once completed. + # Because this is a blocking search, the results are not available + # until the search has finished. + # https://dev.splunk.com/enterprise/docs/python/ + # sdk-python/howtousesplunkpython/howtorunsearchespython is_oneshot = kwargs.get("oneshot", False) if is_oneshot is True: + kwargs["output_mode"] = "json" query_results = self.service.jobs.oneshot(query, count=count, **kwargs) - reader = sp_results.ResultsReader(query_results) + reader = sp_results.JSONResultsReader( # pylint: disable=no-member + query_results + ) # due to DeprecationWarning of normal ResultsReader + resp_rows = [row for row in reader if isinstance(row, dict)] else: # Set mode and initialize async job kwargs_normalsearch = {"exec_mode": "normal"} - query_job = self.service.jobs.create(query, **kwargs_normalsearch) - - # Initiate progress bar and start while loop, waiting for async query to complete - progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete") - while not query_job.is_done(): - current_state = query_job.state - progress = float(current_state["content"]["doneProgress"]) * 100 - progress_bar.update(progress) - sleep(1) - - # Update progress bar indicating completion and fetch results - progress_bar.update(100) - progress_bar.close() - reader = sp_results.ResultsReader(query_job.results()) + query_job = self.service.jobs.create( + query, count=count, **kwargs_normalsearch + ) + resp_rows, reader = self._exec_async_search(query_job, page_size, **kwargs) - resp_rows = [row for row in reader if isinstance(row, dict)] - if not resp_rows: + if len(resp_rows) == 0 or not resp_rows: print("Warning - query did not return any results.") return [row for row in reader if isinstance(row, sp_results.Message)] return pd.DataFrame(resp_rows) @@ -316,6 +333,94 @@ def driver_queries(self) -> Iterable[Dict[str, Any]]: ] return [] + def _exec_async_search(self, query_job, page_size, timeout=60): + """Execute an async search and return results.""" + # Initiate progress bar and start while loop, waiting for async query to complete + progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete") + prev_progress = 0 + offset = 0 # Start at result 0 + start_time = datetime.now() + end_time = start_time + timedelta(seconds=timeout) + while True: + while not query_job.is_ready(): + sleep(1) + if self._retrieve_job_status(query_job, progress_bar, prev_progress): + break + if datetime.now() > end_time: + raise MsticpyDataQueryError( + "Timeout waiting for Splunk query to complete", + f"Job completion reported {query_job['doneProgress']}", + title="Splunk query timeout", + ) + sleep(1) + # Update progress bar indicating job completion + progress_bar.update(100) + progress_bar.close() + sleep(2) + + logger.info("Implicit parameter dump - 'page_size': %d", page_size) + return self._retrieve_results(query_job, offset, page_size) + + @staticmethod + def _retrieve_job_status(query_job, progress_bar, prev_progress): + """Poll the status of a job and update the progress bar.""" + stats = { + "is_done": query_job["isDone"], + "done_progress": float(query_job["doneProgress"]) * 100, + "scan_count": int(query_job["scanCount"]), + "event_count": int(query_job["eventCount"]), + "result_count": int(query_job["resultCount"]), + } + status = ( + "\r%(done_progress)03.1f%% %(scan_count)d scanned " + "%(event_count)d matched %(result_count)d results" + ) % stats + if prev_progress == 0: + progress = stats["done_progress"] + else: + progress = stats["done_progress"] - prev_progress + prev_progress = stats["done_progress"] + progress_bar.update(progress) + + if stats["is_done"] == "1": + logger.info(status) + logger.info("Splunk job completed.") + return True + return False + + @staticmethod + def _retrieve_results(query_job, offset, page_size): + """Retrieve the results of a job, decode and return them.""" + # Retrieving all the results by paginate + result_count = int( + query_job["resultCount"] + ) # Number of results this job returned + + resp_rows = [] + progress_bar_paginate = tqdm( + total=result_count, desc="Waiting Splunk result to retrieve" + ) + while offset < result_count: + kwargs_paginate = { + "count": page_size, + "offset": offset, + "output_mode": "json", + } + # Get the search results and display them + search_results = query_job.results(**kwargs_paginate) + # due to DeprecationWarning of normal ResultsReader + reader = sp_results.JSONResultsReader( # pylint: disable=no-member + search_results + ) + resp_rows.extend([row for row in reader if isinstance(row, dict)]) + progress_bar_paginate.update(page_size) + offset += page_size + # Update progress bar indicating fetch results + progress_bar_paginate.update(result_count) + progress_bar_paginate.close() + logger.info("Retrieved %d results.", len(resp_rows)) + return resp_rows, reader + @property def _saved_searches(self) -> Union[pd.DataFrame, Any]: """ diff --git a/tests/data/drivers/test_splunk_driver.py b/tests/data/drivers/test_splunk_driver.py index 15fc3a8a3..68c2d6f0f 100644 --- a/tests/data/drivers/test_splunk_driver.py +++ b/tests/data/drivers/test_splunk_driver.py @@ -13,6 +13,7 @@ from msticpy.common.exceptions import ( MsticpyConnectionError, + MsticpyDataQueryError, MsticpyNotConnectedError, MsticpyUserConfigError, ) @@ -69,15 +70,35 @@ def __init__(self, name, count): class _MockAsyncResponse: + stats = { + "isDone": "0", + "doneProgress": 0.0, + "scanCount": 1, + "eventCount": 100, + "resultCount": 100, + } + def __init__(self, query): self.query = query - def results(self): + def __getitem__(self, key): + """Mock method.""" + return self.stats[key] + + def results(self, **kwargs): return self.query def is_done(self): return True + def is_ready(self): + return True + + @classmethod + def set_done(cls): + cls.stats["isDone"] = "1" + cls.stats["doneProgress"] = 1 + class _MockSplunkCall: def create(query, **kwargs): @@ -260,6 +281,7 @@ def test_splunk_query_success(splunk_client, splunk_results): splunk_client.connect = cli_connect sp_driver = SplunkDriver() splunk_results.ResultsReader = _results_reader + splunk_results.JSONResultsReader = _results_reader # trying to get these before connecting should throw with pytest.raises(MsticpyNotConnectedError) as mp_ex: @@ -279,6 +301,10 @@ def test_splunk_query_success(splunk_client, splunk_results): check.is_not_instance(response, pd.DataFrame) check.equal(len(response), 0) + with pytest.raises(MsticpyDataQueryError): + df_result = sp_driver.query("some query", timeout=1) + + _MockAsyncResponse.set_done() df_result = sp_driver.query("some query") check.is_instance(df_result, pd.DataFrame) check.equal(len(df_result), 10)