From c45c2adc74d7c26237ebcc7da3aee84441a02813 Mon Sep 17 00:00:00 2001 From: andrey-snowflake <42752788+andrey-snowflake@users.noreply.github.com> Date: Wed, 3 Apr 2019 21:52:49 -0700 Subject: [PATCH] v1.6.1 Polish and Fixes * Adds owner field to Violations, default values for missing fields * Adds config account selection option to installer * Adds data views for querying rules by tag * Adds query_name set in Alert Query Runners instead of AQ's * Fixes default value of alert event_time in WebUI * Fixes bugs in WebUI, installer, and ingestion script * Fixes bugs in VQ runner metadata run and error recording --- src/ingestion/list_aws_accounts.py | 2 +- src/ingestion/okta_ingestion.py | 17 ++-- src/runners/alert_handler.py | 7 +- src/runners/alert_queries_runner.py | 14 ++-- src/runners/alert_suppressions_runner.py | 9 ++- src/runners/config.py | 34 ++++---- src/runners/helpers/db.py | 80 ++++++++++--------- src/runners/helpers/dbconfig.py | 18 ++--- src/runners/helpers/log.py | 51 +++++++++--- src/runners/test_run_violations.py | 55 ++++++++++--- src/runners/violation_queries_runner.py | 15 ++-- src/samui/frontend/src/reducers/rules.ts | 18 ++--- .../src/routes/Dashboard/Policies.tsx | 2 +- .../src/routes/Dashboard/Violations.tsx | 6 ++ src/samui/frontend/src/store/rules.ts | 16 ++-- src/scripts/install.py | 43 +++------- .../installer-queries/create-udtfs.sql.fmt | 18 +++++ .../installer-queries/data-views.sql.fmt | 43 +++++++--- .../sample-violation-queries.sql.fmt | 1 + 19 files changed, 279 insertions(+), 170 deletions(-) create mode 100644 src/scripts/installer-queries/create-udtfs.sql.fmt diff --git a/src/ingestion/list_aws_accounts.py b/src/ingestion/list_aws_accounts.py index 44cbd42e7..221cf4746 100644 --- a/src/ingestion/list_aws_accounts.py +++ b/src/ingestion/list_aws_accounts.py @@ -94,7 +94,7 @@ def load_accounts_list(sf_client, accounts_list): def main(): sf_client = get_snowflake_client() current_time = datetime.datetime.now(datetime.timezone.utc) - last_time = sf_client.cursor().execute(f'SELECT max(timestamp) FROM {AWS_ACCOUNTS_TABLE}').fetchall[0][0] + last_time = sf_client.cursor().execute(f'SELECT max(timestamp) FROM {AWS_ACCOUNTS_TABLE}').fetchall()[0][0] if (current_time - last_time).seconds > 86400: client = get_aws_client() accounts_list = get_accounts_list(client) diff --git a/src/ingestion/okta_ingestion.py b/src/ingestion/okta_ingestion.py index 0616634ff..d4a86b61f 100644 --- a/src/ingestion/okta_ingestion.py +++ b/src/ingestion/okta_ingestion.py @@ -9,6 +9,8 @@ ALOOMA_SDK = alooma_pysdk.PythonSDK(INPUT_TOKEN) OKTA_API_KEY = environ.get('OKTA_API_KEY') AUTH = "SSWS "+OKTA_API_KEY +OKTA_URL = environ.get('OKTA_URL') +OKTA_TABLE = environ.get('OKTA_TABLE') HEADERS = {'Accept': 'application/json', 'Content-Type': 'application/json', 'Authorization': AUTH} @@ -24,17 +26,18 @@ def process_logs(logs): def get_timestamp(): # Once pipelines are more strongly integrated with the installer, this table should be a variable - timestamp_query = """ - SELECT PUBLISHED from SECURITY.ALOOMA.SNOWBIZ_OKTA - order by PUBLISHED desc + timestamp_query = f""" + SELECT EVENT_TIME from {OKTA_TABLE} + WHERE EVENT_TIME IS NOT NULL + order by EVENT_TIME desc limit 1 """ try: _, ts = db.connect_and_fetchall(timestamp_query) - print(ts) + log.info(ts) ts = ts[0][0] ts = ts.strftime("%Y-%m-%dT%H:%M:%S.000Z") - print(ts) + log.info(ts) if len(ts) < 1: log.error("The okta timestamp is too short or doesn't exist; defaulting to one hour ago") ts = datetime.datetime.now() - datetime.timedelta(hours=1) @@ -46,13 +49,13 @@ def get_timestamp(): ts = ts.strftime("%Y-%m-%dT%H:%M:%S.000Z") ret = {'since': ts} - print(ret) + log.info(ret) return ret def main(): - url = 'https://snowbiz.okta.com/api/v1/logs' + url = OKTA_URL log.info("starting loop") timestamp = get_timestamp() while 1: diff --git a/src/runners/alert_handler.py b/src/runners/alert_handler.py index 291d0658b..57fcfea4e 100755 --- a/src/runners/alert_handler.py +++ b/src/runners/alert_handler.py @@ -148,8 +148,11 @@ def main(): # Record the new ticket id in the alert table record_ticket_id(ctx, ticket_id, alert['ALERT_ID']) - if CLOUDWATCH_METRICS: - log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Handler'}], 1) + try: + if CLOUDWATCH_METRICS: + log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Handler'}], 1) + except Exception as e: + log.error("Cloudwatch metric logging failed", e) if __name__ == "__main__": diff --git a/src/runners/alert_queries_runner.py b/src/runners/alert_queries_runner.py index 511f560f3..618a1408f 100755 --- a/src/runners/alert_queries_runner.py +++ b/src/runners/alert_queries_runner.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import json -import hashlib import uuid import datetime from typing import Any, Dict, List, Tuple @@ -21,12 +20,6 @@ QUERY_HISTORY: List = [] -def alert_group(alert) -> str: - return hashlib.md5( - f"{alert['OBJECT']}{alert['DESCRIPTION']}".encode('utf-8') - ).hexdigest() - - def log_alerts(ctx, alerts): if len(alerts): print("Recording alerts.") @@ -121,8 +114,11 @@ def main(rule_name=None): } log.metadata_record(ctx, RUN_METADATA, table=RUN_METADATA_TABLE) - if CLOUDWATCH_METRICS: - log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Query Runner'}], 1) + try: + if CLOUDWATCH_METRICS: + log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Query Runner'}], 1) + except Exception as e: + log.error("Cloudwatch metric logging failed: ", e) if __name__ == '__main__': diff --git a/src/runners/alert_suppressions_runner.py b/src/runners/alert_suppressions_runner.py index 2992d6382..d16881055 100755 --- a/src/runners/alert_suppressions_runner.py +++ b/src/runners/alert_suppressions_runner.py @@ -74,7 +74,7 @@ def log_failure(ctx, suppression_name, e, event_data=None, description=None): alerts = [json.dumps({ 'ALERT_ID': uuid.uuid4().hex, 'QUERY_ID': 'b1d02051dd2c4d62bb75274f2ee5996a', - 'QUERY_NAME': 'Suppression Runner Failure', + 'QUERY_NAME': ['Suppression Runner Failure'], 'ENVIRONMENT': 'Suppressions', 'SOURCES': 'Suppression Runner', 'ACTOR': 'Suppression Runner', @@ -154,8 +154,11 @@ def main(): log.metadata_record(ctx, RUN_METADATA, table=RUN_METADATA_TABLE) - if CLOUDWATCH_METRICS: - log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Suppression Runner'}], 1) + try: + if CLOUDWATCH_METRICS: + log.metric('Run', 'SnowAlert', [{'Name': 'Component', 'Value': 'Alert Suppression Runner'}], 1) + except Exception as e: + log.error("Cloudwatch metric logging failed: ", e) if __name__ == '__main__': diff --git a/src/runners/config.py b/src/runners/config.py index bca0882bd..8f289d679 100644 --- a/src/runners/config.py +++ b/src/runners/config.py @@ -1,32 +1,34 @@ -import os +from os import environ import uuid from runners.helpers.dbconfig import DATABASE +ENV = environ.get('SA_ENV', 'unset') + # generated once per runtime RUN_ID = uuid.uuid4().hex # schema names -DATA_SCHEMA_NAME = os.environ.get('SA_DATA_SCHEMA_NAME', "data") -RULES_SCHEMA_NAME = os.environ.get('SA_RULES_SCHEMA_NAME', "rules") -RESULTS_SCHEMA_NAME = os.environ.get('SA_RESULTS_SCHEMA_NAME', "results") +DATA_SCHEMA_NAME = environ.get('SA_DATA_SCHEMA_NAME', "data") +RULES_SCHEMA_NAME = environ.get('SA_RULES_SCHEMA_NAME', "rules") +RESULTS_SCHEMA_NAME = environ.get('SA_RESULTS_SCHEMA_NAME', "results") # table names -RESULTS_ALERTS_TABLE_NAME = os.environ.get('SA_RESULTS_ALERTS_TABLE_NAME', "alerts") -RESULTS_VIOLATIONS_TABLE_NAME = os.environ.get('SA_RESULTS_VIOLATIONS_TABLE_NAME', "violations") -QUERY_METADATA_TABLE_NAME = os.environ.get('SA_QUERY_METADATA_TABLE_NAME', "query_metadata") -RUN_METADATA_TABLE_NAME = os.environ.get('SA_RUN_METADATA_TABLE_NAME', "run_metadata") +RESULTS_ALERTS_TABLE_NAME = environ.get('SA_RESULTS_ALERTS_TABLE_NAME', "alerts") +RESULTS_VIOLATIONS_TABLE_NAME = environ.get('SA_RESULTS_VIOLATIONS_TABLE_NAME', "violations") +QUERY_METADATA_TABLE_NAME = environ.get('SA_QUERY_METADATA_TABLE_NAME', "query_metadata") +RUN_METADATA_TABLE_NAME = environ.get('SA_RUN_METADATA_TABLE_NAME', "run_metadata") # schemas -DATA_SCHEMA = os.environ.get('SA_DATA_SCHEMA', f"{DATABASE}.{DATA_SCHEMA_NAME}") -RULES_SCHEMA = os.environ.get('SA_RULES_SCHEMA', f"{DATABASE}.{RULES_SCHEMA_NAME}") -RESULTS_SCHEMA = os.environ.get('SA_RESULTS_SCHEMA', f"{DATABASE}.{RESULTS_SCHEMA_NAME}") +DATA_SCHEMA = environ.get('SA_DATA_SCHEMA', f"{DATABASE}.{DATA_SCHEMA_NAME}") +RULES_SCHEMA = environ.get('SA_RULES_SCHEMA', f"{DATABASE}.{RULES_SCHEMA_NAME}") +RESULTS_SCHEMA = environ.get('SA_RESULTS_SCHEMA', f"{DATABASE}.{RESULTS_SCHEMA_NAME}") # tables -ALERTS_TABLE = os.environ.get('SA_ALERTS_TABLE', f"{RESULTS_SCHEMA}.{RESULTS_ALERTS_TABLE_NAME}") -VIOLATIONS_TABLE = os.environ.get('SA_VIOLATIONS_TABLE', f"{RESULTS_SCHEMA}.{RESULTS_VIOLATIONS_TABLE_NAME}") -QUERY_METADATA_TABLE = os.environ.get('SA_QUERY_METADATA_TABLE', f"{RESULTS_SCHEMA}.{QUERY_METADATA_TABLE_NAME}") -RUN_METADATA_TABLE = os.environ.get('SA_METADATA_RUN_TABLE', f"{RESULTS_SCHEMA}.{RUN_METADATA_TABLE_NAME}") +ALERTS_TABLE = environ.get('SA_ALERTS_TABLE', f"{RESULTS_SCHEMA}.{RESULTS_ALERTS_TABLE_NAME}") +VIOLATIONS_TABLE = environ.get('SA_VIOLATIONS_TABLE', f"{RESULTS_SCHEMA}.{RESULTS_VIOLATIONS_TABLE_NAME}") +QUERY_METADATA_TABLE = environ.get('SA_QUERY_METADATA_TABLE', f"{RESULTS_SCHEMA}.{QUERY_METADATA_TABLE_NAME}") +RUN_METADATA_TABLE = environ.get('SA_METADATA_RUN_TABLE', f"{RESULTS_SCHEMA}.{RUN_METADATA_TABLE_NAME}") # misc ALERT_QUERY_POSTFIX = "ALERT_QUERY" @@ -35,7 +37,7 @@ VIOLATION_SQUELCH_POSTFIX = "VIOLATION_SUPPRESSION" # enabling sends metrics to cloudwatch -CLOUDWATCH_METRICS = os.environ.get('CLOUDWATCH_METRICS', False) +CLOUDWATCH_METRICS = environ.get('CLOUDWATCH_METRICS', False) CONFIG_VARS = [ ('ALERTS_TABLE', ALERTS_TABLE), diff --git a/src/runners/helpers/db.py b/src/runners/helpers/db.py index d28b2ad8c..7ac5594df 100644 --- a/src/runners/helpers/db.py +++ b/src/runners/helpers/db.py @@ -83,11 +83,11 @@ def connect(): log.error(e, "Failed to connect.") -def fetch(ctx, query=None): +def fetch(ctx, query=None, fix_errors=True): if query is None: # TODO(andrey): swap args and refactor ctx, query = CACHED_CONNECTION, ctx - res = execute(ctx, query) + res = execute(ctx, query, fix_errors) cols = [c[0] for c in res.description] while True: row = res.fetchone() @@ -97,6 +97,7 @@ def fetch(ctx, query=None): def execute(ctx, query=None, fix_errors=True): + # TODO(andrey): don't ignore errors by default if query is None: # TODO(andrey): swap args and refactor ctx, query = CACHED_CONNECTION, ctx @@ -104,14 +105,15 @@ def execute(ctx, query=None, fix_errors=True): return ctx.cursor().execute(query) except snowflake.connector.errors.ProgrammingError as e: - if not fix_errors: - raise - if e.errno == int(MASTER_TOKEN_EXPIRED_GS_CODE): connect(run_preflight_checks=False, flush_cache=True) return execute(query) - log.error(e, f"Ignoring programming Error in query: {query}") + if not fix_errors: + log.debug(f"re-raising error '{e}' in query >{query}<") + raise + + log.info(e, f"ignoring error '{e}' in query >{query}<") return ctx.cursor().execute("SELECT 1 WHERE FALSE;") @@ -151,18 +153,25 @@ def load_rules(ctx, postfix) -> List[str]: return rules +INSERT_ALERTS_QUERY = f""" +INSERT INTO results.alerts (alert_time, event_time, alert) +SELECT PARSE_JSON(column1):ALERT_TIME + , PARSE_JSON(column1):EVENT_TIME + , PARSE_JSON(column1) +FROM VALUES {{values}} +""" + + +def sql_value_placeholders(n): + return ", ".join(["(%s)"] * n) + + def insert_alerts(alerts, ctx=None): if ctx is None: ctx = CACHED_CONNECTION or connect() - ctx.cursor().execute( - ( - f'INSERT INTO restuls.alerts (alert_time, event_time, alert) ' - f'SELECT PARSE_JSON(column1):ALERT_TIME, PARSE_JSON(column1):EVENT_TIME, PARSE_JSON(column1) ' - f'FROM values {", ".join(["(%s)"] * len(alerts))};' - ), - alerts - ) + query = INSERT_ALERTS_QUERY.format(values=sql_value_placeholders(len(alerts))) + return ctx.cursor().execute(query, alerts) def insert_alerts_query_run(query_name, from_time_sql, to_time_sql='CURRENT_TIMESTAMP()', ctx=None): @@ -202,16 +211,17 @@ def insert_alerts_query_run(query_name, from_time_sql, to_time_sql='CURRENT_TIME IFNULL( OBJECT_CONSTRUCT(*):IDENTITY, OBJECT_CONSTRUCT( - 'ENVIRONMENT', IFNULL(environment, PARSE_JSON('null')), - 'OBJECT', IFNULL(object, PARSE_JSON('null')), - 'TITLE', IFNULL(title, PARSE_JSON('null')), - 'ALERT_TIME', IFNULL(alert_time, PARSE_JSON('null')), - 'DESCRIPTION', IFNULL(description, PARSE_JSON('null')), - 'EVENT_DATA', IFNULL(event_data, PARSE_JSON('null')), - 'DETECTOR', IFNULL(detector, PARSE_JSON('null')), - 'SEVERITY', IFNULL(severity, PARSE_JSON('null')), - 'QUERY_ID', IFNULL(query_id, PARSE_JSON('null')), - 'QUERY_NAME', IFNULL(query_name, PARSE_JSON('null')) + 'ENVIRONMENT', IFNULL(OBJECT_CONSTRUCT(*):ENVIRONMENT, PARSE_JSON('null')), + 'OBJECT', IFNULL(OBJECT_CONSTRUCT(*):OBJECT, PARSE_JSON('null')), + 'OWNER', IFNULL(OBJECT_CONSTRUCT(*):OWNER, PARSE_JSON('null')), + 'TITLE', IFNULL(OBJECT_CONSTRUCT(*):TITLE, PARSE_JSON('null')), + 'ALERT_TIME', IFNULL(OBJECT_CONSTRUCT(*):ALERT_TIME, PARSE_JSON('null')), + 'DESCRIPTION', IFNULL(OBJECT_CONSTRUCT(*):DESCRIPTION, PARSE_JSON('null')), + 'EVENT_DATA', IFNULL(OBJECT_CONSTRUCT(*):EVENT_DATA, PARSE_JSON('null')), + 'DETECTOR', IFNULL(OBJECT_CONSTRUCT(*):DETECTOR, PARSE_JSON('null')), + 'SEVERITY', IFNULL(OBJECT_CONSTRUCT(*):SEVERITY, PARSE_JSON('null')), + 'QUERY_ID', IFNULL(OBJECT_CONSTRUCT(*):QUERY_ID, PARSE_JSON('null')), + 'QUERY_NAME', '{{query_name}}' ) ) )) @@ -229,17 +239,11 @@ def insert_violations_query_run(query_name, ctx=None) -> Tuple[int, int]: log.info(f"{query_name} processing...") try: - try: - result = next(fetch(ctx, INSERT_VIOLATIONS_WITH_ID_QUERY.format(**locals()))) - except Exception: - log.info('warning: missing STRING ID column in RESULTS.VIOLATIONS') - result = next(fetch(ctx, INSERT_VIOLATIONS_QUERY.format(**locals()))) - - num_rows_inserted = result['number of rows inserted'] - log.info(f"{query_name} created {num_rows_inserted} rows, updated 0 rows.") - log.info(f"{query_name} done.") - return num_rows_inserted, 0 - - except Exception as e: - log.info(f"{query_name} run threw an exception:", e) - return 0, 0 + result = next(fetch(INSERT_VIOLATIONS_WITH_ID_QUERY.format(**locals()), fix_errors=False)) + except Exception: + log.info('warning: missing STRING ID column in RESULTS.VIOLATIONS') + result = next(fetch(INSERT_VIOLATIONS_QUERY.format(**locals()), fix_errors=False)) + + num_rows_inserted = result['number of rows inserted'] + log.info(f"{query_name} created {num_rows_inserted} rows.") + return num_rows_inserted diff --git a/src/runners/helpers/dbconfig.py b/src/runners/helpers/dbconfig.py index 921bbdf67..e168df34f 100644 --- a/src/runners/helpers/dbconfig.py +++ b/src/runners/helpers/dbconfig.py @@ -1,17 +1,17 @@ from base64 import b64decode -import os +from os import environ # database & account properties -REGION = os.environ.get('REGION', "us-west-2") -ACCOUNT = os.environ.get('SNOWFLAKE_ACCOUNT', '') + ('' if REGION == 'us-west-2' else f'.{REGION}') +REGION = environ.get('REGION', "us-west-2") +ACCOUNT = environ.get('SNOWFLAKE_ACCOUNT', '') + ('' if REGION == 'us-west-2' else f'.{REGION}') -USER = os.environ.get('SA_USER', "snowalert") -PRIVATE_KEY_PASSWORD = os.environ.get('PRIVATE_KEY_PASSWORD', '').encode('utf-8') -PRIVATE_KEY = b64decode(os.environ['PRIVATE_KEY']) if os.environ.get('PRIVATE_KEY') else None +USER = environ.get('SA_USER', "snowalert") +PRIVATE_KEY_PASSWORD = environ.get('PRIVATE_KEY_PASSWORD', '').encode('utf-8') +PRIVATE_KEY = b64decode(environ['PRIVATE_KEY']) if environ.get('PRIVATE_KEY') else None -ROLE = os.environ.get('SA_ROLE', "snowalert") -WAREHOUSE = os.environ.get('SA_WAREHOUSE', "snowalert") -DATABASE = os.environ.get('SA_DATABASE', "snowalert") +ROLE = environ.get('SA_ROLE', "snowalert") +WAREHOUSE = environ.get('SA_WAREHOUSE', "snowalert") +DATABASE = environ.get('SA_DATABASE', "snowalert") # connection properties TIMEOUT = 500 diff --git a/src/runners/helpers/log.py b/src/runners/helpers/log.py index 8a8006d14..f87b3037b 100644 --- a/src/runners/helpers/log.py +++ b/src/runners/helpers/log.py @@ -1,17 +1,31 @@ -import traceback -import sys -import boto3 import datetime import json +import sys +import traceback + +import boto3 + +from ..config import ENV + + +def format_exception(e): + return ''.join(traceback.format_exception(type(e), e, e.__traceback__)) + + +def format_exception_only(e): + return ''.join(traceback.format_exception_only(type(e), e)).strip() def write(*args, stream=sys.stdout): for a in args: if isinstance(a, Exception): - traceback.print_exception(type(a), a, a.__traceback__, file=stream) - stream.flush() - else: - print(a, file=stream, flush=True) + a = format_exception(a) + print(a, file=stream, flush=True) + + +def debug(*args): + if ENV in ('dev', 'test'): + write(*args, stream=sys.stdout) def info(*args): @@ -41,7 +55,19 @@ def metric(metric, namespace, dimensions, value): def metadata_record(ctx, metadata, table, e=None): - metadata['EXCEPTION'] = ''.join(traceback.format_exception(type(e), e, e.__traceback__)) if e else None + if e is None and 'EXCEPTION' in metadata: + e = metadata['EXCEPTION'] + del metadata['EXCEPTION'] + + if e is not None: + exception_only = format_exception_only(e) + metadata['ERROR'] = { + 'EXCEPTION': format_exception(e), + 'EXCEPTION_ONLY': exception_only, + } + if exception_only.startswith('snowflake.connector.errors.ProgrammingError: '): + metadata['ERROR']['PROGRAMMING_ERROR'] = exception_only[45:] + metadata.setdefault('ROW_COUNT', {'INSERTED': 0, 'UPDATED': 0, 'SUPPRESSED': 0, 'PASSED': 0}) metadata['END_TIME'] = datetime.datetime.utcnow() @@ -51,10 +77,13 @@ def metadata_record(ctx, metadata, table, e=None): record_type = metadata.get('QUERY_NAME', 'RUN') + metadata_json_sql = "'" + json.dumps(metadata).replace('\\', '\\\\').replace("'", "\\'") + "'" + sql = f''' - INSERT INTO {table} - (event_time, v) select '{metadata['START_TIME']}', - PARSE_JSON(column1) from values('{json.dumps(metadata)}') + INSERT INTO {table}(event_time, v) + SELECT '{metadata['START_TIME']}' + , PARSE_JSON(column1) + FROM VALUES({metadata_json_sql}) ''' try: diff --git a/src/runners/test_run_violations.py b/src/runners/test_run_violations.py index e42512232..40743f159 100644 --- a/src/runners/test_run_violations.py +++ b/src/runners/test_run_violations.py @@ -17,8 +17,29 @@ , OBJECT_CONSTRUCT('b', 1, 'a', 2) AS event_data , 'snowalert-test-detector' AS detector , 'low' AS severity + -- , 'test-missing-owner' AS owner + , 'test-violation-query-id' AS query_id +FROM (SELECT 1 AS test_data) +WHERE 1=1 + AND test_data=1 +""" + +TEST_INVALID_QUERY = f""" +CREATE OR REPLACE VIEW rules.test_invalid_violation_query COPY GRANTS + COMMENT='Test Invalid Violation Summary + @id test-invalid-violation-query-id + @tags test-invalid-violation-tag' +AS +SELECT NULL AS environment + , NULL AS object + , NULL AS title + , CURRENT_TIMESTAMP() AS alert_time + , NULL AS description + , 1/0 AS event_data + , NULL AS detector + , NULL AS severity + , NULL AS owner , 'test-violation-query-id' AS query_id - , 'TEST_VIOLATION_QUERY' AS query_name FROM (SELECT 1 AS test_data) WHERE 1=1 AND test_data=1 @@ -36,8 +57,9 @@ """ TEARDOWN_QUERIES = [ - f"DROP VIEW rules.test_violation_query", - f"DROP VIEW rules.test_violation_suppression", + f"DROP VIEW IF EXISTS rules.test_violation_query", + f"DROP VIEW IF EXISTS rules.test_violation_suppression", + f"DROP VIEW IF EXISTS rules.test_invalid_violation_query", f"DELETE FROM results.violations", f"DELETE FROM results.run_metadata", f"DELETE FROM results.query_metadata", @@ -56,6 +78,7 @@ def setup(): db.connect() db.execute(TEST_QUERY) db.execute(TEST_SUPPRESSION) + db.execute(TEST_INVALID_QUERY) def teardown(): @@ -82,7 +105,10 @@ def test_run_violations(): # violation_queries_runner.main() - v = next(db.fetch('SELECT * FROM data.violations')) + violations = list(db.fetch('SELECT * FROM data.violations')) + + assert len(violations) == 1 + v = violations[0] default_identity = { "ENVIRONMENT": "SnowAlert Test Runner", @@ -95,6 +121,7 @@ def test_run_violations(): "SEVERITY": "low", "QUERY_ID": "test-violation-query-id", "QUERY_NAME": "TEST_VIOLATION_QUERY", + "OWNER": None, } # basics @@ -106,13 +133,19 @@ def test_run_violations(): assert v['CREATED_TIME'] is not None # metadata - queries_run_record = next(db.fetch('SELECT * FROM data.violation_queries_runs')) - assert queries_run_record['NUM_VIOLATIONS_CREATED'] == 1 - assert queries_run_record['NUM_VIOLATIONS_UPDATED'] == 0 - - query_rule_run_record = next(db.fetch('SELECT * FROM data.violation_query_rule_runs')) - assert query_rule_run_record['NUM_VIOLATIONS_CREATED'] == 1 - assert query_rule_run_record['NUM_VIOLATIONS_UPDATED'] == 0 + queries_run_records = list(db.fetch('SELECT * FROM data.violation_queries_runs ORDER BY start_time DESC')) + assert len(queries_run_records) == 1 + assert queries_run_records[0]['NUM_VIOLATIONS_CREATED'] == 1 + + query_rule_run_record = list(db.fetch('SELECT * FROM data.violation_query_rule_runs ORDER BY start_time DESC')) + assert query_rule_run_record[0]['NUM_VIOLATIONS_CREATED'] == 1 + assert query_rule_run_record[1]['NUM_VIOLATIONS_CREATED'] == 0 + + assert type(query_rule_run_record[1].get('ERROR')) is str + error = json.loads(query_rule_run_record[1]['ERROR']) + assert error['PROGRAMMING_ERROR'] == '100051 (22012): Division by zero' + assert 'snowflake.connector.errors.ProgrammingError' in error['EXCEPTION_ONLY'] + assert 'Traceback (most recent call last)' in error['EXCEPTION'] # # run supperessions diff --git a/src/runners/violation_queries_runner.py b/src/runners/violation_queries_runner.py index 20b71887d..12c77ad6e 100755 --- a/src/runners/violation_queries_runner.py +++ b/src/runners/violation_queries_runner.py @@ -2,14 +2,14 @@ import datetime -from runners.config import ( +from .config import ( QUERY_METADATA_TABLE, RUN_METADATA_TABLE, VIOLATION_QUERY_POSTFIX, CLOUDWATCH_METRICS, RUN_ID, ) -from runners.helpers import db, log +from .helpers import db, log METADATA_RECORDS = [] @@ -30,17 +30,22 @@ def main(): 'ATTEMPTS': 1, 'START_TIME': datetime.datetime.utcnow(), } - insert_count, update_count = db.insert_violations_query_run(query_name) + try: + insert_count = db.insert_violations_query_run(query_name) + except Exception as e: + log.info(f"{query_name} threw an exception.") + insert_count = 0 + metadata['EXCEPTION'] = e + metadata['ROW_COUNT'] = { 'INSERTED': insert_count, - 'UPDATED': update_count, } log.metadata_record(ctx, metadata, table=QUERY_METADATA_TABLE) + log.info(f"{query_name} done.") METADATA_RECORDS.append(metadata) RUN_METADATA['ROW_COUNT'] = { 'INSERTED': sum(r['ROW_COUNT']['INSERTED'] for r in METADATA_RECORDS), - 'UPDATED': sum(r['ROW_COUNT']['UPDATED'] for r in METADATA_RECORDS), } log.metadata_record(ctx, RUN_METADATA, table=RUN_METADATA_TABLE) diff --git a/src/samui/frontend/src/reducers/rules.ts b/src/samui/frontend/src/reducers/rules.ts index 628b61ef7..2072152b7 100644 --- a/src/samui/frontend/src/reducers/rules.ts +++ b/src/samui/frontend/src/reducers/rules.ts @@ -21,7 +21,7 @@ SELECT 'E' AS environment , ARRAY_CONSTRUCT('S') AS sources , 'Predicate' AS object , 'rule title' AS title - , CURRENT_TIMESTAMP() AS event_time + , NULL AS event_time , CURRENT_TIMESTAMP() AS alert_time , 'S: Subject Verb Predicate at ' || alert_time AS description , 'Subject' AS actor @@ -29,7 +29,6 @@ SELECT 'E' AS environment , 'SnowAlert' AS detector , OBJECT_CONSTRUCT(*) AS event_data , 'low' AS severity - , '${s}' AS query_name , '${qid}' AS query_id FROM data.\nWHERE 1=1\n AND 2=2\n;`; @@ -45,15 +44,15 @@ SELECT 'E' AS environment , OBJECT_CONSTRUCT(*) AS event_data , 'SnowAlert' AS detector , 'low' AS severity - , '${s}' AS query_name + , NULL AS owner , '${qid}' AS query_id FROM data.\nWHERE 1=1\n AND 2=2\n;`; const alertSuppressionBody = (s: string) => `CREATE OR REPLACE VIEW rules.${s}_ALERT_SUPPRESSION COPY GRANTS COMMENT='New Alert Suppression' AS -SELECT alert -FROM results.alerts +SELECT id +FROM data.alerts WHERE suppressed IS NULL AND ... ;`; @@ -62,7 +61,7 @@ const violationSuppressionBody = (s: string) => `CREATE OR REPLACE VIEW rules.${ COMMENT='New Violation Suppression' AS SELECT id -FROM results.violations +FROM data.violations WHERE suppressed IS NULL AND ... ;`; @@ -145,6 +144,9 @@ export const rules: Reducer = ( ...state, queries: state.queries.map(q => (q.viewName === viewName ? q.copy({isSaving: false}) : q)), suppressions: state.suppressions.map(s => (s.viewName === viewName ? s.copy({isSaving: false}) : s)), + policies: state.policies.map(p => + p.viewName === state.currentRuleView ? Object.assign(p, {isSaving: false}) : p, + ), }; } @@ -162,9 +164,7 @@ export const rules: Reducer = ( const {viewName, newDescription} = action.payload; return { ...state, - policies: state.policies.map(p => - viewName !== p.viewName ? p : Object.assign(p, {description: newDescription}), - ), + policies: state.policies.map(p => (viewName !== p.viewName ? p : Object.assign(p, {summary: newDescription}))), }; } diff --git a/src/samui/frontend/src/routes/Dashboard/Policies.tsx b/src/samui/frontend/src/routes/Dashboard/Policies.tsx index 1634a4b8a..061dbbc45 100644 --- a/src/samui/frontend/src/routes/Dashboard/Policies.tsx +++ b/src/samui/frontend/src/routes/Dashboard/Policies.tsx @@ -122,7 +122,7 @@ class Policies extends React.PureComponent { type="primary" disabled={policy.isSaving || !policy.isEdited} style={{marginRight: 10}} - onClick={() => this.props.saveRule(policy.raw)} + onClick={() => this.props.saveRule(Object.assign(policy.raw, {body: policy.body}))} > {policy.isSaving ? : 'Save'} diff --git a/src/samui/frontend/src/routes/Dashboard/Violations.tsx b/src/samui/frontend/src/routes/Dashboard/Violations.tsx index ed1c7dca3..156027824 100644 --- a/src/samui/frontend/src/routes/Dashboard/Violations.tsx +++ b/src/samui/frontend/src/routes/Dashboard/Violations.tsx @@ -100,6 +100,12 @@ class Violations extends React.PureComponent { getValue: (q: Query) => q.fields.select.alert_time, setValue: (q: Query, v: string) => q.copy({fields: {select: {alert_time: v}}}), }, + { + title: 'Owner', + type: 'string', + getValue: (q: Query) => q.fields.select.owner, + setValue: (q: Query, v: string) => q.copy({fields: {select: {owner: v}}}), + }, ], }, { diff --git a/src/samui/frontend/src/store/rules.ts b/src/samui/frontend/src/store/rules.ts index 983d40657..6ec192f89 100644 --- a/src/samui/frontend/src/store/rules.ts +++ b/src/samui/frontend/src/store/rules.ts @@ -20,12 +20,12 @@ export class Subpolicy { } const BLANK_POLICY = (viewName: string) => - `CREATE OR REPLACE VIEW x.y.${viewName}_POLICY_DEFINITION COPY GRANTS` + - ` COMMENT='Policy Title` + - `description goes here'` + - `AS` + - ` SELECT 'subpolicy title' AS title` + - ` , true AS passing` + + `CREATE OR REPLACE VIEW rules.${viewName}_POLICY_DEFINITION COPY GRANTS\n` + + ` COMMENT='Policy Title\n` + + `description goes here'\n` + + `AS\n` + + ` SELECT 'subpolicy title' AS title\n` + + ` , true AS passing\n` + `;`; function stripComment(body: string): {rest: string; comment: string; viewName: string} { @@ -212,6 +212,10 @@ export class Policy extends SQLBackedRule { } while (rest.replace(/\s/g, '')); } + get isEdited() { + return this.body !== this._raw.savedBody; + } + get body(): string { return ( `CREATE OR REPLACE VIEW rules.${this.viewName} COPY GRANTS\n` + diff --git a/src/scripts/install.py b/src/scripts/install.py index 182255c36..fc92de853 100755 --- a/src/scripts/install.py +++ b/src/scripts/install.py @@ -82,25 +82,7 @@ def read_queries(file): f"CREATE SCHEMA IF NOT EXISTS data", f"CREATE SCHEMA IF NOT EXISTS rules", f"CREATE SCHEMA IF NOT EXISTS results", -] - -CREATE_UDTF_FUNCTIONS = [ - f"USE SCHEMA data", - f"""CREATE OR REPLACE FUNCTION time_slices (n NUMBER, s TIMESTAMP, e TIMESTAMP) - RETURNS TABLE ( slice_start TIMESTAMP, slice_end TIMESTAMP ) - AS ' - SELECT - DATEADD(sec, DATEDIFF(sec, s, e) * SEQ4() / n, s) AS slice_start, - DATEADD(sec, DATEDIFF(sec, s, e) * 1 / n, slice_start) AS slice_end - FROM TABLE(GENERATOR(ROWCOUNT => n)) - ' - ; - """, - f"""CREATE OR REPLACE FUNCTION time_slices_before_t (num_slices NUMBER, seconds_in_slice NUMBER, t TIMESTAMP_NTZ) - RETURNS TABLE ( slice_start TIMESTAMP, slice_end TIMESTAMP ) - AS 'SELECT slice_start, slice_end FROM TABLE(time_slices( num_slices, DATEADD(sec, -seconds_in_slice * num_slices, t), t ))' - ; - """, + f"DROP SCHEMA IF EXISTS public", ] CREATE_TABLES_QUERIES = [ @@ -159,13 +141,14 @@ def parse_snowflake_url(url): return account, region -def login(): +def login(config_account): config = ConfigParser() - if config.read(os.path.expanduser('~/.snowsql/config')) and 'connections' in config: - account = config['connections'].get('accountname') - username = config['connections'].get('username') - password = config['connections'].get('password') - region = config['connections'].get('region') + config_section = f'connections.{config_account}' if config_account else 'connections' + if config.read(os.path.expanduser('~/.snowsql/config')) and config_section in config: + account = config[config_section].get('accountname') + username = config[config_section].get('username') + password = config[config_section].get('password') + region = config[config_section].get('region') else: account = None username = None @@ -248,7 +231,7 @@ def setup_schemas_and_tables(do_attempt, database): do_attempt(f"Use database {database}", f'USE DATABASE {database}') do_attempt("Creating schemas", CREATE_SCHEMAS_QUERIES) do_attempt("Creating alerts & violations tables", CREATE_TABLES_QUERIES) - do_attempt("Creating standard UDTFs", CREATE_UDTF_FUNCTIONS) + do_attempt("Creating standard UDTFs", read_queries('create-udtfs')) do_attempt("Creating standard data views", read_queries('data-views')) @@ -275,11 +258,11 @@ def jira_integration(setup_jira=None): setup_jira = True if uinput.startswith('y') else False if uinput.startswith('n') else None if setup_jira: - jira_user = input("Please enter the username for the SnowAlert user in Jira: ") - jira_password = getpass("Please enter the password for the SnowAlert user in Jira: ") jira_url = input("Please enter the URL for the Jira integration: ") if jira_url[:8] != "https://": jira_url = "https://" + jira_url + jira_user = input("Please enter the username for the SnowAlert user in Jira: ") + jira_password = getpass("Please enter the password for the SnowAlert user in Jira: ") print("Please enter the project tag for the alerts...") print("Note that this should be the text that will prepend the ticket id; if the project is SnowAlert") print("and the tickets will be SA-XXXX, then you should enter 'SA' for this prompt.") @@ -390,8 +373,8 @@ def do_kms_encrypt(kms, *args: str) -> List[str]: ] -def main(admin_role="accountadmin", samples=True, pk_passwd=None, jira=None): - ctx, account, region, do_attempt = login() +def main(admin_role="accountadmin", samples=True, pk_passwd=None, jira=None, config_account=None): + ctx, account, region, do_attempt = login(config_account) do_attempt(f"Use role {admin_role}", f"USE ROLE {admin_role}") if admin_role == "accountadmin": diff --git a/src/scripts/installer-queries/create-udtfs.sql.fmt b/src/scripts/installer-queries/create-udtfs.sql.fmt new file mode 100644 index 000000000..a3ff553d8 --- /dev/null +++ b/src/scripts/installer-queries/create-udtfs.sql.fmt @@ -0,0 +1,18 @@ +USE SCHEMA data; + +CREATE OR REPLACE FUNCTION time_slices (n NUMBER, s TIMESTAMP, e TIMESTAMP) +RETURNS TABLE ( slice_start TIMESTAMP, slice_end TIMESTAMP ) +AS ' + SELECT DATEADD(sec, DATEDIFF(sec, s, e) * SEQ4() / n, s) AS slice_start + , DATEADD(sec, DATEDIFF(sec, s, e) * 1 / n, slice_start) AS slice_end + FROM TABLE(GENERATOR(ROWCOUNT => n)) +' +; + +CREATE OR REPLACE FUNCTION time_slices_before_t (num_slices NUMBER, seconds_in_slice NUMBER, t TIMESTAMP_NTZ) +RETURNS TABLE ( slice_start TIMESTAMP, slice_end TIMESTAMP ) +AS ' + SELECT slice_start + , slice_end FROM TABLE(time_slices( num_slices, DATEADD(sec, -seconds_in_slice * num_slices, t), t )) +' +; diff --git a/src/scripts/installer-queries/data-views.sql.fmt b/src/scripts/installer-queries/data-views.sql.fmt index fdc05c17f..1a1c1b72f 100644 --- a/src/scripts/installer-queries/data-views.sql.fmt +++ b/src/scripts/installer-queries/data-views.sql.fmt @@ -27,7 +27,7 @@ FROM ( CREATE OR REPLACE VIEW data.alerts COPY GRANTS COMMENT='Reflects on existing Alerts, e.g. for writing alert suppressions' AS -SELECT alert:ALERT_ID AS id +SELECT alert:ALERT_ID::VARCHAR AS id , correlation_id , alert_time , event_time @@ -46,6 +46,7 @@ SELECT alert:ALERT_ID AS id , alert:DETECTOR::VARCHAR AS detector , alert:EVENT_DATA::VARIANT AS event_data , alert:SEVERITY::VARCHAR AS severity + , alert:OWNER::VARCHAR AS owner FROM results.alerts ; @@ -70,14 +71,31 @@ SELECT id FROM results.violations ; +CREATE OR REPLACE VIEW data.tags_foj_alerts COPY GRANTS + COMMENT='this view selects all tags, FOJed on alerts generated from queries having those tags' +AS +SELECT tag, alert.* +FROM data.rule_tags AS rule_tag +FULL OUTER JOIN data.alerts AS alert +ON rule_tag.query_id=alert.query_id +; + +CREATE OR REPLACE VIEW data.tags_foj_violations COPY GRANTS + COMMENT='this view selects all tags, FOJed on violations generated from queries having those tags' +AS +SELECT tag, violation.* +FROM data.rule_tags AS rule_tag +FULL OUTER JOIN data.violations AS violation +ON rule_tag.query_id=violation.query_id +; + CREATE OR REPLACE VIEW data.violation_queries_runs COPY GRANTS COMMENT='Stable interface to underlying metadata tables' AS SELECT V:RUN_ID::VARCHAR AS run_id - , V:START_TIME AS start_time - , V:END_TIME AS end_time + , V:START_TIME::TIMESTAMP AS start_time + , V:END_TIME::TIMESTAMP AS end_time , V:ROW_COUNT.INSERTED::INTEGER AS num_violations_created - , V:ROW_COUNT.UPDATED::INTEGER AS num_violations_updated FROM results.run_metadata WHERE V:RUN_TYPE='VIOLATION QUERY' ; @@ -86,8 +104,8 @@ CREATE OR REPLACE VIEW data.violation_suppressions_runs COPY GRANTS COMMENT='Stable interface to underlying metadata tables' AS SELECT V:RUN_ID::VARCHAR AS run_id - , V:START_TIME AS start_time - , V:END_TIME AS end_time + , V:START_TIME::TIMESTAMP AS start_time + , V:END_TIME::TIMESTAMP AS end_time , V:ROW_COUNT.PASSED::INTEGER AS num_violations_passed , V:ROW_COUNT.SUPPRESSED::INTEGER AS num_violations_suppressed FROM results.run_metadata @@ -98,9 +116,10 @@ CREATE OR REPLACE VIEW data.violation_query_rule_runs COPY GRANTS COMMENT='Stable interface to underlying metadata tables' AS SELECT V:RUN_ID::VARCHAR AS run_id - , V:QUERY_NAME AS query_name - , V:START_TIME AS start_time - , V:END_TIME AS end_time + , V:QUERY_NAME::VARCHAR AS query_name + , V:START_TIME::TIMESTAMP AS start_time + , V:END_TIME::TIMESTAMP AS end_time + , V:ERROR AS error , V:ROW_COUNT.INSERTED::INTEGER AS num_violations_created , V:ROW_COUNT.UPDATED::INTEGER AS num_violations_updated FROM results.query_metadata @@ -111,9 +130,9 @@ CREATE OR REPLACE VIEW data.violation_suppression_rule_runs COPY GRANTS COMMENT='Stable interface to underlying metadata tables' AS SELECT V:RUN_ID::VARCHAR AS run_id - , V:QUERY_NAME AS query_name - , V:START_TIME AS start_time - , V:END_TIME AS end_time + , V:QUERY_NAME::VARCHAR AS query_name + , V:START_TIME::TIMESTAMP AS start_time + , V:END_TIME::TIMESTAMP AS end_time , V:ROW_COUNT.SUPPRESSED::INTEGER AS num_violations_suppressed FROM results.query_metadata WHERE V:QUERY_NAME ILIKE '%_VIOLATION_SUPPRESSION' diff --git a/src/scripts/installer-queries/sample-violation-queries.sql.fmt b/src/scripts/installer-queries/sample-violation-queries.sql.fmt index 7da71f681..347f0676d 100644 --- a/src/scripts/installer-queries/sample-violation-queries.sql.fmt +++ b/src/scripts/installer-queries/sample-violation-queries.sql.fmt @@ -11,6 +11,7 @@ SELECT 'SnowAlert' AS environment , NULL AS event_data , 'snowalert' AS detector , 'low' AS severity + , NULL AS owner , 'tcl3emm98h' AS query_id , 'no_violation_queries_in_too_long' AS query_name FROM data.violations_in_days_past_v