From 267836475145dc2b29d587f7789ac0dbe77d30b6 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Thu, 18 Apr 2024 16:45:17 -0700 Subject: [PATCH] Add type hinting to status.py and status_reporter.py --- grizzly/common/status.py | 918 ++++++++++++++++-------------- grizzly/common/status_reporter.py | 326 ++++++----- grizzly/common/test_status.py | 80 ++- grizzly/session.py | 4 +- 4 files changed, 725 insertions(+), 603 deletions(-) diff --git a/grizzly/common/status.py b/grizzly/common/status.py index 20cefa26..a300742d 100644 --- a/grizzly/common/status.py +++ b/grizzly/common/status.py @@ -2,16 +2,21 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. """Manage Grizzly status reports.""" +from abc import ABC from collections import defaultdict, namedtuple from contextlib import closing, contextmanager from copy import deepcopy +from dataclasses import dataclass from json import dumps, loads from logging import getLogger from os import getpid -from sqlite3 import OperationalError, connect +from pathlib import Path +from sqlite3 import Connection, OperationalError, connect from time import perf_counter, time +from typing import Callable, Dict, Generator, List, Optional, Set, Union, cast -from ..common.utils import grz_tmp +from .reporter import FuzzManagerReporter +from .utils import grz_tmp __all__ = ("ReadOnlyStatus", "ReductionStatus", "Status", "SimpleStatus") __author__ = "Tyson Smith" @@ -33,20 +38,32 @@ LOG = getLogger(__name__) -ProfileEntry = namedtuple("ProfileEntry", "count max min name total") -ResultEntry = namedtuple("ResultEntry", "rid count desc") +@dataclass(eq=False, frozen=True) +class ProfileEntry: + count: int + max: float + min: float + name: str + total: float -def _db_version_check(con, expected=DB_VERSION): +@dataclass(eq=False, frozen=True) +class ResultEntry: + rid: str + count: int + desc: str + + +def _db_version_check(con: Connection, expected: int = DB_VERSION) -> bool: """Perform version check and remove obsolete tables if required. Args: - con (sqlite3.Connection): An open database connection. - expected (int): The latest database version. + con: An open database connection. + expected: The latest database version. Returns: - bool: True if database was reset otherwise False. + True if database was reset otherwise False. """ assert expected > 0 cur = con.cursor() @@ -57,7 +74,7 @@ def _db_version_check(con, expected=DB_VERSION): cur.execute("BEGIN EXCLUSIVE;") # check db version again while locked to avoid race cur.execute("PRAGMA user_version;") - version = cur.fetchone()[0] + version = cast(int, cur.fetchone()[0]) if version < expected: LOG.debug("db version %d < %d", version, expected) # remove ALL tables from the database @@ -73,18 +90,323 @@ def _db_version_check(con, expected=DB_VERSION): return False -class BaseStatus: +class SimpleResultCounter: + __slots__ = ("_count", "_desc", "pid") + + def __init__(self, pid: int) -> None: + assert pid >= 0 + self._count: Dict[str, int] = defaultdict(int) + self._desc: Dict[str, str] = {} + self.pid = pid + + def __iter__(self) -> Generator[ResultEntry, None, None]: + """Yield all result data. + + Args: + None + + Yields: + Contains ID, count and description for each result entry. + """ + for result_id, count in self._count.items(): + if count > 0: + yield ResultEntry(result_id, count, self._desc[result_id]) + + def blockers( + self, iterations: int, iters_per_result: int = 100 + ) -> Generator[ResultEntry, None, None]: + """Any result with an iterations-per-result ratio of less than or equal the + given limit are considered 'blockers'. Results with a count <= 1 are not + included. + + Args: + iterations: Total iterations. + iters_per_result: Iterations-per-result threshold. + + Yields: + ResultEntry: ID, count and description of blocking result. + """ + assert iters_per_result > 0 + if iterations > 0: + for entry in self: + if entry.count > 1 and iterations / entry.count <= iters_per_result: + yield entry + + def count(self, result_id: str, desc: str) -> int: + """ + + Args: + result_id: Result ID. + desc: User friendly description. + + Returns: + Current count for given result_id. + """ + assert isinstance(result_id, str) + self._count[result_id] += 1 + if self._count[result_id] == 1: + assert result_id not in self._desc + self._desc[result_id] = desc + else: + assert result_id in self._desc + return self._count[result_id] + + def get(self, result_id: str) -> Optional[ResultEntry]: + """Get count and description for given result id. + + Args: + result_id: Result ID. + + Returns: + ResultEntry: Count and description. + """ + assert isinstance(result_id, str) + if result_id not in self._count: + assert result_id not in self._desc + return None + assert result_id in self._desc + return ResultEntry(result_id, self._count[result_id], self._desc[result_id]) + + @property + def total(self) -> int: + """Get total count of all results. + + Args: + None + + Returns: + Total result count. + """ + return sum(self._count.values()) + + +class ReadOnlyResultCounter(SimpleResultCounter): + def count(self, result_id, desc): + raise NotImplementedError("Read only!") # pragma: no cover + + @classmethod + def load( + cls, db_file: Path, time_limit: float = 0 + ) -> List["ReadOnlyResultCounter"]: + """Load existing entries for database and populate a ReadOnlyResultCounter. + + Args: + db_file: Database file. + time_limit: Used to filter older entries. + + Returns: + Loaded ReadOnlyResultCounter objects. + """ + assert time_limit >= 0 + with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + try: + # collect entries + if time_limit: + cur.execute( + """SELECT pid, + result_id, + description, + count + FROM results + WHERE timestamp > ?;""", + (time() - time_limit,), + ) + else: + cur.execute( + """SELECT pid, result_id, description, count FROM results""" + ) + entries = cur.fetchall() + except OperationalError as exc: + if not str(exc).startswith("no such table:"): + raise # pragma: no cover + entries = [] + + loaded = {} + for pid, result_id, desc, count in entries: + if pid not in loaded: + loaded[pid] = cls(pid) + loaded[pid]._desc[result_id] = desc # pylint: disable=protected-access + loaded[pid]._count[result_id] = count # pylint: disable=protected-access + + return list(loaded.values()) + + +class ResultCounter(SimpleResultCounter): + __slots__ = ("_db_file", "_frequent", "_initial", "_limit", "last_found") + + def __init__( + self, + pid: int, + db_file: Path, + life_time: int = RESULTS_EXPIRE, + report_limit: int = 0, + ) -> None: + super().__init__(pid) + assert db_file + assert report_limit >= 0 + self._db_file = db_file + self._frequent: Set[str] = set() + self._initial = False + # use zero to disable report limit + self._limit = report_limit + self.last_found = 0.0 + self._init_db(db_file, pid, life_time) + + @staticmethod + def _init_db(db_file: Path, pid: int, life_time: float) -> None: + # prepare database + LOG.debug("resultcounter using db %s", db_file) + with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: + _db_version_check(con) + cur = con.cursor() + with con: + # create table if needed + cur.execute( + """CREATE TABLE IF NOT EXISTS results ( + count INTEGER NOT NULL, + description TEXT NOT NULL, + pid INTEGER NOT NULL, + result_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + PRIMARY KEY(pid, result_id));""" + ) + # remove expired entries + if life_time > 0: + cur.execute( + """DELETE FROM results WHERE timestamp <= ?;""", + (time() - life_time,), + ) + # avoid (unlikely) pid reuse collision + cur.execute("""DELETE FROM results WHERE pid = ?;""", (pid,)) + # remove results for jobs that have been removed + try: + cur.execute( + """DELETE FROM results + WHERE pid NOT IN (SELECT pid FROM status);""" + ) + except OperationalError as exc: + if not str(exc).startswith("no such table:"): + raise # pragma: no cover + + def count(self, result_id: str, desc: str) -> int: + """Count results and write results to the database. + + Args: + result_id: Result ID. + desc: User friendly description. + + Returns: + Local count for given result_id. + """ + super().count(result_id, desc) + self._initial = False + timestamp = time() + with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + with con: + cur.execute( + """UPDATE results + SET timestamp = ?, + count = ? + WHERE pid = ? + AND result_id = ?;""", + (timestamp, self._count[result_id], self.pid, result_id), + ) + if cur.rowcount < 1: + cur.execute( + """SELECT pid FROM results WHERE result_id = ?;""", + (result_id,), + ) + self._initial = cur.fetchone() is None + cur.execute( + """INSERT INTO results( + pid, + result_id, + description, + timestamp, + count) + VALUES (?, ?, ?, ?, ?);""", + (self.pid, result_id, desc, timestamp, self._count[result_id]), + ) + self.last_found = timestamp + return self._count[result_id] + + def is_frequent(self, result_id: str) -> bool: + """Scan all results including results from other running instances + to determine if the limit has been exceeded. Local count must be >1 before + limit is checked. + + Args: + result_id: Result ID. + + Returns: + True if limit has been exceeded otherwise False. + """ + assert isinstance(result_id, str) + if self._limit < 1: + return False + if result_id in self._frequent: + return True + # get local total + total = self._count.get(result_id, 0) + # only check the db for parallel results if + # - result has been found locally more than once + # - limit has not been exceeded locally + if self._limit >= total > 1: + with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + # look up total count from all processes + cur.execute( + """SELECT COALESCE(SUM(count), 0) + FROM results WHERE result_id = ?;""", + (result_id,), + ) + global_total = cur.fetchone()[0] + assert global_total >= total + total = global_total + if total > self._limit: + self._frequent.add(result_id) + return True + return False + + def is_initial(self) -> bool: + """Check if most recently added result is unique (includes parallel + instances). + + Args: + result_id: Result ID. + + Returns: + True if only seen once and only by this instance otherwise False. + """ + return self._initial + + def mark_frequent(self, result_id: str) -> None: + """Mark given results ID as frequent locally. + + Args: + result_id: Result ID. + + Returns: + None + """ + assert isinstance(result_id, str) + if result_id not in self._frequent: + self._frequent.add(result_id) + + +class BaseStatus(ABC): """Record and manage status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (None): Placeholder for result data. - start_time (float): Start time of session. - test_name (str): Current test name. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + start_time: Start time of session. + test_name: Current test name. """ __slots__ = ( @@ -93,28 +415,33 @@ class BaseStatus: "iteration", "log_size", "pid", - "results", "start_time", "test_name", ) - def __init__(self, pid, start_time, ignored=0, iteration=0, log_size=0): + def __init__( + self, + pid: int, + start_time: float, + ignored: int = 0, + iteration: int = 0, + log_size: int = 0, + ) -> None: assert pid >= 0 assert ignored >= 0 assert iteration >= 0 assert log_size >= 0 assert isinstance(start_time, float) assert start_time >= 0 - self._profiles = {} + self._profiles: Dict[str, Dict[str, Union[float, int]]] = {} self.ignored = ignored self.iteration = iteration self.log_size = log_size self.pid = pid - self.results = None self.start_time = start_time self.test_name = None - def profile_entries(self): + def profile_entries(self) -> Generator[ProfileEntry, None, None]: """Used to retrieve profiling data. Args: @@ -125,30 +452,34 @@ def profile_entries(self): """ for name, entry in self._profiles.items(): yield ProfileEntry( - entry["count"], entry["max"], entry["min"], name, entry["total"] + cast(int, entry["count"]), + entry["max"], + entry["min"], + name, + entry["total"], ) @property - def rate(self): + def rate(self) -> float: """Calculate the average iteration rate in seconds. Args: None Returns: - float: Number of iterations performed per second. + Number of iterations performed per second. """ return self.iteration / self.runtime if self.runtime else 0 @property - def runtime(self): + def runtime(self) -> float: """Calculate the number of seconds since start() was called. Args: None Returns: - int: Total runtime in seconds. + Total runtime in seconds. """ return max(time() - self.start_time, 0) @@ -157,34 +488,45 @@ class ReadOnlyStatus(BaseStatus): """Store status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. results (None): Placeholder for result data. start_time (float): Start time of session. test_name (str): Test name. timestamp (float): Last time data was saved to database. """ - __slots__ = ("timestamp",) + __slots__ = ("results", "timestamp") - def __init__(self, pid, start_time, timestamp, ignored=0, iteration=0, log_size=0): + def __init__( + self, + pid: int, + start_time: float, + timestamp: float, + results: Optional[ReadOnlyResultCounter] = None, + ignored: int = 0, + iteration: int = 0, + log_size: int = 0, + ) -> None: super().__init__( pid, start_time, ignored=ignored, iteration=iteration, log_size=log_size ) - assert isinstance(timestamp, float) assert timestamp >= start_time + self.results = results or ReadOnlyResultCounter(pid) self.timestamp = timestamp @classmethod - def load_all(cls, db_file, time_limit=300): + def load_all( + cls, db_file: Path, time_limit: float = 300 + ) -> Generator["ReadOnlyStatus", None, None]: """Load all status reports found in `db_file`. Args: - db_file (Path): Database containing status data. - time_limit (int): Filter entries by age. Use zero for no limit. + db_file: Database containing status data. + time_limit: Filter entries by age. Use zero for no limit. Yields: ReadOnlyStatus: Successfully loaded objects. @@ -210,38 +552,35 @@ def load_all(cls, db_file, time_limit=300): except OperationalError as exc: if not str(exc).startswith("no such table:"): raise # pragma: no cover - entries = () + entries = [] # Load all results - results = ReadOnlyResultCounter.load(db_file, time_limit=0) + results = {} + for counter in ReadOnlyResultCounter.load(db_file, time_limit=0): + results[counter.pid] = counter + for entry in entries: status = cls( entry[0], entry[5], entry[6], + results=results.get(entry[0]), ignored=entry[2], iteration=entry[3], log_size=entry[4], ) status._profiles = loads(entry[1]) - for counter in results: - if counter.pid == status.pid: - status.results = counter - break - else: - # no existing ReadOnlyResultCounter with matching pid found - status.results = ReadOnlyResultCounter(status.pid) yield status @property - def runtime(self): + def runtime(self) -> float: """Calculate total runtime in seconds relative to 'timestamp'. Args: None Returns: - int: Total runtime in seconds. + Total runtime in seconds. """ return self.timestamp - self.start_time @@ -250,29 +589,31 @@ class SimpleStatus(BaseStatus): """Record and manage status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (None): Placeholder for result data. - start_time (float): Start time of session. - test_name (str): Current test name. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + results: + start_time: Start time of session. + test_name: Current test name. """ - def __init__(self, pid, start_time): + __slots__ = ("results",) + + def __init__(self, pid: int, start_time: float) -> None: super().__init__(pid, start_time) self.results = SimpleResultCounter(pid) @classmethod - def start(cls): + def start(cls) -> "SimpleStatus": """Create a unique SimpleStatus object. Args: None Returns: - SimpleStatus: Active status report. + Active status report. """ return cls(getpid(), time()) @@ -281,30 +622,30 @@ class Status(BaseStatus): """Status records status information and stores it in a database. Attributes: - _db_file (Path): Database file containing data. - _enable_profiling (bool): Profiling support status. - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (ResultCounter): Results data. Used to count occurrences of results. - start_time (float): Start time of session. - test_name (str): Current test name. - timestamp (float): Last time data was saved to database. + _db_file: Database file containing data. + _enable_profiling: Profiling support status. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + results: Results data. Used to count occurrences of results. + start_time: Start time of session. + test_name: Current test name. + timestamp: Last time data was saved to database. """ - __slots__ = ("_db_file", "_enable_profiling", "timestamp") + __slots__ = ("_db_file", "_enable_profiling", "results", "timestamp") def __init__( self, - pid, - start_time, - db_file, - enable_profiling=False, - life_time=REPORTS_EXPIRE, - report_limit=0, - ): + pid: int, + start_time: float, + db_file: Path, + enable_profiling: bool = False, + life_time: float = REPORTS_EXPIRE, + report_limit: int = 0, + ) -> None: super().__init__(pid, start_time) assert life_time >= 0 assert report_limit >= 0 @@ -315,7 +656,7 @@ def __init__( self.timestamp = start_time @staticmethod - def _init_db(db_file, pid, life_time): + def _init_db(db_file: Path, pid: int, life_time: float) -> None: # prepare database LOG.debug("status using db %s", db_file) with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: @@ -343,11 +684,11 @@ def _init_db(db_file, pid, life_time): cur.execute("""DELETE FROM status WHERE pid = ?;""", (pid,)) @contextmanager - def measure(self, name): + def measure(self, name: str) -> Generator[None, None, None]: """Used to simplify collecting profiling data. Args: - name (str): Used to group the entries. + name: Used to group the entries. Yields: None @@ -359,13 +700,13 @@ def measure(self, name): else: yield - def record(self, name, duration): + def record(self, name: str, duration: float) -> None: """Used to add profiling data. This is intended to be used to make rough calculations to identify major configuration issues. Args: - name (str): Used to group the entries. - duration (int, float): Stored to be later used for measurements. + name: Used to group the entries. + duration: Stored to be later used for measurements. Returns: None @@ -388,22 +729,23 @@ def record(self, name, duration): "total": duration, } - def report(self, force=False, report_rate=REPORT_RATE): + def report(self, force: bool = False, report_rate: int = REPORT_RATE) -> bool: """Write status report to database. Reports are only written periodically. It is limited by `report_rate`. The specified number of seconds must elapse before another write will be performed unless `force` is True. Args: - force (bool): Ignore report frequency limiting. - report_rate (int): Minimum number of seconds between writes to database. + force: Ignore report frequency limiting. + report_rate: Minimum number of seconds between writes to database. Returns: - bool: True if the report was successful otherwise False. + True if the report was successful otherwise False. """ now = time() if self.results.last_found > self.timestamp: LOG.debug("results have been found since last report, force update") force = True + assert report_rate >= 0 if not force and now < (self.timestamp + report_rate): return False assert self.start_time <= now @@ -456,13 +798,15 @@ def report(self, force=False, report_rate=REPORT_RATE): return True @classmethod - def start(cls, db_file, enable_profiling=False, report_limit=0): + def start( + cls, db_file: Path, enable_profiling: bool = False, report_limit: int = 0 + ) -> "Status": """Create a unique Status object. Args: - db_file (Path): Database containing status data. - enable_profiling (bool): Record profiling data. - report_limit (int): Number of times a unique result will be reported. + db_file: Database containing status data. + enable_profiling: Record profiling data. + report_limit: Number of times a unique result will be reported. Returns: Status: Active status report. @@ -478,285 +822,6 @@ def start(cls, db_file, enable_profiling=False, report_limit=0): return status -class SimpleResultCounter: - __slots__ = ("_count", "_desc", "pid") - - def __init__(self, pid): - assert pid >= 0 - self._count = defaultdict(int) - self._desc = {} - self.pid = pid - - def __iter__(self): - """Yield all result data. - - Args: - None - - Yields: - ResultEntry: Contains ID, count and description for each result entry. - """ - for result_id, count in self._count.items(): - if count > 0: - yield ResultEntry(result_id, count, self._desc.get(result_id, None)) - - def blockers(self, iterations, iters_per_result=100): - """Any result with an iterations-per-result ratio of less than or equal the - given limit are considered 'blockers'. Results with a count <= 1 are not - included. - - Args: - iterations (int): Total iterations. - iters_per_result (int): Iterations-per-result threshold. - - Yields: - ResultEntry: ID, count and description of blocking result. - """ - assert iters_per_result > 0 - if iterations > 0: - for entry in self: - if entry.count > 1 and iterations / entry.count <= iters_per_result: - yield entry - - def count(self, result_id, desc): - """ - - Args: - result_id (str): Result ID. - desc (str): User friendly description. - - Returns: - int: Current count for given result_id. - """ - assert isinstance(result_id, str) - self._count[result_id] += 1 - if result_id not in self._desc: - self._desc[result_id] = desc - return self._count[result_id] - - def get(self, result_id): - """Get count and description for given result id. - - Args: - result_id (str): Result ID. - - Returns: - ResultEntry: Count and description. - """ - assert isinstance(result_id, str) - return ResultEntry( - result_id, self._count.get(result_id, 0), self._desc.get(result_id, None) - ) - - @property - def total(self): - """Get total count of all results. - - Args: - None - - Returns: - int: Total result count. - """ - return sum(self._count.values()) - - -class ReadOnlyResultCounter(SimpleResultCounter): - def count(self, result_id, desc): - raise NotImplementedError("Read only!") # pragma: no cover - - @classmethod - def load(cls, db_file, time_limit=0): - """Load existing entries for database and populate a ReadOnlyResultCounter. - - Args: - db_file (Path): Database file. - time_limit (int): Used to filter older entries. - - Returns: - list: Loaded ReadOnlyResultCounter objects. - """ - assert time_limit >= 0 - with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - try: - # collect entries - if time_limit: - cur.execute( - """SELECT pid, - result_id, - description, - count - FROM results - WHERE timestamp > ?;""", - (time() - time_limit,), - ) - else: - cur.execute( - """SELECT pid, result_id, description, count FROM results""" - ) - entries = cur.fetchall() - except OperationalError as exc: - if not str(exc).startswith("no such table:"): - raise # pragma: no cover - entries = () - - loaded = {} - for pid, result_id, desc, count in entries: - if pid not in loaded: - loaded[pid] = cls(pid) - loaded[pid]._desc[result_id] = desc # pylint: disable=protected-access - loaded[pid]._count[result_id] = count # pylint: disable=protected-access - - return list(loaded.values()) - - -class ResultCounter(SimpleResultCounter): - __slots__ = ("_db_file", "_frequent", "_limit", "last_found") - - def __init__(self, pid, db_file, life_time=RESULTS_EXPIRE, report_limit=0): - super().__init__(pid) - assert db_file - assert report_limit >= 0 - self._db_file = db_file - self._frequent = set() - # use zero to disable report limit - self._limit = report_limit - self.last_found = 0 - self._init_db(db_file, pid, life_time) - - @staticmethod - def _init_db(db_file, pid, life_time): - # prepare database - LOG.debug("resultcounter using db %s", db_file) - with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: - _db_version_check(con) - cur = con.cursor() - with con: - # create table if needed - cur.execute( - """CREATE TABLE IF NOT EXISTS results ( - count INTEGER NOT NULL, - description TEXT NOT NULL, - pid INTEGER NOT NULL, - result_id TEXT NOT NULL, - timestamp INTEGER NOT NULL, - PRIMARY KEY(pid, result_id));""" - ) - # remove expired entries - if life_time > 0: - cur.execute( - """DELETE FROM results WHERE timestamp <= ?;""", - (time() - life_time,), - ) - # avoid (unlikely) pid reuse collision - cur.execute("""DELETE FROM results WHERE pid = ?;""", (pid,)) - # remove results for jobs that have been removed - try: - cur.execute( - """DELETE FROM results - WHERE pid NOT IN (SELECT pid FROM status);""" - ) - except OperationalError as exc: - if not str(exc).startswith("no such table:"): - raise # pragma: no cover - - def count(self, result_id, desc): - """Count results and write results to the database. - - Args: - result_id (str): Result ID. - desc (str): User friendly description. - - Returns: - tuple (int, bool): Local count and initial report (includes - parallel instances) for given result_id. - """ - super().count(result_id, desc) - timestamp = time() - initial = False - with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - with con: - cur.execute( - """UPDATE results - SET timestamp = ?, - count = ? - WHERE pid = ? - AND result_id = ?;""", - (timestamp, self._count[result_id], self.pid, result_id), - ) - if cur.rowcount < 1: - cur.execute( - """SELECT pid FROM results WHERE result_id = ?;""", - (result_id,), - ) - initial = cur.fetchone() is None - cur.execute( - """INSERT INTO results( - pid, - result_id, - description, - timestamp, - count) - VALUES (?, ?, ?, ?, ?);""", - (self.pid, result_id, desc, timestamp, self._count[result_id]), - ) - self.last_found = timestamp - return self._count[result_id], initial - - def is_frequent(self, result_id): - """Scan all results including results from other running instances - to determine if the limit has been exceeded. Local count must be >1 before - limit is checked. - - Args: - result_id (str): Result ID. - - Returns: - bool: True if limit has been exceeded otherwise False. - """ - assert isinstance(result_id, str) - if self._limit < 1: - return False - if result_id in self._frequent: - return True - # get local total - total = self._count.get(result_id, 0) - # only check the db for parallel results if - # - result has been found locally more than once - # - limit has not been exceeded locally - if self._limit >= total > 1: - with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - # look up total count from all processes - cur.execute( - """SELECT COALESCE(SUM(count), 0) - FROM results WHERE result_id = ?;""", - (result_id,), - ) - global_total = cur.fetchone()[0] - assert global_total >= total - total = global_total - if total > self._limit: - self._frequent.add(result_id) - return True - return False - - def mark_frequent(self, result_id): - """Mark given results ID as frequent locally. - - Args: - result_id (str): Result ID. - - Returns: - None - """ - assert isinstance(result_id, str) - if result_id not in self._frequent: - self._frequent.add(result_id) - - ReductionStep = namedtuple( "ReductionStep", "name, duration, successes, attempts, size, iterations" ) @@ -767,22 +832,23 @@ class ReductionStatus: def __init__( self, - strategies=None, - testcase_size_cb=None, - crash_id=None, - db_file=None, - pid=None, - tool=None, - life_time=REPORTS_EXPIRE, - ): + strategies: Optional[List[str]] = None, + testcase_size_cb: Optional[Callable[[], int]] = None, + crash_id: Optional[int] = None, + db_file: Optional[Path] = None, + pid: Optional[int] = None, + tool: Optional[str] = None, + life_time: float = REPORTS_EXPIRE, + ) -> None: """Initialize a ReductionStatus instance. Arguments: - strategies (list(str)): List of strategies to be run. - testcase_size_cb (callable): Callback to get testcase size - crash_id (int): CrashManager ID of original testcase - db_file (Path): Database file containing data. - tool (str): The tool name used for reporting to FuzzManager. + strategies: List of strategies to be run. + testcase_size_cb: Callback to get testcase size. + crash_id: CrashManager ID of original testcase. + db_file: Database file containing data. + tool: The tool name used for reporting to FuzzManager. + life_time: """ self.analysis = {} self.attempts = 0 @@ -842,23 +908,23 @@ def __init__( @classmethod def start( cls, - db_file, - strategies=None, - testcase_size_cb=None, - crash_id=None, - tool=None, - ): + db_file: Path, + strategies: Optional[List[str]] = None, + testcase_size_cb: Optional[Callable[[], int]] = None, + crash_id: Optional[int] = None, + tool: Optional[str] = None, + ) -> "ReductionStatus": """Create a unique ReductionStatus object. Args: - db_file (Path): Database containing status data. - strategies (list(str)): List of strategies to be run. - testcase_size_cb (callable): Callback to get testcase size - crash_id (int): CrashManager ID of original testcase - tool (str): The tool name used for reporting to FuzzManager. + db_file: Database containing status data. + strategies: List of strategies to be run. + testcase_size_cb: Callback to get testcase size. + crash_id: CrashManager ID of original testcase. + tool: The tool name used for reporting to FuzzManager. Returns: - ReductionStatus: Active status report. + Active status report. """ status = cls( crash_id=crash_id, @@ -871,17 +937,17 @@ def start( status.report(force=True) return status - def report(self, force=False, report_rate=REPORT_RATE): + def report(self, force: bool = False, report_rate: float = REPORT_RATE) -> bool: """Write status report to database. Reports are only written periodically. It is limited by `report_rate`. The specified number of seconds must elapse before another write will be performed unless `force` is True. Args: - force (bool): Ignore report frequently limiting. - report_rate (int): Minimum number of seconds between writes. + force: Ignore report frequently limiting. + report_rate: Minimum number of seconds between writes. Returns: - bool: Returns true if the report was successful otherwise false. + True if the report was successful otherwise false. """ now = time() if not force and now < (self.timestamp + report_rate): @@ -980,16 +1046,18 @@ def report(self, force=False, report_rate=REPORT_RATE): return True @classmethod - def load_all(cls, db_file, time_limit=300): + def load_all( + cls, db_file: Path, time_limit: float = 300 + ) -> Generator["ReductionStatus", None, None]: """Load all reduction status reports found in `db_file`. Args: - db_file (Path): Database containing status data. - time_limit (int): Only include entries with a timestamp that is within the - given number of seconds. Use zero for no limit. + db_file: Database containing status data. + time_limit: Only include entries with a timestamp that is within the + given number of seconds. Use zero for no limit. Yields: - Status: Successfully loaded read-only status objects. + Successfully loaded read-only status objects. """ assert time_limit >= 0 with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: @@ -1022,7 +1090,7 @@ def load_all(cls, db_file, time_limit=300): except OperationalError as exc: if not str(exc).startswith("no such table:"): raise # pragma: no cover - entries = () + entries = [] for entry in entries: pid = entry[0] @@ -1051,7 +1119,7 @@ def load_all(cls, db_file, time_limit=300): status.last_reports = loads(entry[15]) yield status - def _testcase_size(self): + def _testcase_size(self) -> int: if self._db_file is None: return self._current_size return self._testcase_size_cb() @@ -1112,13 +1180,13 @@ def original(self): def record( self, - name, - duration=None, - iterations=None, - attempts=None, - successes=None, - report=True, - ): + name: str, + duration: Optional[float] = None, + iterations: Optional[int] = None, + attempts: Optional[int] = None, + successes: Optional[int] = None, + report: bool = True, + ) -> None: """Record reduction status for a given point in time: - name of the milestone (eg. init, strategy name completed) @@ -1195,13 +1263,13 @@ def serialize(sub): return _MilestoneTimer() @contextmanager - def measure(self, name, report=True): + def measure(self, name: str, report: bool = True) -> Generator[None, None, None]: """Time and record the period leading up to a reduction milestone. eg. a strategy being run. Arguments: - name (str): name of milestone - report (bool): Automatically force a report. + name: name of milestone + report: Automatically force a report. Yields: None @@ -1222,23 +1290,25 @@ def measure(self, name, report=True): report=report, ) - def copy(self): + def copy(self) -> "ReductionStatus": """Create a deep copy of this instance. Arguments: None Returns: - ReductionStatus: Clone of self + Clone of self """ return deepcopy(self) - def add_to_reporter(self, reporter, expected=True): + def add_to_reporter( + self, reporter: FuzzManagerReporter, expected: bool = True + ) -> None: """Add the reducer status to reported metadata for the given reporter. Arguments: - reporter (FuzzManagerReporter): Reporter to update. - expected (bool): Add detailed stats. + reporter: Reporter to update. + expected: Add detailed stats. Returns: None diff --git a/grizzly/common/status_reporter.py b/grizzly/common/status_reporter.py index 0fe20264..7d8585ec 100644 --- a/grizzly/common/status_reporter.py +++ b/grizzly/common/status_reporter.py @@ -20,6 +20,7 @@ from re import match from re import sub as re_sub from time import gmtime, localtime, strftime +from typing import Dict, List, Optional, Set, Tuple from psutil import cpu_count, cpu_percent, disk_usage, virtual_memory @@ -39,6 +40,116 @@ LOG = getLogger(__name__) +class TracebackReport: + """Read Python tracebacks from log files and store it in a manner that is helpful + when generating reports. + """ + + MAX_LINES = 16 # should be no less than 6 + READ_LIMIT = 0x20000 # 128KB + + def __init__( + self, + log_file: Path, + lines: List[str], + is_kbi: bool = False, + prev_lines: Optional[List[str]] = None, + ) -> None: + self.is_kbi = is_kbi + self.lines = lines + self.log_file = log_file + self.prev_lines = prev_lines or [] + + @classmethod + def from_file( + cls, log_file: Path, max_preceding: int = 5, ignore_kbi: bool = False + ) -> Optional["TracebackReport"]: + """Create TracebackReport from a text file containing a Python traceback. + Only the first traceback in the file will be parsed. + + Args: + log_file: File to parse. + max_preceding: Number of lines to collect leading up to the traceback. + ignore_kbi: Skip/ignore KeyboardInterrupt. + + Returns: + TracebackReport: Contains data from log_file. + """ + token_bytes = b"Traceback (most recent call last):" + assert len(token_bytes) < cls.READ_LIMIT + try: + with log_file.open("rb") as in_fp: + for chunk in iter(partial(in_fp.read, cls.READ_LIMIT), b""): + idx = chunk.find(token_bytes) + if idx > -1: + # calculate offset of data in the file + pos = in_fp.tell() - len(chunk) + idx + break + if len(chunk) == cls.READ_LIMIT: + # seek back to avoid missing beginning of token + in_fp.seek(len(token_bytes) * -1, SEEK_CUR) + else: + # no traceback here, move along + return None + # seek back 2KB to collect preceding lines + in_fp.seek(max(pos - 2048, 0)) + raw_data = in_fp.read(cls.READ_LIMIT) + except OSError: # pragma: no cover + # in case the file goes away + return None + + data = raw_data.decode(encoding="ascii", errors="ignore").splitlines() + token = token_bytes.decode() + is_kbi = False + tb_start = None + tb_end = None + line_count = len(data) + for line_num, log_line in enumerate(data): + if tb_start is None and token in log_line: + tb_start = line_num + continue + if tb_start is not None: + log_line = log_line.strip() + if not log_line: + # stop at first empty line + tb_end = min(line_num, line_count) + break + if match(r"^\w+(\.\w+)*\:\s|^\w+(Interrupt|Error)$", log_line): + is_kbi = log_line.startswith("KeyboardInterrupt") + if is_kbi and ignore_kbi: + # ignore this exception since it is a KeyboardInterrupt + return None + # stop after error message + tb_end = min(line_num + 1, line_count) + break + assert tb_start is not None + if max_preceding > 0: + prev_start = max(tb_start - max_preceding, 0) + prev_lines = data[prev_start:tb_start] + else: + prev_lines = None + if tb_end is None: + # limit if the end is not identified (failsafe) + tb_end = max(line_count, cls.MAX_LINES) + if tb_end - tb_start > cls.MAX_LINES: + # add first entry + lines = data[tb_start : tb_start + 3] + lines += ["<--- TRACEBACK TRIMMED--->"] + # add end entries + lines += data[tb_end - (cls.MAX_LINES - 3) : tb_end] + else: + lines = data[tb_start:tb_end] + return cls(log_file, lines, is_kbi=is_kbi, prev_lines=prev_lines) + + def __len__(self) -> int: + return len(str(self)) + + def __str__(self) -> str: + return "\n".join( + [f"Log: '{self.log_file.name}'"] + self.prev_lines + self.lines + ) + + class StatusReporter: """Read and merge Grizzly status reports, including tracebacks if found. Output is a single textual report, e.g. for submission to EC2SpotManager. @@ -50,26 +161,31 @@ class StatusReporter: SUMMARY_LIMIT = 4095 # summary output must be no more than 4KB TIME_LIMIT = 120 # ignore older reports - def __init__(self, reports, tracebacks=None): + def __init__(self, reports: List[ReadOnlyStatus], tracebacks=None) -> None: self.reports = reports self.tracebacks = tracebacks @property - def has_results(self): - return any(x.results.total for x in self.reports) + def has_results(self) -> bool: + return any(x.results.total for x in self.reports if x.results is not None) @classmethod - def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): + def load( + cls, + db_file: Path, + tb_path: Optional[Path] = None, + time_limit: float = TIME_LIMIT, + ) -> "StatusReporter": """Read Grizzly status reports and create a StatusReporter object. Args: - db_file (str): Status data file to load. - tb_path (Path): Directory to scan for files containing Python tracebacks. - time_limit (int): Only include entries with a timestamp that is within the - given number of seconds. Use zero for no limit. + db_file: Status data file to load. + tb_path: Directory to scan for files containing Python tracebacks. + time_limit: Only include entries with a timestamp that is within the + given number of seconds. Use zero for no limit. Returns: - StatusReporter: Contains available status reports and traceback reports. + Available status reports and traceback reports. """ return cls( list(ReadOnlyStatus.load_all(db_file, time_limit=time_limit)), @@ -77,9 +193,9 @@ def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): ) @staticmethod - def format_entries(entries): + def format_entries(entries: List[Tuple[str, Optional[str]]]) -> str: """Generate formatted output from (label, body) pairs. - Each entry must have a label and an optional body. + Each entry has a label and an optional body. Example: entries = ( @@ -95,10 +211,10 @@ def format_entries(entries): third : 3.0 Args: - entries list(2-tuple(str, str)): Data to merge. + entries: Data to merge. Returns: - str: Formatted output. + Formatted output. """ label_lengths = tuple(len(x[0]) for x in entries if x[1]) max_len = max(label_lengths) if label_lengths else 0 @@ -110,18 +226,18 @@ def format_entries(entries): out.append(label) return "\n".join(out) - def results(self, max_len=85): + def results(self, max_len: int = 85) -> str: """Merged and generate formatted output from results. Args: - max_len (int): Maximum length of result description. + max_len: Maximum length of result description. Returns: - str: A formatted report. + A formatted report. """ - blockers = set() - counts = defaultdict(int) - descs = {} + blockers: Set[str] = set() + counts: Dict[str, int] = defaultdict(int) + descs: Dict[str, str] = {} # calculate totals for report in self.reports: for result in report.results: @@ -129,7 +245,7 @@ def results(self, max_len=85): counts[result.rid] += result.count blockers.update(x.rid for x in report.results.blockers(report.iteration)) # generate output - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] for rid, count in sorted(counts.items(), key=lambda x: x[1], reverse=True): desc = descs[rid] # trim long descriptions @@ -144,19 +260,19 @@ def results(self, max_len=85): entries.append(("", None)) return self.format_entries(entries) - def specific(self, iters_per_result=100): + def specific(self, iters_per_result: int = 100) -> str: """Merged and generate formatted output from status reports. Args: - iters_per_result (int): Threshold for warning of potential blockers. + iters_per_result: Threshold for warning of potential blockers. Returns: - str: A formatted report. + A formatted report. """ if not self.reports: return "No status reports available" self.reports.sort(key=lambda x: x.start_time) - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] for report in self.reports: label = ( f"PID {report.pid} started at " @@ -213,25 +329,25 @@ def specific(self, iters_per_result=100): def summary( self, - rate=True, - runtime=True, - sysinfo=False, - timestamp=False, - iters_per_result=100, - ): + rate: bool = True, + runtime: bool = True, + sysinfo: bool = False, + timestamp: bool = False, + iters_per_result: int = 100, + ) -> str: """Merge and generate a summary from status reports. Args: - rate (bool): Include iteration rate. - runtime (bool): Include total runtime in output. - sysinfo (bool): Include system info (CPU, disk, RAM... etc) in output. - timestamp (bool): Include time stamp in output. - iters_per_result (int): Threshold for warning of potential blockers. + rate: Include iteration rate. + runtime: Include total runtime in output. + sysinfo: Include system info (CPU, disk, RAM... etc) in output. + timestamp: Include time stamp in output. + iters_per_result: Threshold for warning of potential blockers. Returns: - str: A summary of merged reports. + A summary of merged reports. """ - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] # Job specific status if self.reports: # calculate totals @@ -262,7 +378,7 @@ def summary( if total_iters: total_results = sum(results) result_pct = total_results / total_iters * 100 - buckets = set() + buckets: Set[str] = set() for report in self.reports: buckets.update(x.rid for x in report.results) disp = [f"{total_results} ({len(buckets)})"] @@ -320,15 +436,15 @@ def summary( return msg @staticmethod - def _merge_tracebacks(tracebacks, size_limit): + def _merge_tracebacks(tracebacks: List[TracebackReport], size_limit: int) -> str: """Merge traceback without exceeding size_limit. Args: - tracebacks (iterable): TracebackReport to merge. - size_limit (int): Maximum size in bytes of output. + tracebacks: TracebackReports to merge. + size_limit: Maximum size in bytes of output. Returns: - str: merged tracebacks. + Merged tracebacks. """ txt = [] txt.append(f"\n\nWARNING Tracebacks ({len(tracebacks)}) detected!") @@ -341,14 +457,14 @@ def _merge_tracebacks(tracebacks, size_limit): return "\n".join(txt) @staticmethod - def _sys_info(): + def _sys_info() -> List[Tuple[str, str]]: """Collect system information. Args: None Returns: - list(tuple): System information in tuples (label, display data). + System information in tuples (label, display data). """ entries = [] @@ -369,7 +485,7 @@ def _sys_info(): disp = [] mem_usage = virtual_memory() if mem_usage.available < 1_073_741_824: # < 1GB - disp.append(f"{int(mem_usage.available / 1_048_576)}MB") + disp.append(f"{mem_usage.available // 1_048_576}MB") else: disp.append(f"{mem_usage.available / 1_073_741_824:0.1f}GB") disp.append(f" of {mem_usage.total / 1_073_741_824:0.1f}GB free") @@ -379,7 +495,7 @@ def _sys_info(): disp = [] usage = disk_usage("/") if usage.free < 1_073_741_824: # < 1GB - disp.append(f"{int(usage.free / 1_048_576)}MB") + disp.append(f"{usage.free // 1_048_576}MB") else: disp.append(f"{usage.free / 1_073_741_824:0.1f}GB") disp.append(f" of {usage.total / 1_073_741_824:0.1f}GB free") @@ -388,17 +504,18 @@ def _sys_info(): return entries @staticmethod - def _tracebacks(path, ignore_kbi=True, max_preceding=5): + def _tracebacks( + path: Path, ignore_kbi: bool = True, max_preceding: int = 5 + ) -> List[TracebackReport]: """Search screen logs for tracebacks. Args: - path (Path): Directory containing log files. - ignore_kbi (bool): Do not include KeyboardInterrupts in results - max_preceding (int): Maximum number of lines preceding traceback to - include. + path: Directory containing log files. + ignore_kbi: Do not include KeyboardInterrupts in results + max_preceding: Maximum number of lines preceding traceback to include. Returns: - list: A list of TracebackReports. + TracebackReports. """ tracebacks = [] for screen_log in (x for x in path.glob("screenlog.*") if x.is_file()): @@ -410,111 +527,6 @@ def _tracebacks(path, ignore_kbi=True, max_preceding=5): return tracebacks -class TracebackReport: - """Read Python tracebacks from log files and store it in a manner that is helpful - when generating reports. - """ - - MAX_LINES = 16 # should be no less than 6 - READ_LIMIT = 0x20000 # 128KB - - def __init__(self, log_file, lines, is_kbi=False, prev_lines=None): - assert isinstance(lines, list) - assert isinstance(log_file, Path) - assert isinstance(prev_lines, list) or prev_lines is None - self.is_kbi = is_kbi - self.lines = lines - self.log_file = log_file - self.prev_lines = prev_lines or [] - - @classmethod - def from_file(cls, log_file, max_preceding=5, ignore_kbi=False): - """Create TracebackReport from a text file containing a Python traceback. - Only the first traceback in the file will be parsed. - - Args: - log_file (Path): File to parse. - max_preceding (int): Number of lines to collect leading up to the traceback. - ignore_kbi (bool): Skip/ignore KeyboardInterrupt. - - Returns: - TracebackReport: Contains data from log_file. - """ - token = b"Traceback (most recent call last):" - assert len(token) < cls.READ_LIMIT - try: - with log_file.open("rb") as in_fp: - for chunk in iter(partial(in_fp.read, cls.READ_LIMIT), b""): - idx = chunk.find(token) - if idx > -1: - # calculate offset of data in the file - pos = in_fp.tell() - len(chunk) + idx - break - if len(chunk) == cls.READ_LIMIT: - # seek back to avoid missing beginning of token - in_fp.seek(len(token) * -1, SEEK_CUR) - else: - # no traceback here, move along - return None - # seek back 2KB to collect preceding lines - in_fp.seek(max(pos - 2048, 0)) - data = in_fp.read(cls.READ_LIMIT) - except OSError: # pragma: no cover - # in case the file goes away - return None - - data = data.decode("ascii", errors="ignore").splitlines() - token = token.decode() - is_kbi = False - tb_start = None - tb_end = None - line_count = len(data) - for line_num, log_line in enumerate(data): - if tb_start is None and token in log_line: - tb_start = line_num - continue - if tb_start is not None: - log_line = log_line.strip() - if not log_line: - # stop at first empty line - tb_end = min(line_num, line_count) - break - if match(r"^\w+(\.\w+)*\:\s|^\w+(Interrupt|Error)$", log_line): - is_kbi = log_line.startswith("KeyboardInterrupt") - if is_kbi and ignore_kbi: - # ignore this exception since it is a KeyboardInterrupt - return None - # stop after error message - tb_end = min(line_num + 1, line_count) - break - assert tb_start is not None - if max_preceding > 0: - prev_start = max(tb_start - max_preceding, 0) - prev_lines = data[prev_start:tb_start] - else: - prev_lines = None - if tb_end is None: - # limit if the end is not identified (failsafe) - tb_end = max(line_count, cls.MAX_LINES) - if tb_end - tb_start > cls.MAX_LINES: - # add first entry - lines = data[tb_start : tb_start + 3] - lines += ["<--- TRACEBACK TRIMMED--->"] - # add end entries - lines += data[tb_end - (cls.MAX_LINES - 3) : tb_end] - else: - lines = data[tb_start:tb_end] - return cls(log_file, lines, is_kbi=is_kbi, prev_lines=prev_lines) - - def __len__(self): - return len(str(self)) - - def __str__(self): - return "\n".join( - [f"Log: '{self.log_file.name}'"] + self.prev_lines + self.lines - ) - - class _TableFormatter: """Format data in a table.""" diff --git a/grizzly/common/test_status.py b/grizzly/common/test_status.py index 6ef65c62..c8abfd05 100644 --- a/grizzly/common/test_status.py +++ b/grizzly/common/test_status.py @@ -38,7 +38,6 @@ def test_basic_status_01(): assert status.ignored == 0 assert status.iteration == 0 assert status.log_size == 0 - assert status.results is None assert not status._profiles assert status.runtime > 0 assert status.rate == 0 @@ -141,7 +140,11 @@ def test_status_03(tmp_path): assert status.iteration == loaded.iteration assert status.log_size == loaded.log_size assert status.pid == loaded.pid - assert loaded.results.get("uid1") == ("uid1", 1, "sig1") + result = loaded.results.get("uid1") + assert result + assert result.rid == "uid1" + assert result.count == 1 + assert result.desc == "sig1" assert "test" in loaded._profiles @@ -183,7 +186,11 @@ def test_status_05(mocker, tmp_path): assert status.iteration == loaded.iteration assert status.log_size == loaded.log_size assert status.pid == loaded.pid - assert loaded.results.get("uid1") == ("uid1", 1, "sig1") + result = loaded.results.get("uid1") + assert result + assert result.rid == "uid1" + assert result.count == 1 + assert result.desc == "sig1" # NOTE: this function must be at the top level to work on Windows @@ -572,16 +579,22 @@ def test_report_counter_01(tmp_path, keys, counts, limit): db_path = tmp_path / "storage.db" counter = ResultCounter(1, db_path, report_limit=limit) for report_id, counted in zip(keys, counts): - assert counter.get(report_id) == (report_id, 0, None) + result = counter.get(report_id) + assert result + assert result.rid == report_id + assert result.count == 0 + assert result.desc is None assert not counter.is_frequent(report_id) # call count() with report_id 'counted' times for current in range(1, counted + 1): - assert counter.count(report_id, "desc") == (current, (current == 1)) + assert counter.count(report_id, "desc") == current + assert counter.is_initial() == (current == 1) # test get() - if sum(counts) > 0: - assert counter.get(report_id) == (report_id, counted, "desc") - else: - assert counter.get(report_id) == (report_id, counted, None) + result = counter.get(report_id) + assert result + assert result.rid == report_id + assert result.count == counted + assert result.desc == ("desc" if sum(counts) else None) # test is_frequent() if counted > limit > 0: assert counter.is_frequent(report_id) @@ -592,8 +605,8 @@ def test_report_counter_01(tmp_path, keys, counts, limit): assert counter.is_frequent(report_id) else: assert limit == 0 - for _report_id, counted, _desc in counter: - assert counted > 0 + for result in counter: + assert result.count > 0 assert counter.total == sum(counts) @@ -607,34 +620,41 @@ def test_report_counter_02(mocker, tmp_path): counter_c = ResultCounter(3, db_path, report_limit=2) # local counts are 0, global (all counters) count is 0 assert not counter_a.is_frequent("a") + assert not counter_a.is_initial() assert not counter_b.is_frequent("a") assert not counter_c.is_frequent("a") # local (counter_a, bucket a) count is 1, global (all counters) count is 1 - assert counter_a.count("a", "desc") == (1, True) + assert counter_a.count("a", "desc") == 1 + assert counter_a.is_initial() assert not counter_a.is_frequent("a") assert not counter_b.is_frequent("a") assert not counter_c.is_frequent("a") # local (counter_b, bucket a) count is 1, global (all counters) count is 2 - assert counter_b.count("a", "desc") == (1, False) + assert counter_b.count("a", "desc") == 1 + assert not counter_b.is_initial() assert not counter_a.is_frequent("a") assert not counter_b.is_frequent("a") assert not counter_c.is_frequent("a") # local (counter_b, bucket a) count is 2, global (all counters) count is 3 # locally exceeded - assert counter_b.count("a", "desc") == (2, False) + assert counter_b.count("a", "desc") == 2 + assert not counter_b.is_initial() assert counter_b.is_frequent("a") # local (counter_c, bucket a) count is 1, global (all counters) count is 4 - assert counter_c.count("a", "desc") == (1, False) + assert counter_c.count("a", "desc") == 1 + assert not counter_c.is_initial() assert not counter_a.is_frequent("a") assert counter_b.is_frequent("a") assert not counter_c.is_frequent("a") # local (counter_a, bucket a) count is 2, global (all counters) count is 5 # no limit - assert counter_a.count("a", "desc") == (2, False) + assert counter_a.count("a", "desc") == 2 + assert not counter_a.is_initial() assert not counter_a.is_frequent("a") # local (counter_c, bucket a) count is 2, global (all counters) count is 6 # locally not exceeded, globally exceeded - assert counter_c.count("a", "desc") == (2, False) + assert counter_c.count("a", "desc") == 2 + assert not counter_c.is_initial() assert counter_c.is_frequent("a") # local (counter_a, bucket x) count is 0, global (all counters) count is 0 assert not counter_a.is_frequent("x") @@ -670,17 +690,35 @@ def test_report_counter_03(mocker, tmp_path): # last 2 seconds loaded = ReadOnlyResultCounter.load(db_path, 2)[0] assert loaded.total == 1 - assert loaded.get("b") == ("b", 1, "desc_b") + result = loaded.get("b") + assert result + assert result.rid == "b" + assert result.count == 1 + assert result.desc == "desc_b" # last 3 seconds loaded = ReadOnlyResultCounter.load(db_path, 3)[0] - assert loaded.get("a") == ("a", 2, "desc_a") + result = loaded.get("a") + assert result + assert result.rid == "a" + assert result.count == 2 + assert result.desc == "desc_a" assert loaded.total == 3 # increase time limit fake_time.return_value = 4 loaded = ReadOnlyResultCounter.load(db_path, 10)[0] assert loaded.total == counter.total == 3 - assert loaded.get("a") == ("a", 2, "desc_a") - assert loaded.get("b") == ("b", 1, "desc_b") + result = loaded.get("a") + assert result + assert result.rid == "a" + assert result.count == 2 + assert result.desc == "desc_a" + assert loaded.total == 3 + result = loaded.get("b") + assert result + assert result.rid == "b" + assert result.count == 1 + assert result.desc == "desc_b" + assert loaded.total == 3 def test_report_counter_04(mocker, tmp_path): diff --git a/grizzly/session.py b/grizzly/session.py index 24db3a65..b75db4ac 100644 --- a/grizzly/session.py +++ b/grizzly/session.py @@ -246,7 +246,7 @@ def run( if result.status == Result.FOUND: LOG.debug("result detected") report = self.target.create_report(is_hang=result.timeout) - seen, initial = self.status.results.count( + seen = self.status.results.count( report.crash_hash, report.short_signature ) LOG.info( @@ -256,6 +256,8 @@ def run( report.minor[:8], seen, ) + # check if result has been seen by parallel instances + initial = self.status.results.is_initial() if initial or not self.status.results.is_frequent(report.crash_hash): # add target info to test cases for test in self.iomanager.tests: