Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize signal ingestion pipeline and create perf-test cli util #3337

Merged
merged 6 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/dispatch/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import os
import time

import click
import requests
import uvicorn
from dispatch import __version__, config
from dispatch.config import DISPATCH_UI_URL
from dispatch.enums import UserRoles
from dispatch.plugin.models import PluginInstance

Expand Down Expand Up @@ -689,6 +692,104 @@ def start_tasks(tasks, exclude, eager):
scheduler.start()


@dispatch_scheduler.command("perf-test")
@click.option("--num-instances", default=1000, help="Number of signal instances to send.")
@click.option("--num-workers", default=1000, help="Number of threads to use.")
@click.option(
"--api-endpoint",
default=f"{DISPATCH_UI_URL}/api/v1/default/signals/instances",
required=True,
help="API endpoint to send the signal instances.",
)
@click.option(
"--api-token",
required=True,
help="API token to use.",
)
@click.option(
"--project",
default="Test",
required=True,
help="The Dispatch project to send the instances to",
)
def perf_test(
num_instances: int, num_workers: int, api_endpoint: str, api_token: str, project: str
) -> None:
"""Performance testing utility for creating signal instances."""
import concurrent.futures
from fastapi import status

NUM_SIGNAL_INSTANCES = num_instances
NUM_WORKERS = num_workers

session = requests.Session()
session.headers.update(
{
"Content-Type": "application/json",
"Authorization": f"Bearer {api_token}",
}
)
start_time = time.time()

def _send_signal_instance(
api_endpoint: str,
api_token: str,
session: requests.Session,
signal_instance: dict[str, str],
) -> None:
try:
r = session.post(
api_endpoint,
json=signal_instance,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_token}",
},
)
log.info(f"Response: {r.json()}")
if r.status_code == status.HTTP_401_UNAUTHORIZED:
raise PermissionError(
"Unauthorized. Please check your bearer token. You can find it in the Dev Tools under Request Headers -> Authorization."
)

r.raise_for_status()

except requests.exceptions.RequestException as e:
log.error(f"Unable to send finding. Reason: {e} Response: {r.json() if r else 'N/A'}")
else:
log.info(f"{signal_instance.get('raw', {}).get('id')} created succesfully")

def send_signal_instances(
api_endpoint: str, api_token: str, signal_instances: list[dict[str, str]]
):
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
futures = [
executor.submit(
_send_signal_instance,
api_endpoint=api_endpoint,
api_token=api_token,
session=session,
signal_instance=signal_instance,
)
for signal_instance in signal_instances
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]

log.info(f"\nSent {len(results)} of {NUM_SIGNAL_INSTANCES} signal instances")

signal_instances = [
{
"project": {"name": project},
"raw": {},
},
] * NUM_SIGNAL_INSTANCES

send_signal_instances(api_endpoint, api_token, signal_instances)

elapsed_time = time.time() - start_time
click.echo(f"Elapsed time: {elapsed_time:.2f} seconds")


@dispatch_cli.group("server")
def dispatch_server():
"""Container for all dispatch server commands."""
Expand Down
41 changes: 36 additions & 5 deletions src/dispatch/signal/flows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
import logging
from cachetools import TTLCache

from email_validator import validate_email, EmailNotValidError
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -40,13 +42,13 @@ def signal_instance_create_flow(
db_session.commit()

# we don't need to continue if a filter action took place
if signal_service.filter_signal(db_session=db_session, signal_instance=signal_instance):
if signal_service.filter_signal(
db_session=db_session,
signal_instance=signal_instance,
):
# If the signal was deduplicated, we can assume a case exists,
# and we need to update the corresponding signal message
if (
signal_instance.filter_action == SignalFilterAction.deduplicate
and signal_instance.case.signal_thread_ts # noqa
):
if _should_update_signal_message(signal_instance):
update_signal_message(
db_session=db_session,
signal_instance=signal_instance,
Expand Down Expand Up @@ -226,3 +228,32 @@ def update_signal_message(db_session: Session, signal_instance: SignalInstance)
db_session=db_session,
thread_id=signal_instance.case.signal_thread_ts,
)


_last_nonupdated_signal_cache = TTLCache(maxsize=4, ttl=60)


def _should_update_signal_message(signal_instance: SignalInstance) -> bool:
"""
Determine if the signal message should be updated based on the filter action and time since the last update.
"""
global _last_nonupdated_signal_cache

case_id = str(signal_instance.case_id)

if case_id not in _last_nonupdated_signal_cache:
_last_nonupdated_signal_cache[case_id] = signal_instance
return True

last_nonupdated_signal = _last_nonupdated_signal_cache[case_id]
time_since_last_update = signal_instance.created_at - last_nonupdated_signal.created_at

if (
signal_instance.filter_action == SignalFilterAction.deduplicate
and signal_instance.case.signal_thread_ts # noqa
and time_since_last_update >= timedelta(seconds=5) # noqa
):
_last_nonupdated_signal_cache[case_id] = signal_instance
return True
else:
return False
93 changes: 74 additions & 19 deletions src/dispatch/signal/scheduled.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
:copyright: (c) 2022 by Netflix Inc., see AUTHORS for more
:license: Apache, see LICENSE for more details.
"""
from datetime import datetime, timedelta, timezone
import logging
import queue
from sqlalchemy import asc
from sqlalchemy.orm import scoped_session

from schedule import every
from dispatch.database.core import SessionLocal
from dispatch.database.core import SessionLocal, sessionmaker, engine
from dispatch.scheduler import scheduler
from dispatch.project.models import Project
from dispatch.plugin import service as plugin_service
from dispatch.signal import flows as signal_flows
from dispatch.decorators import scheduled_project_task
from dispatch.decorators import scheduled_project_task, timer
from dispatch.signal.models import SignalInstance

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,27 +52,77 @@ def consume_signals(db_session: SessionLocal, project: Project):
log.debug(signal_instance_data)
log.exception(e)

if signal_instances:
plugin.instance.delete()

@timer
def process_signal_instance(db_session: SessionLocal, signal_instance: SignalInstance) -> None:
try:
signal_flows.signal_instance_create_flow(
db_session=db_session,
signal_instance_id=signal_instance.id,
)
except Exception as e:
log.debug(signal_instance)
log.exception(e)
finally:
db_session.close()


MAX_SIGNAL_INSTANCES = 500
signal_queue = queue.Queue(maxsize=MAX_SIGNAL_INSTANCES)


@timer
@scheduler.add(every(1).minutes, name="signal-process")
@scheduled_project_task
def process_signals(db_session: SessionLocal, project: Project):
"""Processes signals and create cases if appropriate."""
"""
Process signals and create cases if appropriate.

This function processes signals within a given project, creating cases if necessary.
It runs every minute, processing signals that meet certain criteria within the last 5 minutes.
Signals are added to a queue for processing, and then each signal instance is processed.

Args:
db_session: The database session used to query and update the database.
project: The project for which the signals will be processed.

Notes:
The function is decorated with three decorators:
- scheduler.add: schedules the function to run every minute.
- scheduled_project_task: ensures that the function is executed as a scheduled project task.

The function uses a queue to process signal instances in a first-in-first-out (FIFO) order
This ensures that signals are processed in the order they were added to the queue.

A scoped session is used to create a new database session for each signal instance
This ensures that each signal instance is processed using its own separate database connection,
preventing potential issues with concurrent connections.
"""
one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1)
signal_instances = (
db_session.query(SignalInstance)
.filter(SignalInstance.project_id == project.id)
.filter(SignalInstance.filter_action == None) # noqa
.filter(SignalInstance.case_id == None) # noqa
).limit(100)
(
db_session.query(SignalInstance)
.filter(SignalInstance.project_id == project.id)
.filter(SignalInstance.filter_action == None) # noqa
.filter(SignalInstance.case_id == None) # noqa
.filter(SignalInstance.created_at >= one_hour_ago)
)
.order_by(asc(SignalInstance.created_at))
.limit(MAX_SIGNAL_INSTANCES)
)
# Add each signal_instance to the queue for processing
for signal_instance in signal_instances:
log.info(f"Attempting to process the following signal: {signal_instance.id}")
try:
signal_flows.signal_instance_create_flow(
db_session=db_session,
signal_instance_id=signal_instance.id,
)
except Exception as e:
log.debug(signal_instance)
log.exception(e)
signal_queue.put(signal_instance)

schema_engine = engine.execution_options(
schema_translate_map={
None: "dispatch_organization_default",
}
)
session = scoped_session(sessionmaker(bind=schema_engine))

# Process each signal instance in the queue
while not signal_queue.empty():
signal_instance = signal_queue.get()
db_session = session()
process_signal_instance(db_session, signal_instance)
10 changes: 4 additions & 6 deletions src/dispatch/signal/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from pydantic.error_wrappers import ErrorWrapper, ValidationError
from sqlalchemy import asc
from sqlalchemy import desc, asc
from sqlalchemy.orm import Session

from dispatch.auth.models import DispatchUser
Expand Down Expand Up @@ -552,13 +552,11 @@ def filter_signal(*, db_session: Session, signal_instance: SignalInstance) -> bo
SignalInstance.signal_id == signal_instance.signal_id,
SignalInstance.created_at >= default_dedup_window,
SignalInstance.id != signal_instance.id,
SignalInstance.case_id.isnot(None),
SignalInstance.case_id.isnot(None), # noqa
)
.order_by(asc(SignalInstance.created_at))
.all()
.order_by(desc(SignalInstance.created_at))
)

if default_dedup_query:
if default_dedup_query.all():
signal_instance.case_id = default_dedup_query[0].case_id
signal_instance.filter_action = SignalFilterAction.deduplicate
filtered = True
Expand Down