Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a critical bug of Splunk results reader, lack of pagination #657

Merged
153 changes: 129 additions & 24 deletions msticpy/data/drivers/splunk_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# license information.
# --------------------------------------------------------------------------
"""Splunk Driver class."""
from datetime import datetime
import logging
from datetime import datetime, timedelta
from time import sleep
from typing import Any, Dict, Iterable, Optional, Tuple, Union

Expand All @@ -14,6 +15,7 @@
from ..._version import VERSION
from ...common.exceptions import (
MsticpyConnectionError,
MsticpyDataQueryError,
MsticpyImportExtraError,
MsticpyUserConfigError,
)
Expand All @@ -35,6 +37,8 @@
__version__ = VERSION
__author__ = "Ashwin Patil"

logger = logging.getLogger(__name__)


SPLUNK_CONNECT_ARGS = {
"host": "(string) The host name (the default is 'localhost').",
Expand Down Expand Up @@ -73,7 +77,9 @@ def __init__(self, **kwargs):
self.service = None
self._loaded = True
self._connected = False
self._debug = kwargs.get("debug", False)
if kwargs.get("debug", False):
logger.setLevel(logging.DEBUG)

self.public_attribs = {
"client": self.service,
"saved_searches": self._saved_searches,
Expand All @@ -84,7 +90,7 @@ def __init__(self, **kwargs):
Formatters.LIST: self._format_list,
}

def connect(self, connection_str: str = None, **kwargs):
def connect(self, connection_str: Optional[str] = None, **kwargs):
"""
Connect to Splunk via splunk-sdk.

Expand Down Expand Up @@ -188,9 +194,17 @@ def query(
Other Parameters
----------------
count : int, optional
Passed to Splunk oneshot method if `oneshot` is True, by default, 0
Passed to Splunk job that indicates the maximum number
of entities to return. A value of 0 indicates no maximum,
by default, 0
oneshot : bool, optional
Set to True for oneshot (blocking) mode, by default False
page_size = int, optional
Pass to Splunk results reader in terms of fetch speed,
which sets of result amount will be got at a time,
by default, 100
timeout : int, optional
Amount of time to wait for results, by default 60

Returns
-------
Expand All @@ -206,35 +220,38 @@ def query(
# default to unlimited query unless count is specified
count = kwargs.pop("count", 0)

# Normal, oneshot or blocking searches. Defaults to non-blocking
# Oneshot is blocking a blocking HTTP call which may cause time-outs
# https://dev.splunk.com/enterprise/docs/python/sdk-python/howtousesplunkpython/howtorunsearchespython
# Get sets of N results at a time, N=100 by default
page_size = kwargs.pop("page_size", 100)

# Normal (non-blocking) searches or oneshot (blocking) searches.
# Defaults to Normal(non-blocking)

# Oneshot is a blocking search that is scheduled to run immediately.
# Instead of returning a search job, this mode returns the results
# of the search once completed.
# Because this is a blocking search, the results are not available
# until the search has finished.
# https://dev.splunk.com/enterprise/docs/python/
# sdk-python/howtousesplunkpython/howtorunsearchespython
is_oneshot = kwargs.get("oneshot", False)

if is_oneshot is True:
kwargs["output_mode"] = "json"
query_results = self.service.jobs.oneshot(query, count=count, **kwargs)
reader = sp_results.ResultsReader(query_results)

reader = sp_results.JSONResultsReader( # pylint: disable=no-member
query_results
) # due to DeprecationWarning of normal ResultsReader
resp_rows = [row for row in reader if isinstance(row, dict)]
else:
# Set mode and initialize async job
kwargs_normalsearch = {"exec_mode": "normal"}
query_job = self.service.jobs.create(query, **kwargs_normalsearch)

# Initiate progress bar and start while loop, waiting for async query to complete
progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete")
while not query_job.is_done():
current_state = query_job.state
progress = float(current_state["content"]["doneProgress"]) * 100
progress_bar.update(progress)
sleep(1)

# Update progress bar indicating completion and fetch results
progress_bar.update(100)
progress_bar.close()
reader = sp_results.ResultsReader(query_job.results())
query_job = self.service.jobs.create(
query, count=count, **kwargs_normalsearch
)
resp_rows, reader = self._exec_async_search(query_job, page_size, **kwargs)

resp_rows = [row for row in reader if isinstance(row, dict)]
if not resp_rows:
if len(resp_rows) == 0 or not resp_rows:
print("Warning - query did not return any results.")
return [row for row in reader if isinstance(row, sp_results.Message)]
return pd.DataFrame(resp_rows)
Expand Down Expand Up @@ -310,6 +327,94 @@ def driver_queries(self) -> Iterable[Dict[str, Any]]:
]
return []

def _exec_async_search(self, query_job, page_size, timeout=60):
"""Execute an async search and return results."""
# Initiate progress bar and start while loop, waiting for async query to complete
progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete")
prev_progress = 0
offset = 0 # Start at result 0
start_time = datetime.now()
end_time = start_time + timedelta(seconds=timeout)
while True:
while not query_job.is_ready():
sleep(1)
if self._retrieve_job_status(query_job, progress_bar, prev_progress):
break
if datetime.now() > end_time:
raise MsticpyDataQueryError(
"Timeout waiting for Splunk query to complete",
f"Job completion reported {query_job['doneProgress']}",
title="Splunk query timeout",
)
sleep(1)
# Update progress bar indicating job completion
progress_bar.update(100)
progress_bar.close()
sleep(2)

logger.info("Implicit parameter dump - 'page_size': %d", page_size)
return self._retrieve_results(query_job, offset, page_size)

@staticmethod
def _retrieve_job_status(query_job, progress_bar, prev_progress):
"""Poll the status of a job and update the progress bar."""
stats = {
"is_done": query_job["isDone"],
"done_progress": float(query_job["doneProgress"]) * 100,
"scan_count": int(query_job["scanCount"]),
"event_count": int(query_job["eventCount"]),
"result_count": int(query_job["resultCount"]),
}
status = (
"\r%(done_progress)03.1f%% %(scan_count)d scanned "
"%(event_count)d matched %(result_count)d results"
) % stats
if prev_progress == 0:
progress = stats["done_progress"]
else:
progress = stats["done_progress"] - prev_progress
prev_progress = stats["done_progress"]
progress_bar.update(progress)

if stats["is_done"] == "1":
logger.info(status)
logger.info("Splunk job completed.")
return True
return False

@staticmethod
def _retrieve_results(query_job, offset, page_size):
"""Retrieve the results of a job, decode and return them."""
# Retrieving all the results by paginate
result_count = int(
query_job["resultCount"]
) # Number of results this job returned

resp_rows = []
progress_bar_paginate = tqdm(
total=result_count, desc="Waiting Splunk result to retrieve"
)
while offset < result_count:
kwargs_paginate = {
"count": page_size,
"offset": offset,
"output_mode": "json",
}
# Get the search results and display them
search_results = query_job.results(**kwargs_paginate)
# due to DeprecationWarning of normal ResultsReader
reader = sp_results.JSONResultsReader( # pylint: disable=no-member
search_results
)
resp_rows.extend([row for row in reader if isinstance(row, dict)])
progress_bar_paginate.update(page_size)
offset += page_size
# Update progress bar indicating fetch results
progress_bar_paginate.update(result_count)
progress_bar_paginate.close()
logger.info("Retrieved %d results.", len(resp_rows))
return resp_rows, reader

@property
def _saved_searches(self) -> Union[pd.DataFrame, Any]:
"""
Expand Down
28 changes: 27 additions & 1 deletion tests/data/drivers/test_splunk_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from msticpy.common.exceptions import (
MsticpyConnectionError,
MsticpyDataQueryError,
MsticpyNotConnectedError,
MsticpyUserConfigError,
)
Expand Down Expand Up @@ -69,15 +70,35 @@ def __init__(self, name, count):


class _MockAsyncResponse:
stats = {
"isDone": "0",
"doneProgress": 0.0,
"scanCount": 1,
"eventCount": 100,
"resultCount": 100,
}

def __init__(self, query):
self.query = query

def results(self):
def __getitem__(self, key):
"""Mock method."""
return self.stats[key]

def results(self, **kwargs):
return self.query

def is_done(self):
return True

def is_ready(self):
return True

@classmethod
def set_done(cls):
cls.stats["isDone"] = "1"
cls.stats["doneProgress"] = 1


class _MockSplunkCall:
def create(query, **kwargs):
Expand Down Expand Up @@ -260,6 +281,7 @@ def test_splunk_query_success(splunk_client, splunk_results):
splunk_client.connect = cli_connect
sp_driver = SplunkDriver()
splunk_results.ResultsReader = _results_reader
splunk_results.JSONResultsReader = _results_reader

# trying to get these before connecting should throw
with pytest.raises(MsticpyNotConnectedError) as mp_ex:
Expand All @@ -279,6 +301,10 @@ def test_splunk_query_success(splunk_client, splunk_results):
check.is_not_instance(response, pd.DataFrame)
check.equal(len(response), 0)

with pytest.raises(MsticpyDataQueryError):
df_result = sp_driver.query("some query", timeout=1)

_MockAsyncResponse.set_done()
df_result = sp_driver.query("some query")
check.is_instance(df_result, pd.DataFrame)
check.equal(len(df_result), 10)
Expand Down