diff --git a/docs/source/api/msticpy.analysis.polling_detection.rst b/docs/source/api/msticpy.analysis.polling_detection.rst new file mode 100644 index 000000000..b9c035c23 --- /dev/null +++ b/docs/source/api/msticpy.analysis.polling_detection.rst @@ -0,0 +1,7 @@ +msticpy.analysis.polling\_detection module +========================================== + +.. automodule:: msticpy.analysis.polling_detection + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/msticpy.analysis.rst b/docs/source/api/msticpy.analysis.rst index b17d5e56a..615bb314c 100644 --- a/docs/source/api/msticpy.analysis.rst +++ b/docs/source/api/msticpy.analysis.rst @@ -25,5 +25,6 @@ Submodules msticpy.analysis.eventcluster msticpy.analysis.observationlist msticpy.analysis.outliers + msticpy.analysis.polling_detection msticpy.analysis.syslog_utils msticpy.analysis.timeseries diff --git a/docs/source/data_acquisition/DataProv-Kusto-New.rst b/docs/source/data_acquisition/DataProv-Kusto-New.rst index 8974e82b7..5e2e7645e 100644 --- a/docs/source/data_acquisition/DataProv-Kusto-New.rst +++ b/docs/source/data_acquisition/DataProv-Kusto-New.rst @@ -20,19 +20,27 @@ Changes from the previous implementation * Use the provider name ``Kusto_New`` when creating a QueryProvider instance. This will be changed to ``Kusto`` in a future release. +* The driver supports asynchronous execution of queries. This is used + when you create a Query provider with multiple connections (e.g. + to different clusters) and when you split queries into time chunks. + See :ref:`multiple_connections` and :ref:`splitting_query_execution` for + for more details. * The settings format has changed (although the existing format is still supported albeit with some limited functionality). +* Supports user-specified timeout for queries. +* Supports proxies (via MSTICPy config or the ``proxies`` parameter to + the ``connect`` method) * You could previously specify a new cluster to connect to in when executing a query. This is no longer supported. Once the provider is connected to a cluster it will only execute queries against - that cluster. (You can however, call the connect() function to connect + that cluster. (You can however, call the ``connect()`` function to connect the provider to a new cluster before running the query.) * Some of the previous parameters have been deprecated: - * ``mp_az_auth`` is replaced by ``auth_types`` (the former still works + * The ``mp_az_auth`` parameter is replaced by ``auth_types`` (the former still works but will be removed in a future release). * ``mp_az_auth_tenant_id`` is replaced by ``tenant_id`` (the former - is no longer supported + is no longer supported). Kusto Configuration ------------------- @@ -49,9 +57,9 @@ and :doc:`MSTICPy Settings Editor<../getting_started/SettingsEditor>` .. note:: The settings for the new Kusto provider are stored in the ``KustoClusters`` section of the configuration file. This cannot currently be edited from the MSTICPy Settings Editor - please - edit the *msticpyconfig.yaml* directly to edit these. + edit the *msticpyconfig.yaml* in a text editor to change these. -To accommodate the use of multiple clusters the new provider supports +To accommodate the use of multiple clusters, the new provider supports a different configuration format. The basic settings in the file should look like the following: @@ -92,7 +100,7 @@ for *clientsecret* authentication. The ClusterDefaults section ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you have parameters that you want to apply to all clusters +If you have parameters that you want to apply to all clusters, you can add these to a ``ClusterDefaults`` section. .. code:: yaml @@ -117,7 +125,7 @@ cluster group name. This is useful if you have clusters in different regions that share the same schema and you want to run the same queries against all of them. -This is used primarily to support query templates, to match +ClusterGroups are used primarily to support query templates, to match queries to the correct cluster. See `Writing query templates for Kusto clusters`_ later in this document. diff --git a/docs/source/data_acquisition/DataProv-MSSentinel-New.rst b/docs/source/data_acquisition/DataProv-MSSentinel-New.rst index c0417cb02..4abe5ccab 100644 --- a/docs/source/data_acquisition/DataProv-MSSentinel-New.rst +++ b/docs/source/data_acquisition/DataProv-MSSentinel-New.rst @@ -7,7 +7,7 @@ the (the earlier implementation used `Kqlmagic `__) -.. note:: This provider currently in beta and is available for testing. +.. warning:: This provider currently in beta and is available for testing. It is available alongside the existing Sentinel provider for you to compare old and new. To use it you will need the ``azure-monitor-query`` package installed. You can install this with ``pip install azure-monitor-query`` @@ -25,6 +25,12 @@ Changes from the previous implementation * Supports user-specified timeout for queries. * Supports proxies (via MSTICPy config or the ``proxies`` parameter to the ``connect`` method) +* The driver supports asynchronous execution of queries. This is used + when you create a Query provider with multiple connections (e.g. + to different clusters) and when you split queries into time chunks. + See :ref:`multiple_connections` and :ref:`splitting_query_execution` for + for more details. This is independent of the ability to specify + multiple workspaces in a single connection as described above. * Some of the previous parameters have been deprecated: * ``mp_az_auth`` is replaced by ``auth_types`` (the former still works @@ -129,7 +135,8 @@ Connecting to a MS Sentinel Workspace Once you've created a QueryProvider you need to authenticate to Sentinel Workspace. This is done by calling the connect() function of the Query -Provider. See :py:meth:`connect() ` +Provider. See +:py:meth:`connect() ` This function takes an initial parameter (called ``connection_str`` for historical reasons) that can be one of the following: @@ -160,7 +167,6 @@ an instance of WorkspaceConfig to the query provider's ``connect`` method. qry_prov.connect(WorkspaceConfig(workspace="MyOtherWorkspace")) - MS Sentinel Authentication options ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -177,15 +183,26 @@ You can override several authentication parameters including: * tenant_id - the Azure tenant ID to use for authentication If you are using a Sovereign cloud rather than the Azure global cloud, -you should follow the guidance in :doc:`Azure Authentication <../getting_started/AzureAuthentication>` +you should follow the guidance in +:doc:`Azure Authentication <../getting_started/AzureAuthentication>` to configure the correct cloud. + Connecting to multiple Sentinel workspaces ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The Sentinel data provider supports connecting to multiple workspaces. -You can pass a list of workspace names or workspace IDs to the ``connect`` method. +There are two mechanisms for querying multiple MS Sentinel workspaces. +One is a generic method common to all data providers. For more +information on this see :ref:`multiple_connections` in the main +Data Providers documentation. + +The other is specific to the Sentinel data provider and is provided +by the underlying Azure Monitor client. This latter capability is described in +this section. + +The Sentinel data provider supports connecting to multiple workspaces by +passing a list of workspace names or workspace IDs to the ``connect`` method. using the ``workspaces`` or ``workspace_ids`` parameters respectively. ``workspace_ids`` should be a list or tuple of workspace IDs. diff --git a/docs/source/data_acquisition/DataProv-OSQuery.rst b/docs/source/data_acquisition/DataProv-OSQuery.rst index 4f17b0d79..b6895ea9a 100644 --- a/docs/source/data_acquisition/DataProv-OSQuery.rst +++ b/docs/source/data_acquisition/DataProv-OSQuery.rst @@ -152,7 +152,7 @@ from the logs. qry_prov.osquery.processes() ================================== ================ ========================= ===== ========== ========= ====== ======== ======== ===== ========== -name hostIdentifier unixTime ... username cmdline euid name_ parent uid username +name hostIdentifier unixTime ... username cmdline euid tname_ parent uid username ================================== ================ ========================= ===== ========== ========= ====== ======== ======== ===== ========== pack_osquery-custom-pack_processes jumpvm 2023-03-16 03:08:58+00:00 ... LOGIN 0 kthreadd 2 0 root pack_osquery-custom-pack_processes jumpvm 2023-03-16 03:08:58+00:00 ... LOGIN 0 kthreadd 2 0 root diff --git a/docs/source/data_acquisition/DataProviders.rst b/docs/source/data_acquisition/DataProviders.rst index ddd229f6c..65f6561a1 100644 --- a/docs/source/data_acquisition/DataProviders.rst +++ b/docs/source/data_acquisition/DataProviders.rst @@ -449,6 +449,75 @@ TimeGenerated AlertDisplayName Severity 2019-07-22 07:02:42 Traffic from unrecommended IP addresses was de... Low Azure security center has detected incoming tr... {\r\n "Destination Port": "3389",\r\n "Proto... [\r\n {\r\n "$id": "4",\r\n "ResourceId... Detection =================== ================================================= ========== ================================================= ================================================ ========================================== ============== +.. _multiple_connections: + +Running a query across multiple connections +------------------------------------------- + +It is common for data services to be spread across multiple tenants or +workloads. For example, you may have multiple Sentinel workspaces, +Microsoft Defender subscriptions or Splunk instances. You can use the +``QueryProvider`` to run a query across multiple connections and return +the results in a single DataFrame. + +.. note:: This feature only works for multiple instances using the same + ``DataEnvironment`` (e.g. "MSSentinel", "Splunk", etc.) + +To create a multi-instance provider you first need to create an +instance of a QueryProvider for your data source and execute +the ``connect()`` method to connect to the first instance of your +data service. Which instance you choose is not important. +Then use the +:py:meth:`add_connection() ` +method. This takes the same parameters as the +:py:meth:`connect() ` +method (the parameters for this method vary by data provider). + +``add_connection()`` also supports an ``alias`` parameter to allow +you to refer to the connection by a friendly name. Otherwise, the +connection is just assigned an index number in the order that it was +added. + +Use the +:py:meth:`list_connections() ` +to see all of the current connections. + +.. code:: ipython3 + + qry_prov = QueryProvider("MSSentinel") + qry_prov.connect(workspace="Workspace1") + qry_prov.add_connection(workspace="Workspace2, alias="Workspace2") + qry_prov.list_connections() + +When you now run a query for this provider, the query will be run on +all of the connections and the results will be returned as a single +dataframe. + +.. code:: ipython3 + + test_query = ''' + SecurityAlert + | take 5 + ''' + + query_test = qry_prov.exec_query(query=test_query) + query_test.head() + +Some of the MSTICPy drivers support asynchronous execution of queries +against multiple instances, so that the time taken to run the query is +much reduced compared to running the queries sequentially. Drivers +that support asynchronous queries will use this automatically. + +By default, the queries will use at most 4 concurrent threads. You can +override this by initializing the QueryProvider with the +``max_threads`` parameter to set it to the number of threads you want. + +.. code:: ipython3 + + qry_prov = QueryProvider("MSSentinel", max_threads=10) + + +.. _splitting_query_execution: Splitting Query Execution into Chunks ------------------------------------- @@ -459,6 +528,11 @@ split a query into time ranges. Each sub-range is run as an independent query and the results are combined before being returned as a DataFrame. +.. note:: Some data drivers support running queries asynchronously. + This means that the time taken to run all chunks of the query is much reduced + compared to running these sequentially. Drivers that support + asynchronous queries will use this automatically. + To use this feature you must specify the keyword parameter ``split_query_by`` when executing the query function. The value to this parameter is a string that specifies a time period. The time range specified by the @@ -471,7 +545,7 @@ chunks. than the split period that you specified in the *split_query_by* parameter. This can happen if *start* and *end* are not aligned exactly on time boundaries (e.g. if you used a one hour split period - and *end* is 10 hours 15 min after *start*. The query split logic + and *end* is 10 hours 15 min after *start*). The query split logic will create a larger final slice if *end* is close to the final time range or it will insert an extra time range to ensure that the full *start** to *end* time range is covered. diff --git a/docs/source/data_acquisition/SentinelIncidents.rst b/docs/source/data_acquisition/SentinelIncidents.rst index 03435270e..da4d68610 100644 --- a/docs/source/data_acquisition/SentinelIncidents.rst +++ b/docs/source/data_acquisition/SentinelIncidents.rst @@ -16,10 +16,12 @@ See :py:meth:`list_incidents `_ and include the following key items: - $top: this controls how many incidents are returned - - $filter: this accepts an OData query that filters the returned item. https://learn.microsoft.com/graph/filter-query-parameter + - $filter: this accepts an OData query that filters the returned item. + (see `$filter parameter `_) - $orderby: this allows for sorting results by a specific column .. code:: ipython3 diff --git a/docs/source/getting_started/Installing.rst b/docs/source/getting_started/Installing.rst index faac0004e..b9c079d41 100644 --- a/docs/source/getting_started/Installing.rst +++ b/docs/source/getting_started/Installing.rst @@ -201,22 +201,24 @@ exception message: functionality you are trying to use. Installing in Managed Spark compute in Azure Machine Learning Notebooks -^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -*MSTICPy* installation for Managed (Automatic) Spark Compute in Azure Machine Learning workspace requires -different instructions since library installation is different. +*MSTICPy* installation for Managed (Automatic) Spark Compute in Azure Machine Learning workspace requires +different instructions since library installation is different. -.. note:: These notebook requires Azure ML Spark Compute. If you are using it for the first time, follow the guidelines mentioned here :Attach and manage a Synapse Spark pool in Azure Machine Learning (preview): -.. _Attach and manage a Synapse Spark pool in Azure Machine Learning (preview): -https://learn.microsoft.com/en-us/azure/machine-learning/how-to-manage-synapse-spark-pool?tabs=studio-ui +.. note:: These notebook requires Azure ML Spark Compute. If you are using + it for the first time, follow the guidelines mentioned here: + `Attach and manage a Synapse Spark pool in Azure Machine Learning (preview) `_ Once you have completed the pre-requisites, you will see AzureML Spark Compute in the dropdown menu for Compute. Select it and run any cell to start Spark Session. -Please refer the docs _Managed (Automatic) Spark compute in Azure Machine Learning Notebooks: for more detailed steps along with screenshots. -.. _Managed (Automatic) Spark compute in Azure Machine Learning Notebooks: -https://learn.microsoft.com/en-us/azure/machine-learning/interactive-data-wrangling-with-apache-spark-azure-ml +Please refer to +`Managed (Automatic) Spark compute in Azure Machine Learning Notebooks `_ +for more detailed steps along with screenshots. + + -In order to install any libraries in Spark compute, you need to use a conda file to configure a Spark session. +In order to install any libraries in Spark compute, you need to use a conda file to configure a Spark session. Please save below file as conda.yml , check the Upload conda file checkbox. You can modify the version number as needed. Then, select Browse, and choose the conda file saved earlier with the Spark session configuration you want. diff --git a/msticpy/analysis/polling_detection.py b/msticpy/analysis/polling_detection.py index f1803a2ac..9f35b5eb2 100644 --- a/msticpy/analysis/polling_detection.py +++ b/msticpy/analysis/polling_detection.py @@ -35,11 +35,6 @@ class PeriodogramPollingDetector: Dataframe containing the data to be analysed. Must contain a column of edges and a column of timestamps - Methods - ------- - detect_polling(timestamps, process_start, process_end, interval) - Detect strong periodic frequencies - """ def __init__(self, data: pd.DataFrame, copy: bool = False) -> None: diff --git a/msticpy/data/core/data_providers.py b/msticpy/data/core/data_providers.py index cbd681821..66a86157c 100644 --- a/msticpy/data/core/data_providers.py +++ b/msticpy/data/core/data_providers.py @@ -4,14 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Data provider loader.""" -from datetime import datetime +import logging from functools import partial -from itertools import tee from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import pandas as pd -from tqdm.auto import tqdm from ..._version import VERSION from ...common.pkg_config import get_config @@ -24,7 +22,6 @@ from .query_defns import DataEnvironment from .query_provider_connections_mixin import QueryProviderConnectionsMixin from .query_provider_utils_mixin import QueryProviderUtilsMixin -from .query_source import QuerySource from .query_store import QueryStore __version__ = VERSION @@ -39,6 +36,8 @@ "kusto_new": ["kusto"], } +logger = logging.getLogger(__name__) + # These are mixin classes that do not have an __init__ method # pylint: disable=super-init-not-called @@ -114,6 +113,8 @@ def __init__( # noqa: MC0001 driver.get_driver_property(DriverProps.EFFECTIVE_ENV) or self.environment_name ) + logger.info("Using data environment %s", self.environment_name) + logger.info("Driver class: %s", self.driver_class.__name__) self._additional_connections: Dict[str, DriverBase] = {} self._query_provider = driver @@ -124,15 +125,19 @@ def __init__( # noqa: MC0001 # Add any query files data_env_queries: Dict[str, QueryStore] = {} + self._query_paths = query_paths if driver.use_query_paths: + logger.info("Using query paths %s", query_paths) data_env_queries.update( self._read_queries_from_paths(query_paths=query_paths) ) self.query_store = data_env_queries.get( self.environment_name, QueryStore(self.environment_name) ) + logger.info("Adding query functions to provider") self._add_query_functions() self._query_time = QueryTime(units="day") + logger.info("Initialization complete.") def _check_environment( self, data_environment @@ -179,6 +184,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): Connection string for the data source """ + logger.info("Calling connect on driver") self._query_provider.connect(connection_str=connection_str, **kwargs) # If the driver has any attributes to expose via the provider @@ -194,6 +200,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): refresh_query_funcs = True # Add any built-in or dynamically retrieved queries from driver if self._query_provider.has_driver_queries: + logger.info("Adding driver queries to provider") driver_queries = self._query_provider.driver_queries self._add_driver_queries(queries=driver_queries) refresh_query_funcs = True @@ -202,6 +209,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): self._add_query_functions() # Since we're now connected, add Pivot functions + logger.info("Adding query pivot functions") self._add_pivots(lambda: self._query_time.timespan) def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: @@ -230,12 +238,13 @@ def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: """ query_options = kwargs.pop("query_options", {}) or kwargs query_source = kwargs.pop("query_source", None) - result = self._query_provider.query( - query, query_source=query_source, **query_options - ) + + logger.info("Executing query '%s...'", query[:40]) if not self._additional_connections: - return result - return self._exec_additional_connections(query, result, **kwargs) + return self._query_provider.query( + query, query_source=query_source, **query_options + ) + return self._exec_additional_connections(query, **kwargs) @property def query_time(self): @@ -265,6 +274,7 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: return None params, missing = extract_query_params(query_source, *args, **kwargs) + logger.info("Parameters for query: %s", params) query_options = { "default_time_params": self._check_for_time_params(params, missing) } @@ -272,12 +282,14 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: query_source.help() raise ValueError(f"No values found for these parameters: {missing}") - split_by = kwargs.pop("split_query_by", None) + split_by = kwargs.pop("split_query_by", kwargs.pop("split_by", None)) if split_by: + logger.info("Split query selected - interval - %s", split_by) split_result = self._exec_split_query( split_by=split_by, query_source=query_source, query_params=params, + debug=_debug_flag(*args, **kwargs), args=args, **kwargs, ) @@ -292,7 +304,10 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: return query_str # Handle any query options passed and run the query - query_options.update(_get_query_options(params, kwargs)) + query_options.update(self._get_query_options(params, kwargs)) + logger.info( + "Running query '%s...' with params: %s", query_str[:40], query_options + ) return self.exec_query(query_str, query_source=query_source, **query_options) def _check_for_time_params(self, params, missing) -> bool: @@ -342,6 +357,7 @@ def _read_queries_from_paths(self, query_paths) -> Dict[str, QueryStore]: if param_qry_path: all_query_paths.append(param_qry_path) if all_query_paths: + logger.info("Reading queries from %s", all_query_paths) return QueryStore.import_files( source_path=all_query_paths, recursive=True, @@ -395,80 +411,23 @@ def _add_driver_queries(self, queries: Iterable[Dict[str, str]]): # queries it should not be noticeable. self._add_query_functions() - def _exec_split_query( - self, - split_by: str, - query_source: QuerySource, - query_params: Dict[str, Any], - args, - **kwargs, - ) -> Union[pd.DataFrame, str, None]: - start = query_params.pop("start", None) - end = query_params.pop("end", None) - if not (start or end): - print( - "Cannot split a query that does not have 'start' and 'end' parameters" - ) - return None - try: - split_delta = pd.Timedelta(split_by) - except ValueError: - split_delta = pd.Timedelta("1D") - - ranges = _calc_split_ranges(start, end, split_delta) - - split_queries = [ - query_source.create_query( - formatters=self._query_provider.formatters, - start=q_start, - end=q_end, - **query_params, - ) - for q_start, q_end in ranges - ] - # This looks for any of the "print query" debug args in args or kwargs - if _debug_flag(*args, **kwargs): - return "\n\n".join(split_queries) - - # Retrieve any query options passed (other than query params) - # and send to query function. - query_options = _get_query_options(query_params, kwargs) - query_dfs = [ - self.exec_query(query_str, query_source=query_source, **query_options) - for query_str in tqdm(split_queries, unit="sub-queries", desc="Running") - ] - - return pd.concat(query_dfs) - - -def _calc_split_ranges(start: datetime, end: datetime, split_delta: pd.Timedelta): - """Return a list of time ranges split by `split_delta`.""" - # Use pandas date_range and split the result into 2 iterables - s_ranges, e_ranges = tee(pd.date_range(start, end, freq=split_delta)) - next(e_ranges, None) # skip to the next item in the 2nd iterable - # Zip them together to get a list of (start, end) tuples of ranges - # Note: we subtract 1 nanosecond from the 'end' value of each range so - # to avoid getting duplicated records at the boundaries of the ranges. - # Some providers don't have nanosecond granularity so we might - # get duplicates in these cases - ranges = [ - (s_time, e_time - pd.Timedelta("1ns")) - for s_time, e_time in zip(s_ranges, e_ranges) - ] - - # Since the generated time ranges are based on deltas from 'start' - # we need to adjust the end time on the final range. - # If the difference between the calculated last range end and - # the query 'end' that the user requested is small (< 10% of a delta), - # we just replace the last "end" time with our query end time. - if (ranges[-1][1] - end) < (split_delta / 10): - ranges[-1] = ranges[-1][0], end - else: - # otherwise append a new range starting after the last range - # in ranges and ending in 'end" - # note - we need to add back our subtracted 1 nanosecond - ranges.append((ranges[-1][0] + pd.Timedelta("1ns"), end)) - return ranges + @staticmethod + def _get_query_options( + params: Dict[str, Any], kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + # sourcery skip: inline-immediately-returned-variable, use-or-for-fallback + """Return any kwargs not already in params.""" + query_options = kwargs.pop("query_options", {}) + if not query_options: + # Any kwargs left over we send to the query provider driver + query_options = { + key: val for key, val in kwargs.items() if key not in params + } + query_options["time_span"] = { + "start": params.get("start"), + "end": params.get("end"), + } + return query_options def _resolve_package_path(config_path: str) -> Path: @@ -500,19 +459,3 @@ def _debug_flag(*args, **kwargs) -> bool: return any(db_arg for db_arg in _DEBUG_FLAGS if db_arg in args) or any( db_arg for db_arg in _DEBUG_FLAGS if kwargs.get(db_arg, False) ) - - -def _get_query_options( - params: Dict[str, Any], kwargs: Dict[str, Any] -) -> Dict[str, Any]: - # sourcery skip: inline-immediately-returned-variable, use-or-for-fallback - """Return any kwargs not already in params.""" - query_options = kwargs.pop("query_options", {}) - if not query_options: - # Any kwargs left over we send to the query provider driver - query_options = {key: val for key, val in kwargs.items() if key not in params} - query_options["time_span"] = { - "start": params.get("start"), - "end": params.get("end"), - } - return query_options diff --git a/msticpy/data/core/query_provider_connections_mixin.py b/msticpy/data/core/query_provider_connections_mixin.py index e436a3144..b40a757bd 100644 --- a/msticpy/data/core/query_provider_connections_mixin.py +++ b/msticpy/data/core/query_provider_connections_mixin.py @@ -4,19 +4,29 @@ # license information. # -------------------------------------------------------------------------- """Query Provider additional connection methods.""" -from typing import Any, Dict, List, Optional, Protocol +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from functools import partial +from itertools import tee +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union import pandas as pd +from tqdm.auto import tqdm from ..._version import VERSION from ...common.exceptions import MsticpyDataQueryError -from ..drivers.driver_base import DriverBase +from ..drivers.driver_base import DriverBase, DriverProps +from .query_source import QuerySource __version__ = VERSION __author__ = "Ian Hellen" +logger = logging.getLogger(__name__) -# pylint: disable=too-few-public-methods + +# pylint: disable=too-few-public-methods, unnecessary-ellipsis class QueryProviderProtocol(Protocol): """Protocol for required properties of QueryProvider class.""" @@ -25,6 +35,16 @@ class QueryProviderProtocol(Protocol): _additional_connections: Dict[str, Any] _query_provider: DriverBase + def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]: + """Execute a query against the provider.""" + ... + + @staticmethod + def _get_query_options( + params: Dict[str, Any], kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + ... + # pylint: disable=super-init-not-called class QueryProviderConnectionsMixin(QueryProviderProtocol): @@ -49,7 +69,7 @@ def add_connection( Other Parameters ---------------- kwargs : Dict[str, Any] - Other parameters passed to the driver constructor. + Other connection parameters passed to the driver. Notes ----- @@ -67,7 +87,7 @@ def add_connection( def list_connections(self) -> List[str]: """ - Return a list of current connections or the default connection. + Return a list of current connections. Returns ------- @@ -81,18 +101,314 @@ def list_connections(self) -> List[str]: ] return [f"Default: {self._query_provider.current_connection}", *add_connections] - def _exec_additional_connections(self, query, result, **kwargs) -> pd.DataFrame: - """Return results of query run query against additional connections.""" - query_source = kwargs.get("query_source") - query_options = kwargs.get("query_options", {}) - results = [result] + # pylint: disable=too-many-locals + def _exec_additional_connections(self, query, **kwargs) -> pd.DataFrame: + """ + Return results of query run query against additional connections. + + Parameters + ---------- + query : str + The query to execute. + progress: bool, optional + Show progress bar, by default True + retry_on_error: bool, optional + Retry failed queries, by default False + **kwargs : Dict[str, Any] + Additional keyword arguments to pass to the query method. + + Returns + ------- + pd.DataFrame + The concatenated results of the query executed against all connections. + + Notes + ----- + This method executes the specified query against all additional connections + added to the query provider. + If the driver supports threading or async execution, the per-connection + queries are executed asynchronously. + Otherwise, the queries are executed sequentially. + + """ + progress = kwargs.pop("progress", True) + retry = kwargs.pop("retry_on_error", False) + # Add the initial connection + query_tasks = { + self._query_provider.current_connection + or "0": partial( + self._query_provider.query, + query, + **kwargs, + ) + } + # add the additional connections + query_tasks.update( + { + name: partial(connection.query, query, **kwargs) + for name, connection in self._additional_connections.items() + } + ) + + logger.info("Running queries for %s connections.", len(query_tasks)) + # Run the queries threaded if supported + if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING): + logger.info("Running threaded queries.") + event_loop = _get_event_loop() + return event_loop.run_until_complete( + self._exec_queries_threaded(query_tasks, progress, retry) + ) + + # standard synchronous execution print(f"Running query for {len(self._additional_connections)} connections.") - for con_name, connection in self._additional_connections.items(): - print(f"{con_name}...") + return self._exec_synchronous_queries(progress, query_tasks) + + def _exec_split_query( + self, + split_by: str, + query_source: QuerySource, + query_params: Dict[str, Any], + **kwargs, + ) -> Union[pd.DataFrame, str, None]: + """ + Execute a query that is split into multiple queries. + + Parameters + ---------- + split_by : str + The time interval to split the query by. + query_source : QuerySource + The query to execute. + query_params : Dict[str, Any] + The parameters to pass to the query. + + Other Parameters + ---------------- + debug: bool, optional + Return queries to be executed rather than execute them, by default False + progress: bool, optional + Show progress bar, by default True + retry_on_error: bool, optional + Retry failed queries, by default False + **kwargs : Dict[str, Any] + Additional keyword arguments to pass to the query method. + + Returns + ------- + pd.DataFrame + The concatenated results of the query executed against all connections. + + Notes + ----- + This method executes the time-chunks of the split query. + If the driver supports threading or async execution, the sub-queries are + executed asynchronously. Otherwise, the queries are executed sequentially. + + """ + start = query_params.pop("start", None) + end = query_params.pop("end", None) + progress = kwargs.pop("progress", True) + retry = kwargs.pop("retry_on_error", False) + debug = kwargs.pop("debug", False) + if not (start or end): + print("Cannot split a query with no 'start' and 'end' parameters") + return None + + split_queries = self._create_split_queries( + query_source=query_source, + query_params=query_params, + start=start, + end=end, + split_by=split_by, + ) + if debug: + return "\n\n".join( + f"{start}-{end}\n{query}" + for (start, end), query in split_queries.items() + ) + + query_tasks = self._create_split_query_tasks( + query_source, query_params, split_queries, **kwargs + ) + # Run the queries threaded if supported + if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING): + logger.info("Running threaded queries.") + event_loop = _get_event_loop() + return event_loop.run_until_complete( + self._exec_queries_threaded(query_tasks, progress, retry) + ) + + # or revert to standard synchronous execution + return self._exec_synchronous_queries(progress, query_tasks) + + def _create_split_query_tasks( + self, + query_source: QuerySource, + query_params: Dict[str, Any], + split_queries, + **kwargs, + ) -> Dict[str, partial]: + """Return dictionary of partials to execute queries.""" + # Retrieve any query options passed (other than query params) + query_options = self._get_query_options(query_params, kwargs) + logger.info("query_options: %s", query_options) + logger.info("kwargs: %s", kwargs) + if "time_span" in query_options: + del query_options["time_span"] + return { + f"{start}-{end}": partial( + self.exec_query, + query=query_str, + query_source=query_source, + time_span={"start": start, "end": end}, + **query_options, + ) + for (start, end), query_str in split_queries.items() + } + + @staticmethod + def _exec_synchronous_queries( + progress: bool, query_tasks: Dict[str, Any] + ) -> pd.DataFrame: + logger.info("Running queries sequentially.") + results: List[pd.DataFrame] = [] + if progress: + query_iter = tqdm(query_tasks.items(), unit="sub-queries", desc="Running") + else: + query_iter = query_tasks.items() + for con_name, query_task in query_iter: try: - results.append( - connection.query(query, query_source=query_source, **query_options) - ) + results.append(query_task()) except MsticpyDataQueryError: print(f"Query {con_name} failed.") return pd.concat(results) + + def _create_split_queries( + self, + query_source: QuerySource, + query_params: Dict[str, Any], + start: datetime, + end: datetime, + split_by: str, + ) -> Dict[Tuple[datetime, datetime], str]: + """Return separate queries for split time ranges.""" + try: + split_delta = pd.Timedelta(split_by) + except ValueError: + split_delta = pd.Timedelta("1D") + logger.info("Using split delta %s", split_delta) + + ranges = _calc_split_ranges(start, end, split_delta) + + split_queries = { + (q_start, q_end): query_source.create_query( + formatters=self._query_provider.formatters, + start=q_start, + end=q_end, + **query_params, + ) + for q_start, q_end in ranges + } + logger.info("Split query into %s chunks", len(split_queries)) + return split_queries + + async def _exec_queries_threaded( + self, + query_tasks: Dict[str, partial], + progress: bool = True, + retry: bool = False, + ) -> pd.DataFrame: + """Return results of multiple queries run as threaded tasks.""" + logger.info("Running threaded queries for %d connections.", len(query_tasks)) + + event_loop = _get_event_loop() + + with ThreadPoolExecutor( + max_workers=self._query_provider.get_driver_property( + DriverProps.MAX_PARALLEL + ) + ) as executor: + # add the additional connections + thread_tasks = { + query_id: event_loop.run_in_executor(executor, query_func) + for query_id, query_func in query_tasks.items() + } + results: List[pd.DataFrame] = [] + failed_tasks: Dict[str, asyncio.Future] = {} + if progress: + task_iter = tqdm( + asyncio.as_completed(thread_tasks.values()), + unit="sub-queries", + desc="Running", + ) + else: + task_iter = asyncio.as_completed(thread_tasks.values()) + ids_and_tasks = dict(zip(thread_tasks, task_iter)) + for query_id, thread_task in ids_and_tasks.items(): + try: + result = await thread_task + logger.info("Query task '%s' completed successfully.", query_id) + results.append(result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Query task '%s' failed with exception", query_id, exc_info=True + ) + failed_tasks[query_id] = thread_task + + if retry and failed_tasks: + for query_id, thread_task in failed_tasks.items(): + try: + logger.info("Retrying query task '%s'", query_id) + result = await thread_task + results.append(result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Retried query task '%s' failed with exception", + query_id, + exc_info=True, + ) + # Sort the results by the order of the tasks + results = [result for _, result in sorted(zip(thread_tasks, results))] + + return pd.concat(results, ignore_index=True) + + +def _get_event_loop() -> asyncio.AbstractEventLoop: + """Return the current event loop, or create a new one.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +def _calc_split_ranges(start: datetime, end: datetime, split_delta: pd.Timedelta): + """Return a list of time ranges split by `split_delta`.""" + # Use pandas date_range and split the result into 2 iterables + s_ranges, e_ranges = tee(pd.date_range(start, end, freq=split_delta)) + next(e_ranges, None) # skip to the next item in the 2nd iterable + # Zip them together to get a list of (start, end) tuples of ranges + # Note: we subtract 1 nanosecond from the 'end' value of each range so + # to avoid getting duplicated records at the boundaries of the ranges. + # Some providers don't have nanosecond granularity so we might + # get duplicates in these cases + ranges = [ + (s_time, e_time - pd.Timedelta("1ns")) + for s_time, e_time in zip(s_ranges, e_ranges) + ] + + # Since the generated time ranges are based on deltas from 'start' + # we need to adjust the end time on the final range. + # If the difference between the calculated last range end and + # the query 'end' that the user requested is small (< 10% of a delta), + # we just replace the last "end" time with our query end time. + if (ranges[-1][1] - end) < (split_delta / 10): + ranges[-1] = ranges[-1][0], end + else: + # otherwise append a new range starting after the last range + # in ranges and ending in 'end" + # note - we need to add back our subtracted 1 nanosecond + ranges.append((ranges[-1][0] + pd.Timedelta("1ns"), end)) + + return ranges diff --git a/msticpy/data/drivers/azure_kusto_driver.py b/msticpy/data/drivers/azure_kusto_driver.py index 815787a69..160cb14a5 100644 --- a/msticpy/data/drivers/azure_kusto_driver.py +++ b/msticpy/data/drivers/azure_kusto_driver.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Kusto Driver subclass.""" +import base64 import dataclasses import json import logging @@ -17,6 +18,8 @@ KustoClient, KustoConnectionStringBuilder, ) +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.serialization import pkcs12 from ..._version import VERSION from ...auth.azure_auth import az_connect, get_default_resource_name @@ -34,6 +37,7 @@ from ..core.query_source import QuerySource from .driver_base import DriverBase, DriverProps +# pylint: disable=ungrouped-imports try: from azure.kusto.data.exceptions import KustoApiError, KustoServiceError from azure.kusto.data.helpers import dataframe_from_result_table @@ -78,6 +82,7 @@ class ConfigFields: CLIENT_SEC = "ClientSecret" ARGS = "Args" CLUSTER_GROUPS = "ClusterGroups" + CERTIFICATE = "Certificate" # pylint: disable=no-member @property @@ -174,6 +179,10 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): self.set_driver_property(DriverProps.PUBLIC_ATTRS, self._set_public_attribs()) self.set_driver_property(DriverProps.FILTER_ON_CONNECT, True) self.set_driver_property(DriverProps.EFFECTIVE_ENV, DataEnvironment.Kusto.name) + self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) + self.set_driver_property( + DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) + ) self._loaded = True def _set_public_attribs(self): @@ -204,16 +213,12 @@ def current_connection(self, value: str): @property def cluster_uri(self) -> str: """Return current cluster URI.""" - if not self._current_config: - return "" - return self._current_config.cluster + return "" if not self._current_config else self._current_config.cluster @property def cluster_name(self) -> str: """Return current cluster URI.""" - if not self._current_config: - return "" - return self._current_config.name + return self._current_config.name if self._current_config else "" @property def cluster_config_name(self) -> str: @@ -314,6 +319,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): ) cluster = kwargs.pop("cluster", None) + self.current_connection = connection_str or self.current_connection if not connection_str and not cluster: raise MsticpyParameterError( "Must specify either a connection string or a cluster name", @@ -322,15 +328,18 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): if cluster: self._current_config = self._lookup_cluster_settings(cluster) + if not self._az_tenant_id: + self._az_tenant_id = self._current_config.tenant_id logger.info( - "Using cluster id: %s, retrieved %s", + "Using cluster id: %s, retrieved url %s to build connection string", cluster, self.cluster_uri, ) kusto_cs = self._get_connection_string_for_cluster(self._current_config) + self.current_connection = cluster else: logger.info("Using connection string %s", connection_str) - self._current_connection = connection_str + self.current_connection = connection_str kusto_cs = connection_str self.client = KustoClient(kusto_cs) @@ -551,14 +560,23 @@ def _get_connection_string_for_cluster( ) -> KustoConnectionStringBuilder: """Return full cluster URI and credential for cluster name or URI.""" auth_params = self._get_auth_params_from_config(cluster_config) + connect_auth_types = self._az_auth_types or AzureCloudConfig().auth_methods if auth_params.method == "clientsecret": - connect_auth_types = self._az_auth_types or AzureCloudConfig().auth_methods + logger.info("Client secret specified in config - using client secret authn") if "clientsecret" not in connect_auth_types: - connect_auth_types.append("clientsecret") + connect_auth_types.insert(0, "clientsecret") credential = az_connect( auth_types=connect_auth_types, **(auth_params.params) ) + elif auth_params.method == "certificate": + logger.info("Certificate specified in config - using certificate authn") + connect_auth_types.insert(0, "certificate") + credential = az_connect( + auth_types=self._az_auth_types, **(auth_params.params) + ) + return self._create_kusto_cert_connection_str(auth_params) else: + logger.info("Using integrated authn") credential = az_connect( auth_types=self._az_auth_types, **(auth_params.params) ) @@ -570,6 +588,41 @@ def _get_connection_string_for_cluster( user_token=token.token, ) + def _create_kql_cert_connection_str( + self, auth_params: AuthParams + ) -> KustoConnectionStringBuilder: + logger.info("Creating KQL connection string for certificate authentication") + if not self._az_tenant_id: + raise ValueError( + "Azure tenant ID must be set in config or connect parameter", + "to use certificate authentication", + ) + cert_bytes = base64.b64decode(auth_params.params["certificate"]) + ( + private_key, + certificate, + _, + ) = pkcs12.load_key_and_certificates(data=cert_bytes, password=None) + if private_key is None or certificate is None: + raise ValueError( + f"Could not load certificate for cluster {self.cluster_uri}" + ) + private_cert = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + public_cert = certificate.public_bytes(encoding=serialization.Encoding.PEM) + thumbprint = certificate.fingerprint(hashes.SHA1()) + return KustoConnectionStringBuilder.with_aad_application_certificate_sni_authentication( + connection_string=self.cluster_uri, + aad_app_id=auth_params.params["client_id"], + private_certificate=private_cert.decode("utf-8"), + public_certificate=public_cert.decode("utf-8"), + thumbprint=thumbprint.hex().upper(), + authority_id=self._az_tenant_id, + ) + def _get_auth_params_from_config(self, cluster_config: KustoConfig) -> AuthParams: """Get authentication parameters for cluster from KustoConfig values.""" method = "integrated" @@ -581,6 +634,16 @@ def _get_auth_params_from_config(self, cluster_config: KustoConfig) -> AuthParam logger.info( "Using client secret authentication because client_secret in config" ) + elif ( + KFields.CERTIFICATE in cluster_config + and KFields.CLIENT_ID in cluster_config + ): + method = "certificate" + auth_params_dict["client_id"] = cluster_config.ClientId + auth_params_dict["certificate"] = cluster_config.Certificate + logger.info( + "Using client secret authentication because client_secret in config" + ) elif KFields.INTEG_AUTH in cluster_config: logger.info("Using integrated auth.") auth_params_dict["tenant_id"] = cluster_config.tenant_id diff --git a/msticpy/data/drivers/azure_monitor_driver.py b/msticpy/data/drivers/azure_monitor_driver.py index b54281b7b..c7e2f775b 100644 --- a/msticpy/data/drivers/azure_monitor_driver.py +++ b/msticpy/data/drivers/azure_monitor_driver.py @@ -133,6 +133,7 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): self._query_client: Optional[LogsQueryClient] = None self._az_tenant_id: Optional[str] = None self._ws_config: Optional[WorkspaceConfig] = None + self._ws_name: Optional[str] = None self._workspace_id: Optional[str] = None self._workspace_ids: List[str] = [] self._def_connection_str: Optional[str] = connection_str @@ -143,6 +144,10 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): self.set_driver_property( DriverProps.EFFECTIVE_ENV, DataEnvironment.MSSentinel.name ) + self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) + self.set_driver_property( + DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) + ) logger.info( "AzureMonitorDriver loaded. connect_str %s, kwargs: %s", connection_str, @@ -160,6 +165,29 @@ def url_endpoint(self) -> str: return f"{base_url}v1" return base_url + @property + def current_connection(self) -> str: + """Return the current connection name.""" + connection = self._ws_name + if ( + not connection + and self._ws_config + and WorkspaceConfig.CONF_WS_NAME_KEY in self._ws_config + ): + connection = self._ws_config[WorkspaceConfig.CONF_WS_NAME_KEY] + return ( + connection + or self._def_connection_str + or self._workspace_id + or next(iter(self._workspace_ids), "") + or "AzureMonitor" + ) + + @current_connection.setter + def current_connection(self, value: str): + """Allow attrib to be set but ignore.""" + del value + def connect(self, connection_str: Optional[str] = None, **kwargs): """ Connect to data source. @@ -303,13 +331,21 @@ def query_with_results( title="Workspace not connected.", help_uri=_HELP_URL, ) - logger.info("Query to run %s", query) time_span_value = self._get_time_span_value(**kwargs) - server_timeout = kwargs.pop("timeout", self._def_timeout) workspace_id = next(iter(self._workspace_ids), None) or self._workspace_id additional_workspaces = self._workspace_ids[1:] if self._workspace_ids else None + logger.info("Query to run %s", query) + logger.info( + "Workspaces %s", ",".join(self._workspace_ids) or self._workspace_id + ) + logger.info( + "Time span %s - %s", + str(time_span_value[0]) if time_span_value else "none", + str(time_span_value[1]) if time_span_value else "none", + ) + logger.info("Timeout %s", server_timeout) try: result = self._query_client.query_workspace( workspace_id=workspace_id, # type: ignore[arg-type] @@ -413,6 +449,7 @@ def _get_workspaces(self, connection_str: Optional[str] = None, **kwargs): help_uri=_HELP_URL, ) self._ws_config = ws_config + self._ws_name = workspace_name or ws_config.workspace_id if not self._az_tenant_id and WorkspaceConfig.CONF_TENANT_ID_KEY in ws_config: self._az_tenant_id = ws_config[WorkspaceConfig.CONF_TENANT_ID_KEY] self._workspace_id = ws_config[WorkspaceConfig.CONF_WS_ID_KEY] @@ -464,7 +501,19 @@ def _get_time_span_value(self, **kwargs): start=time_params["start"], end=time_params["end"], ) - time_span_value = time_span.start, time_span.end + # Azure Monitor API expects datetime objects, so + # convert to datetimes if we have pd.Timestamps + t_start = ( + time_span.start.to_pydatetime(warn=False) + if isinstance(time_span.start, pd.Timestamp) + else time_span.start + ) + t_end = ( + time_span.end.to_pydatetime(warn=False) + if isinstance(time_span.end, pd.Timestamp) + else time_span.end + ) + time_span_value = t_start, t_end logger.info("Time parameters set %s", str(time_span)) return time_span_value diff --git a/msticpy/data/drivers/driver_base.py b/msticpy/data/drivers/driver_base.py index b4c5a6def..bfb3d79de 100644 --- a/msticpy/data/drivers/driver_base.py +++ b/msticpy/data/drivers/driver_base.py @@ -88,6 +88,7 @@ def __init__(self, **kwargs): self.data_environment = kwargs.get("data_environment") self._query_filter: Dict[str, Set[str]] = defaultdict(set) self._instance: Optional[str] = None + self.properties = DriverProps.defaults() self.set_driver_property( name=DriverProps.EFFECTIVE_ENV, @@ -97,6 +98,8 @@ def __init__(self, **kwargs): else self.data_environment or "" ), ) + self.set_driver_property(DriverProps.SUPPORTS_THREADING, False) + self.set_driver_property(DriverProps.MAX_PARALLEL, kwargs.get("max_threads", 4)) def __getattr__(self, attrib): """Return item from the properties dictionary as an attribute.""" diff --git a/msticpy/data/drivers/mdatp_driver.py b/msticpy/data/drivers/mdatp_driver.py index 19cbe0d39..0c1bf9214 100644 --- a/msticpy/data/drivers/mdatp_driver.py +++ b/msticpy/data/drivers/mdatp_driver.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- """MDATP OData Driver class.""" -from typing import Any, Union +from typing import Any, Optional, Union import pandas as pd @@ -27,7 +27,9 @@ class MDATPDriver(OData): CONFIG_NAME = "MicrosoftDefender" _ALT_CONFIG_NAMES = ["MDATPApp"] - def __init__(self, connection_str: str = None, instance: str = "Default", **kwargs): + def __init__( + self, connection_str: Optional[str] = None, instance: str = "Default", **kwargs + ): """ Instantiate MSDefenderDriver and optionally connect. @@ -74,7 +76,7 @@ def __init__(self, connection_str: str = None, instance: str = "Default", **kwar self.connect(connection_str) def query( - self, query: str, query_source: QuerySource = None, **kwargs + self, query: str, query_source: Optional[QuerySource] = None, **kwargs ) -> Union[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -89,7 +91,7 @@ def query( Returns ------- Union[pd.DataFrame, results.ResultSet] - A DataFrame (if successfull) or + A DataFrame (if successful) or the underlying provider result if an error. """ diff --git a/msticpy/data/drivers/odata_driver.py b/msticpy/data/drivers/odata_driver.py index ebc3204d4..fd2c1c394 100644 --- a/msticpy/data/drivers/odata_driver.py +++ b/msticpy/data/drivers/odata_driver.py @@ -18,7 +18,7 @@ from ...common.pkg_config import get_config from ...common.provider_settings import get_provider_settings from ...common.utility import mp_ua_header -from .driver_base import DriverBase, QuerySource +from .driver_base import DriverBase, DriverProps, QuerySource __version__ = VERSION __author__ = "Pete Bryan" @@ -66,6 +66,11 @@ def __init__(self, **kwargs): self.scopes = None self.msal_auth = None + self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) + self.set_driver_property( + DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) + ) + @abc.abstractmethod def query( self, query: str, query_source: QuerySource = None, **kwargs diff --git a/msticpy/data/drivers/security_graph_driver.py b/msticpy/data/drivers/security_graph_driver.py index 48b3aa671..a4f5c0433 100644 --- a/msticpy/data/drivers/security_graph_driver.py +++ b/msticpy/data/drivers/security_graph_driver.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Security Graph OData Driver class.""" -from typing import Any, Union +from typing import Any, Optional, Union import pandas as pd @@ -24,7 +24,7 @@ class SecurityGraphDriver(OData): CONFIG_NAME = "MicrosoftGraph" _ALT_CONFIG_NAMES = ["SecurityGraphApp"] - def __init__(self, connection_str: str = None, **kwargs): + def __init__(self, connection_str: Optional[str] = None, **kwargs): """ Instantiate MSGraph driver and optionally connect. @@ -54,7 +54,7 @@ def __init__(self, connection_str: str = None, **kwargs): self.connect(connection_str) def query( - self, query: str, query_source: QuerySource = None, **kwargs + self, query: str, query_source: Optional[QuerySource] = None, **kwargs ) -> Union[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -69,7 +69,7 @@ def query( Returns ------- Union[pd.DataFrame, results.ResultSet] - A DataFrame (if successfull) or + A DataFrame (if successful) or the underlying provider result if an error. """ diff --git a/msticpy/init/nbinit.py b/msticpy/init/nbinit.py index 737d57d27..170d606ed 100644 --- a/msticpy/init/nbinit.py +++ b/msticpy/init/nbinit.py @@ -421,7 +421,7 @@ def init_notebook( check_version() output = stdout_cap.getvalue() _pr_output(output) - logger.info(output) + logger.info("Check version failures: %s", output) if _detect_env("synapse", **kwargs) and is_in_synapse(): synapse_params = { @@ -438,7 +438,7 @@ def init_notebook( ) output = stdout_cap.getvalue() _pr_output(output) - logger.info(output) + logger.info("Import failures: %s", output) # Configuration check if no_config_check: @@ -468,7 +468,7 @@ def init_notebook( _load_pivots(namespace=namespace) output = stdout_cap.getvalue() _pr_output(output) - logger.info(output) + logger.info("Pivot load failures: %s", output) # User defaults stdout_cap = io.StringIO() @@ -478,6 +478,7 @@ def init_notebook( output = stdout_cap.getvalue() _pr_output(output) logger.info(output) + logger.info("User default load failures: %s", output) if prov_dict: namespace.update(prov_dict) diff --git a/tests/data/drivers/test_azure_kusto_driver.py b/tests/data/drivers/test_azure_kusto_driver.py index db8e8f0e1..be1663913 100644 --- a/tests/data/drivers/test_azure_kusto_driver.py +++ b/tests/data/drivers/test_azure_kusto_driver.py @@ -93,9 +93,14 @@ def get_test_df(): def test_init(): - # Test that __init__ sets the current_connection property correctly - driver = AzureKustoDriver(connection_str="https://test.kusto.windows.net") - assert driver.current_connection == "https://test.kusto.windows.net" + """Test initialization of AzureKustoDriver.""" + driver = AzureKustoDriver( + connection_str="cluster='https://test.kusto.windows.net', db='Security'" + ) + assert ( + driver.current_connection + == "cluster='https://test.kusto.windows.net', db='Security'" + ) # Test that __init__ sets the _connection_props property correctly driver = AzureKustoDriver(timeout=300) diff --git a/tests/data/test_async_queries.py b/tests/data/test_async_queries.py new file mode 100644 index 000000000..50c052de9 --- /dev/null +++ b/tests/data/test_async_queries.py @@ -0,0 +1,164 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Test async connections and split queries.""" + +from datetime import datetime, timedelta, timezone + +import pandas as pd +import pytest_check as check + +from msticpy.data.core.data_providers import QueryProvider +from msticpy.data.core.query_provider_connections_mixin import _calc_split_ranges +from msticpy.data.drivers.driver_base import DriverProps + +from ..unit_test_lib import get_test_data_path + +_LOCAL_DATA_PATHS = [str(get_test_data_path().joinpath("localdata"))] + +# pylint: disable=protected-access + + +def test_multiple_connections_sync(): + """Test adding connection instance to provider.""" + prov_args = dict(query_paths=_LOCAL_DATA_PATHS, data_paths=_LOCAL_DATA_PATHS) + # create local provider and run a query + local_prov = QueryProvider("LocalData", **prov_args) + start = datetime.now(timezone.utc) - timedelta(days=1) + end = datetime.now(timezone.utc) + single_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + + # add another connection (to same folder) + local_prov.add_connection(alias="SecondInst", **prov_args) + connections = local_prov.list_connections() + # verify second connection is listed + check.equal(len(connections), 2) + check.is_in("Default:", connections[0]) + check.is_in("SecondInst:", connections[1]) + + # run query again + multi_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + # verify len of result is 2x single_result + check.equal(single_results.shape[0] * 2, multi_results.shape[0]) + # verify columns/schema is the same. + check.equal(list(single_results.columns), list(multi_results.columns)) + + +def test_multiple_connections_threaded(): + """Test adding connection instance to provider.""" + prov_args = dict(query_paths=_LOCAL_DATA_PATHS, data_paths=_LOCAL_DATA_PATHS) + # create local provider and run a query + local_prov = QueryProvider("LocalData", **prov_args) + local_prov._query_provider.set_driver_property( + DriverProps.SUPPORTS_THREADING, value=True + ) + start = datetime.now(timezone.utc) - timedelta(days=1) + end = datetime.now(timezone.utc) + single_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + + # add another 2 named connections + for idx in range(1, 3): + local_prov.add_connection(alias=f"Instance {idx}", **prov_args) + # add another 2 unnamed connections + for _ in range(2): + local_prov.add_connection(**prov_args) + + connections = local_prov.list_connections() + # verify second connection is listed + check.equal(len(connections), 5) + check.is_in("Default:", connections[0]) + check.is_in("Instance 1", connections[1]) + + # run query again + multi_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + # verify len of result is 2x single_result + check.equal(single_results.shape[0] * 5, multi_results.shape[0]) + # verify columns/schema is the same. + check.equal(list(single_results.columns), list(multi_results.columns)) + + +def test_split_queries_sync(): + """Test queries split into time segments.""" + prov_args = dict(query_paths=_LOCAL_DATA_PATHS, data_paths=_LOCAL_DATA_PATHS) + local_prov = QueryProvider("LocalData", **prov_args) + + start = datetime.now(timezone.utc) - pd.Timedelta("5H") + end = datetime.now(timezone.utc) + pd.Timedelta("5min") + delta = pd.Timedelta("1H") + + ranges = _calc_split_ranges(start, end, delta) + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + result_queries = local_prov.WindowsSecurity.list_host_logons( + "print", host_name="DESKTOP-12345", start=start, end=end, split_query_by="1H" + ) + queries = result_queries.split("\n\n") + check.equal(len(queries), 5) + + for idx, (st_time, e_time) in enumerate(ranges): + check.is_in(st_time.isoformat(sep=" "), queries[idx]) + check.is_in(e_time.isoformat(sep=" "), queries[idx]) + check.is_in(start.isoformat(sep=" "), queries[0]) + check.is_in(end.isoformat(sep=" "), queries[-1]) + + single_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + result_queries = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end, split_query_by="1H" + ) + # verify len of result is 2x single_result + check.equal(single_results.shape[0] * 5, result_queries.shape[0]) + # verify columns/schema is the same. + check.equal(list(single_results.columns), list(result_queries.columns)) + + +def test_split_queries_async(): + """Test queries split into time segments threaded execution.""" + prov_args = dict(query_paths=_LOCAL_DATA_PATHS, data_paths=_LOCAL_DATA_PATHS) + local_prov = QueryProvider("LocalData", **prov_args) + local_prov._query_provider.set_driver_property( + DriverProps.SUPPORTS_THREADING, value=True + ) + + start = datetime.now(timezone.utc) - pd.Timedelta("5H") + end = datetime.now(timezone.utc) + pd.Timedelta("5min") + delta = pd.Timedelta("1H") + + ranges = _calc_split_ranges(start, end, delta) + local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + result_queries = local_prov.WindowsSecurity.list_host_logons( + "print", host_name="DESKTOP-12345", start=start, end=end, split_query_by="1H" + ) + queries = result_queries.split("\n\n") + check.equal(len(queries), 5) + + for idx, (st_time, e_time) in enumerate(ranges): + check.is_in(st_time.isoformat(sep=" "), queries[idx]) + check.is_in(e_time.isoformat(sep=" "), queries[idx]) + check.is_in(start.isoformat(sep=" "), queries[0]) + check.is_in(end.isoformat(sep=" "), queries[-1]) + + single_results = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end + ) + result_queries = local_prov.WindowsSecurity.list_host_logons( + host_name="DESKTOP-12345", start=start, end=end, split_query_by="1H" + ) + # verify len of result is 2x single_result + check.equal(single_results.shape[0] * 5, result_queries.shape[0]) + # verify columns/schema is the same. + check.equal(list(single_results.columns), list(result_queries.columns)) diff --git a/tests/data/test_dataqueries.py b/tests/data/test_dataqueries.py index df88c93c7..756c23180 100644 --- a/tests/data/test_dataqueries.py +++ b/tests/data/test_dataqueries.py @@ -20,8 +20,9 @@ from msticpy.common import pkg_config from msticpy.common.exceptions import MsticpyException -from msticpy.data.core.data_providers import QueryProvider, _calc_split_ranges +from msticpy.data.core.data_providers import QueryProvider from msticpy.data.core.query_container import QueryContainer +from msticpy.data.core.query_provider_connections_mixin import _calc_split_ranges from msticpy.data.core.query_source import QuerySource from msticpy.data.drivers.driver_base import DriverBase, DriverProps @@ -404,7 +405,7 @@ def test_split_queries_err(self): queries = result_queries.split("\n\n") # if no start and end - provider prints message and returns None self.assertEqual(len(queries), 1) - self.assertIn("Cannot split a query that", mssg.getvalue()) + self.assertIn("Cannot split a query", mssg.getvalue()) # With invalid split_query_by value it will default to 1D start = datetime.utcnow() - pd.Timedelta("5D") @@ -420,7 +421,7 @@ def test_split_queries_err(self): _LOCAL_DATA_PATHS = [str(get_test_data_path().joinpath("localdata"))] -def test_add_provider(): +def test_multiple_connections(): """Test adding connection instance to provider.""" prov_args = dict(query_paths=_LOCAL_DATA_PATHS, data_paths=_LOCAL_DATA_PATHS) # create local provider and run a query