Skip to content

Commit

Permalink
Hotfix for v2.5.1 (#672)
Browse files Browse the repository at this point in the history
Add more defense against import errors - in msticpy.__init__.py - this causes failures when help(msticpy) is used, causing loading of all dynamic attributes
Better exception message on import error in azure_data.py
Moving ResourceGraph query provider to only instantiate the provider when needed.
Made data_query_reader.py produce warnings rather throw exceptions when encountering a bad query file
  • Loading branch information
ianhelle authored Jun 2, 2023
1 parent 1f87529 commit 9466a77
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 53 deletions.
11 changes: 7 additions & 4 deletions msticpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"""
import importlib
import os
import warnings
from typing import Any, Iterable, Union

from . import nbwidgets
Expand All @@ -124,6 +125,7 @@
from ._version import VERSION
from .common import pkg_config as settings
from .common.check_version import check_version
from .common.exceptions import MsticpyException
from .common.utility import search_name as search
from .init.logging import set_logging_level, setup_logging

Expand All @@ -146,8 +148,6 @@
"GeoLiteLookup": "msticpy.context.geoip",
"init_notebook": "msticpy.init.nbinit",
"reset_ipython_exception_handler": "msticpy.init.nbinit",
"IPStackLookup": "msticpy.context.geoip",
"MicrosoftSentinel": "msticpy.context.azure",
"MpConfigEdit": "msticpy.config.mp_config_edit",
"MpConfigFile": "msticpy.config.mp_config_file",
"QueryProvider": "msticpy.data",
Expand Down Expand Up @@ -180,8 +180,11 @@ def __getattr__(attrib: str) -> Any:
"""
if attrib in _DEFAULT_IMPORTS:
module = importlib.import_module(_DEFAULT_IMPORTS[attrib])
return getattr(module, attrib)
try:
return getattr(importlib.import_module(_DEFAULT_IMPORTS[attrib]), attrib)
except (ImportError, MsticpyException):
warnings.warn("Unable to import msticpy.{attrib}", ImportWarning)
return None
raise AttributeError(f"msticpy has no attribute {attrib}")


Expand Down
2 changes: 1 addition & 1 deletion msticpy/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Version file."""
VERSION = "2.5.0"
VERSION = "2.5.1"
6 changes: 5 additions & 1 deletion msticpy/context/azure/azure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
from azure.mgmt.compute.models import VirtualMachineInstanceView
except ImportError as imp_err:
raise MsticpyImportExtraError(
"Cannot use this feature without azure packages installed",
"Cannot use this feature without these azure packages installed",
"azure.mgmt.network",
"azure.mgmt.resource",
"azure.mgmt.monitor",
"azure.mgmt.compute",
title="Error importing azure module",
extra="azure",
) from imp_err
Expand Down
25 changes: 15 additions & 10 deletions msticpy/context/azure/sentinel_workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SentinelWorkspacesMixin:
"""Mixin class for Sentinel workspaces."""

_TENANT_URI = "{cloud_endpoint}/{tenant_name}/.well-known/openid-configuration"
_RES_GRAPH_PROV = QueryProvider("ResourceGraph")
_RES_GRAPH_PROV: Optional[QueryProvider] = None

@classmethod
def get_resource_id_from_url(cls, portal_url: str) -> str:
Expand Down Expand Up @@ -238,34 +238,39 @@ def get_workspace_settings_by_name(
)
return {}

@classmethod
def _get_resource_graph_provider(cls) -> QueryProvider:
if not cls._RES_GRAPH_PROV:
cls._RES_GRAPH_PROV = QueryProvider("ResourceGraph")
if not cls._RES_GRAPH_PROV.connected:
cls._RES_GRAPH_PROV.connect() # pragma: no cover
return cls._RES_GRAPH_PROV

@classmethod
def _lookup_workspace_by_name(
cls,
workspace_name: str,
subscription_id: str = "",
resource_group: str = "",
) -> pd.DataFrame:
if not cls._RES_GRAPH_PROV.connected:
cls._RES_GRAPH_PROV.connect() # pragma: no cover
return cls._RES_GRAPH_PROV.Sentinel.list_sentinel_workspaces_for_name(
res_graph_prov = cls._get_resource_graph_provider()
return res_graph_prov.Sentinel.list_sentinel_workspaces_for_name(
workspace_name=workspace_name,
subscription_id=subscription_id,
resource_group=resource_group,
)

@classmethod
def _lookup_workspace_by_ws_id(cls, workspace_id: str) -> pd.DataFrame:
if not cls._RES_GRAPH_PROV.connected:
cls._RES_GRAPH_PROV.connect() # pragma: no cover
return cls._RES_GRAPH_PROV.Sentinel.get_sentinel_workspace_for_workspace_id(
res_graph_prov = cls._get_resource_graph_provider()
return res_graph_prov.Sentinel.get_sentinel_workspace_for_workspace_id(
workspace_id=workspace_id
)

@classmethod
def _lookup_workspace_by_res_id(cls, resource_id: str):
if not cls._RES_GRAPH_PROV.connected:
cls._RES_GRAPH_PROV.connect() # pragma: no cover
return cls._RES_GRAPH_PROV.Sentinel.get_sentinel_workspace_for_resource_id(
res_graph_prov = cls._get_resource_graph_provider()
return res_graph_prov.Sentinel.get_sentinel_workspace_for_resource_id(
resource_id=resource_id
)

Expand Down
20 changes: 10 additions & 10 deletions msticpy/data/core/data_query_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
# license information.
# --------------------------------------------------------------------------
"""Data query definition reader."""
import logging
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Iterable, Tuple

import yaml

from ..._version import VERSION
from .query_defns import DataEnvironment

__version__ = VERSION
__author__ = "Ian Hellen"

logger = logging.getLogger(__name__)


def find_yaml_files(source_path: str, recursive: bool = True) -> Iterable[Path]:
"""
Expand Down Expand Up @@ -69,12 +71,16 @@ def read_query_def_file(query_file: str) -> Tuple[Dict, Dict, Dict]:
# use safe_load instead load
data_map = yaml.safe_load(f_handle)

validate_query_defs(query_def_dict=data_map)
try:
validate_query_defs(query_def_dict=data_map)
except ValueError as err:
logger.warning("Validation failed for %s\n%s", query_file, err, exc_info=True)

defaults = data_map.get("defaults", {})
sources = data_map.get("sources", {})
metadata = data_map.get("metadata", {})

logger.info("Read %s queries from %s", len(sources), query_file)
return sources, defaults, metadata


Expand All @@ -99,6 +105,8 @@ def validate_query_defs(query_def_dict: Dict[str, Any]) -> bool:
exception message (arg[0])
"""
if query_def_dict is None or not query_def_dict:
raise ValueError("Imported file is empty")
# verify that sources and metadata are in the data dict
if "sources" not in query_def_dict or not query_def_dict["sources"]:
raise ValueError("Imported file has no sources defined")
Expand All @@ -119,14 +127,6 @@ def _validate_data_categories(query_def_dict: Dict):
):
raise ValueError("Imported file has no data_environments defined")

for env in query_def_dict["metadata"]["data_environments"]:
if not DataEnvironment.parse(env):
raise ValueError(
f"Unknown data environment {env} in metadata. ",
"Valid values are\n",
", ".join(e.name for e in DataEnvironment),
)

if (
"data_families" not in query_def_dict["metadata"]
or not query_def_dict["metadata"]["data_families"]
Expand Down
3 changes: 2 additions & 1 deletion msticpy/data/drivers/sentinel_query_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
# --------------------------------------------------------------------------
"""Github Sentinel Query repo import class and helpers."""

import logging
import os
import re
import warnings
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Optional
import logging

import attr
import httpx
import yaml
from attr import attrs
from tqdm.notebook import tqdm

from ..._version import VERSION

__version__ = VERSION
Expand Down
4 changes: 3 additions & 1 deletion tests/context/azure/test_sentinel_workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from msticpy.auth.azure_auth_core import AzureCloudConfig
from msticpy.context.azure import MicrosoftSentinel
from msticpy.data import QueryProvider

# pylint: disable=protected-access

Expand Down Expand Up @@ -380,7 +381,8 @@ def test_param_checks():


def _patch_qry_prov(patcher):
qry_prov = getattr(MicrosoftSentinel, "_RES_GRAPH_PROV")
qry_prov = QueryProvider("ResourceGraph")
setattr(MicrosoftSentinel, "_RES_GRAPH_PROV", qry_prov)
qry_prov._query_provider._loaded = True
qry_prov._query_provider._connected = True
patcher.setattr(qry_prov, "connect", lambda: True)
Expand Down
71 changes: 47 additions & 24 deletions tests/data/drivers/test_sentinel_query_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ def test_get_sentinel_queries_from_github():


def test_read_yaml_files():
yaml_files = read_yaml_files(parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections")
assert yaml_files[str(BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml"))]
yaml_files = read_yaml_files(
parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections"
)
assert yaml_files[
str(BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml"))
]


def test__import_sentinel_query():
yaml_files = read_yaml_files(parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections")
yaml_files = read_yaml_files(
parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections"
)
query_type = "Detections"
yaml_path = str(BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml"))
yaml_path = str(
BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml")
)
yaml_text = yaml_files[yaml_path]
sample_query = SentinelQuery(
query_id="d0255b5f-2a3c-4112-8744-e6757af3283a",
Expand Down Expand Up @@ -84,8 +92,12 @@ def test__import_sentinel_query():


def test_import_sentinel_query():
yaml_files = read_yaml_files(parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections")
yaml_path = str(BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml"))
yaml_files = read_yaml_files(
parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections"
)
yaml_path = str(
BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml")
)
sample_query = SentinelQuery(
query_id="d0255b5f-2a3c-4112-8744-e6757af3283a",
name="Unusual Anomaly",
Expand All @@ -109,9 +121,7 @@ def test_import_sentinel_query():
source_file_name=yaml_path,
query_type="Detections",
)
assert (
sample_query in import_sentinel_queries(yaml_files, query_type="Detections")
)
assert sample_query in import_sentinel_queries(yaml_files, query_type="Detections")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -165,45 +175,58 @@ def test__format_query_name(initial_str, expected_result):
version="1.0.1",
kind="Scheduled",
folder_name="Anomalies",
source_file_name=str(BASE_DIR_TEST_FOLDER.joinpath("Detections/Anomalies/UnusualAnomaly.yaml")),
source_file_name=str(
BASE_DIR_TEST_FOLDER.joinpath(
"Detections/Anomalies/UnusualAnomaly.yaml"
)
),
query_type="Detections",
)
],
),
],
)
def test__organize_query_list_by_folder(dict_section, expected_result):
yaml_files = read_yaml_files(parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections")
yaml_files = read_yaml_files(
parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections"
)
query_list = import_sentinel_queries(yaml_files=yaml_files, query_type="Detections")
if dict_section == "keys":
assert (
sorted(list(_organize_query_list_by_folder(query_list=query_list).keys()))
== sorted(expected_result)
)
assert sorted(
list(_organize_query_list_by_folder(query_list=query_list).keys())
) == sorted(expected_result)
else:
assert (
sorted(_organize_query_list_by_folder(query_list=query_list)[dict_section])
== sorted(expected_result)
)

assert sorted(
_organize_query_list_by_folder(query_list=query_list)[dict_section]
) == sorted(expected_result)


def test__create_queryfile_metadata():
ignore_keys = ['last_updated'] #timing may differ but doesn't matter for test purposes
generated_dict = {k:v for k,v in _create_queryfile_metadata(folder_name="Detections")["metadata"].items() if k not in ignore_keys}
ignore_keys = [
"last_updated"
] # timing may differ but doesn't matter for test purposes
generated_dict = {
k: v
for k, v in _create_queryfile_metadata(folder_name="Detections")[
"metadata"
].items()
if k not in ignore_keys
}
test_dict = {
"version": 1,
"description": "Sentinel Alert Queries - Detections",
"data_environments": ["MSSentinel"],
"data_families": "Detections"
"data_families": "Detections",
}
assert generated_dict == test_dict


# original test case for generating new yaml files
@pytest.mark.skip(reason="requires downloading the file directly during the test")
def test_write_to_yaml():
yaml_files = read_yaml_files(parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections")
yaml_files = read_yaml_files(
parent_dir=BASE_DIR_TEST_FOLDER, child_dir="Detections"
)
query_list = import_sentinel_queries(yaml_files=yaml_files, query_type="Detections")
query_list = [l for l in query_list if l is not None]
write_to_yaml(
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_dataqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_graph_load_query_exec(self):
def test_load_yaml_def(self):
"""Test query loader rejecting badly formed query files."""
la_provider = self.la_provider
with self.assertRaises((MsticpyException, ValueError)) as cm:
with self.assertRaises((MsticpyException, ValueError, KeyError)) as cm:
file_path = Path(_TEST_DATA, "data_q_meta_fail.yaml")
la_provider.import_query_file(query_file=file_path)
self.assertIn("no data families defined", str(cm.exception))
Expand Down

0 comments on commit 9466a77

Please sign in to comment.