diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 989f09e70..ba7fa22bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.4.2 hooks: - id: black @@ -18,10 +18,13 @@ repos: hooks: - id: isort - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + - repo: local hooks: - id: mypy + name: mypy + entry: mypy + language: system + types: [python] args: [ "--config", "pyproject.toml" diff --git a/esrally/client/asynchronous.py b/esrally/client/asynchronous.py index 2d869390d..9c32efc44 100644 --- a/esrally/client/asynchronous.py +++ b/esrally/client/asynchronous.py @@ -77,7 +77,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": self.response = self.response_class( self.method, self.original_url, - writer=self._writer, + writer=self._writer, # type: ignore[arg-type] # TODO remove this ignore when introducing type hints continue100=self._continue, timer=self._timer, request_info=self.request_info, @@ -223,7 +223,7 @@ def __init__(self, config): self._loop = None self.client_id = None self.trace_configs = None - self.enable_cleanup_closed = None + self.enable_cleanup_closed = False self._static_responses = None self._request_class = aiohttp.ClientRequest self._response_class = aiohttp.ClientResponse diff --git a/esrally/driver/scheduler.py b/esrally/driver/scheduler.py index 4cd9c7b4b..05c591ad3 100644 --- a/esrally/driver/scheduler.py +++ b/esrally/driver/scheduler.py @@ -166,8 +166,7 @@ def remove_scheduler(name): class SimpleScheduler(ABC): @abstractmethod - def next(self, current): - ... + def next(self, current): ... class Scheduler(ABC): @@ -178,8 +177,7 @@ def after_request(self, now, weight, unit, request_meta_data): pass @abstractmethod - def next(self, current): - ... + def next(self, current): ... # Deprecated diff --git a/esrally/mechanic/launcher.py b/esrally/mechanic/launcher.py index 18d21f90b..b23a0682f 100644 --- a/esrally/mechanic/launcher.py +++ b/esrally/mechanic/launcher.py @@ -242,7 +242,7 @@ def stop(self, nodes, metrics_store): stop_watch.start() try: es.terminate() - es.wait(10.0) + es.wait(10) stopped_nodes.append(node) except psutil.NoSuchProcess: self.logger.warning("No process found with PID [%s] for node [%s].", es.pid, node_name) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 6ea7cabc1..ddaa48d85 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -19,6 +19,19 @@ import logging import os from enum import Enum +from types import ModuleType +from typing import ( + Any, + Callable, + Collection, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) import tabulate @@ -28,14 +41,14 @@ TEAM_FORMAT_VERSION = 1 -def _path_for(team_root_path, team_member_type): +def _path_for(team_root_path: str, team_member_type: str) -> str: root_path = os.path.join(team_root_path, team_member_type, f"v{TEAM_FORMAT_VERSION}") if not os.path.exists(root_path): raise exceptions.SystemSetupError(f"Path {root_path} for {team_member_type} does not exist.") return root_path -def list_cars(cfg: types.Config): +def list_cars(cfg: types.Config) -> None: loader = CarLoader(team_path(cfg)) cars = [] for name in loader.car_names(): @@ -47,17 +60,17 @@ def list_cars(cfg: types.Config): console.println(tabulate.tabulate([[c.name, c.type, c.description] for c in cars], headers=["Name", "Type", "Description"])) -def load_car(repo, name, car_params=None): +def load_car(repo: str, name: Collection[str], car_params: Optional[Mapping] = None) -> "Car": class Component: - def __init__(self, root_path, entry_point): + def __init__(self, root_path: str, entry_point: str): self.root_path = root_path self.entry_point = entry_point - root_path = None + root_paths = [] # preserve order as we append to existing config files later during provisioning. all_config_paths = [] - all_config_base_vars = {} - all_car_vars = {} + all_config_base_vars: MutableMapping[str, str] = {} + all_car_vars: MutableMapping[str, str] = {} for n in name: descriptor = CarLoader(repo).load_car(n, car_params) @@ -67,25 +80,22 @@ def __init__(self, root_path, entry_point): for p in descriptor.root_paths: # probe whether we have a root path if BootstrapHookHandler(Component(root_path=p, entry_point=Car.entry_point)).can_load(): - if not root_path: - root_path = p - # multiple cars are based on the same hook - elif root_path != p: - raise exceptions.SystemSetupError(f"Invalid car: {name}. Multiple bootstrap hooks are forbidden.") + if p not in root_paths: + root_paths.append(p) all_config_base_vars.update(descriptor.config_base_variables) all_car_vars.update(descriptor.variables) if len(all_config_paths) == 0: raise exceptions.SystemSetupError(f"At least one config base is required for car {name}") - variables = {} + variables: MutableMapping[str, str] = {} # car variables *always* take precedence over config base variables variables.update(all_config_base_vars) variables.update(all_car_vars) - return Car(name, root_path, all_config_paths, variables) + return Car(name, root_paths, all_config_paths, variables) -def list_plugins(cfg: types.Config): +def list_plugins(cfg: types.Config) -> None: plugins = PluginLoader(team_path(cfg)).plugins() if plugins: console.println("Available Elasticsearch plugins:\n") @@ -94,12 +104,16 @@ def list_plugins(cfg: types.Config): console.println("No Elasticsearch plugins are available.\n") -def load_plugin(repo, name, config_names, plugin_params=None): +def load_plugin( + repo: str, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None +) -> "PluginDescriptor": return PluginLoader(repo).load_plugin(name, config_names, plugin_params) -def load_plugins(repo, plugin_names, plugin_params=None): - def name_and_config(p): +def load_plugins( + repo: str, plugin_names: Collection[str], plugin_params: Optional[Mapping[str, str]] = None +) -> Collection["PluginDescriptor"]: + def name_and_config(p: str) -> Tuple[str, Optional[Collection[str]]]: plugin_spec = p.split(":") if len(plugin_spec) == 1: return plugin_spec[0], None @@ -116,7 +130,7 @@ def name_and_config(p): return plugins -def team_path(cfg: types.Config): +def team_path(cfg: types.Config) -> str: root_path = cfg.opts("mechanic", "team.path", mandatory=False) if root_path: return root_path @@ -141,35 +155,44 @@ def team_path(cfg: types.Config): class CarLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.cars_dir = _path_for(team_root_path, "cars") self.logger = logging.getLogger(__name__) - def car_names(self): - def __car_name(path): + def car_names(self) -> Iterator[str]: + def __car_name(path: str) -> str: p, _ = io.splitext(path) return io.basename(p) - def __is_car(path): + def __is_car(path: str) -> bool: _, extension = io.splitext(path) return extension == ".ini" return map(__car_name, filter(__is_car, os.listdir(self.cars_dir))) - def _car_file(self, name): + def _car_file(self, name: str) -> str: return os.path.join(self.cars_dir, f"{name}.ini") - def load_car(self, name, car_params=None): + def load_car(self, name: str, car_params: Optional[Mapping[str, Any]] = None) -> "CarDescriptor": car_config_file = self._car_file(name) if not io.exists(car_config_file): raise exceptions.SystemSetupError(f"Unknown car [{name}]. List the available cars with {PROGRAM_NAME} list cars.") config = self._config_loader(car_config_file) - root_paths = [] - config_paths = [] - config_base_vars = {} + root_paths: List[str] = [] + config_paths: List[str] = [] + config_base_vars: MutableMapping[str, Any] = {} + description = self._value(config, ["meta", "description"], default="") + assert isinstance(description, str), f"Car [{name}] defines an invalid description [{description}]." + car_type = self._value(config, ["meta", "type"], default="car") - config_bases = self._value(config, ["config", "base"], default="").split(",") + assert isinstance(car_type, str), f"Car [{name}] defines an invalid type [{car_type}]." + + config_base = self._value(config, ["config", "base"], default="") + assert config_base is not None, f"Car [{name}] does not define a config base." + assert isinstance(config_base, str), f"Car [{name}] defines an invalid config base [{config_base}]." + config_bases = config_base.split(",") + for base in config_bases: if base: root_path = os.path.join(self.cars_dir, base) @@ -190,24 +213,26 @@ def load_car(self, name, car_params=None): return CarDescriptor(name, description, car_type, root_paths, config_paths, config_base_vars, variables) - def _config_loader(self, file_name): + def _config_loader(self, file_name: str) -> "configparser.ConfigParser": config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is - config.optionxform = lambda option: option + config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(file_name) return config - def _value(self, cfg, section_path, default=None): - path = [section_path] if (isinstance(section_path, str)) else section_path - current_cfg = cfg + def _value( + self, cfg: "configparser.ConfigParser", section_path: Union[str, Collection[str]], default: Optional[str] = None + ) -> Optional[Union[str, Mapping[str, Any]]]: + path: Collection[str] = [section_path] if (isinstance(section_path, str)) else section_path + current_cfg: Union["configparser.ConfigParser", Mapping[str, Any], str] = cfg for k in path: - if k in current_cfg: + if not isinstance(current_cfg, str) and k in current_cfg: current_cfg = current_cfg[k] else: return default return current_cfg - def _copy_section(self, cfg, section, target): + def _copy_section(self, cfg: "configparser.ConfigParser", section: str, target: MutableMapping[str, Any]) -> MutableMapping[str, Any]: if section in cfg.sections(): for k, v in cfg[section].items(): target[k] = v @@ -215,7 +240,16 @@ def _copy_section(self, cfg, section, target): class CarDescriptor: - def __init__(self, name, description, type, root_paths, config_paths, config_base_variables, variables): + def __init__( + self, + name: str, + description: str, + type: str, + root_paths: Collection[str], + config_paths: Collection[str], + config_base_variables: Mapping[str, str], + variables: Mapping[str, str], + ): self.name = name self.description = description self.type = type @@ -224,10 +258,10 @@ def __init__(self, name, description, type, root_paths, config_paths, config_bas self.config_base_variables = config_base_variables self.variables = variables - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.name == other.name @@ -235,59 +269,71 @@ class Car: # name of the initial Python file to load for cars. entry_point = "config" - def __init__(self, names, root_path, config_paths, variables=None): + def __init__( + self, + names: Collection[str], + root_path: Union[None, str, Collection[str]], + config_paths: Collection[str], + variables: Optional[Mapping[str, Any]] = None, + ): """ Creates new settings for a benchmark candidate. :param names: Descriptive name(s) for this car. - :param root_path: The root path from which bootstrap hooks should be loaded if any. May be ``None``. + :param root_path: The root path(s) from which bootstrap hooks should be loaded if any. May be ``[]``. :param config_paths: A non-empty list of paths where the raw config can be found. :param variables: A dict containing variable definitions that need to be replaced. """ if variables is None: variables = {} if isinstance(names, str): - self.names = [names] + self.names: Collection[str] = [names] else: self.names = names - self.root_path = root_path + + if root_path is None: + self.root_path: Collection[str] = [] + elif isinstance(root_path, str): + self.root_path = [root_path] + else: + self.root_path = root_path self.config_paths = config_paths self.variables = variables - def mandatory_var(self, name): + def mandatory_var(self, name: str) -> str: try: return self.variables[name] except KeyError: raise exceptions.SystemSetupError(f'Car "{self.name}" requires config key "{name}"') @property - def name(self): + def name(self) -> str: return "+".join(self.names) # Adapter method for BootstrapHookHandler @property - def config(self): + def config(self) -> str: return self.name @property - def safe_name(self): + def safe_name(self) -> str: return "_".join(self.names) - def __str__(self): + def __str__(self) -> str: return self.name class PluginLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.plugins_root_path = _path_for(team_root_path, "plugins") self.logger = logging.getLogger(__name__) - def plugins(self, variables=None): + def plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: known_plugins = self._core_plugins(variables) + self._configured_plugins(variables) sorted(known_plugins, key=lambda p: p.name) return known_plugins - def _core_plugins(self, variables=None): + def _core_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: core_plugins = [] core_plugins_path = os.path.join(self.plugins_root_path, "core-plugins.txt") if os.path.exists(core_plugins_path): @@ -299,7 +345,7 @@ def _core_plugins(self, variables=None): core_plugins.append(PluginDescriptor(name=values[0], core_plugin=True, variables=variables)) return core_plugins - def _configured_plugins(self, variables=None): + def _configured_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: configured_plugins = [] # each directory is a plugin, each .ini is a config (just go one level deep) for entry in os.listdir(self.plugins_root_path): @@ -313,10 +359,10 @@ def _configured_plugins(self, variables=None): configured_plugins.append(PluginDescriptor(name=plugin_name, config=config, variables=variables)) return configured_plugins - def _plugin_file(self, name, config): + def _plugin_file(self, name: str, config: str) -> str: return os.path.join(self._plugin_root_path(name), "%s.ini" % config) - def _plugin_root_path(self, name): + def _plugin_root_path(self, name: str) -> str: return os.path.join(self.plugins_root_path, self._plugin_name_to_file(name)) # As we allow to store Python files in the plugin directory and the plugin directory also serves as the root path of the corresponding @@ -324,16 +370,18 @@ def _plugin_root_path(self, name): # need to switch from underscores to hyphens and vice versa. # # We are implicitly assuming that plugin names stick to the convention of hyphen separation to simplify implementation and usage a bit. - def _file_to_plugin_name(self, file_name): + def _file_to_plugin_name(self, file_name: str) -> str: return file_name.replace("_", "-") - def _plugin_name_to_file(self, plugin_name): + def _plugin_name_to_file(self, plugin_name: str) -> str: return plugin_name.replace("-", "_") - def _core_plugin(self, name, variables=None): + def _core_plugin(self, name: str, variables: Optional[Mapping[str, str]] = None) -> Optional["PluginDescriptor"]: return next((p for p in self._core_plugins(variables) if p.name == name and p.config is None), None) - def load_plugin(self, name, config_names, plugin_params=None): + def load_plugin( + self, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None + ) -> "PluginDescriptor": if config_names is not None: self.logger.info("Loading plugin [%s] with configuration(s) [%s].", name, config_names) else: @@ -382,7 +430,7 @@ def load_plugin(self, name, config_names, plugin_params=None): config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is - config.optionxform = lambda option: option + config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(config_file) if "config" in config and "base" in config["config"]: config_bases = config["config"]["base"].split(",") @@ -415,7 +463,15 @@ class PluginDescriptor: # name of the initial Python file to load for plugins. entry_point = "plugin" - def __init__(self, name, core_plugin=False, config=None, root_path=None, config_paths=None, variables=None): + def __init__( + self, + name: str, + core_plugin: bool = False, + config: Optional[Collection[str]] = None, + root_path: Optional[str] = None, + config_paths: Optional[Collection[str]] = None, + variables: Optional[Mapping[str, Any]] = None, + ): if config_paths is None: config_paths = [] if variables is None: @@ -427,27 +483,27 @@ def __init__(self, name, core_plugin=False, config=None, root_path=None, config_ self.config_paths = config_paths self.variables = variables - def __str__(self): - return "Plugin descriptor for [%s]" % self.name + def __str__(self) -> str: + return f"Plugin descriptor for [{self.name}]" - def __repr__(self): + def __repr__(self) -> str: r = [] for prop, value in vars(self).items(): r.append("%s = [%s]" % (prop, repr(value))) return ", ".join(r) @property - def moved_to_module(self): + def moved_to_module(self) -> bool: # For a BWC escape hatch we first check if the plugin is listed in rally-teams' "core-plugin.txt", # thus allowing users to override the teams path or revision to include the repository-s3/azure/gcs plugins in # "core-plugin.txt" # TODO: https://github.com/elastic/rally/issues/1622 return self.name in ["repository-s3", "repository-gcs", "repository-azure"] and not self.core_plugin - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) ^ hash(self.config) ^ hash(self.core_plugin) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and (self.name, self.config, self.core_plugin) == (other.name, other.config, other.core_plugin) @@ -455,14 +511,14 @@ class BootstrapPhase(Enum): post_install = 10 @classmethod - def valid(cls, name): + def valid(cls, name: str) -> bool: for n in BootstrapPhase.names(): if n == name: return True return False @classmethod - def names(cls): + def names(cls) -> Collection[str]: return [p.name for p in list(BootstrapPhase)] @@ -471,7 +527,7 @@ class BootstrapHookHandler: Responsible for loading and executing component-specific intitialization code. """ - def __init__(self, component, loader_class=modules.ComponentLoader): + def __init__(self, component: Any, loader_class: Callable = modules.ComponentLoader): """ Creates a new BootstrapHookHandler. @@ -481,18 +537,23 @@ def __init__(self, component, loader_class=modules.ComponentLoader): self.component = component # Don't allow the loader to recurse. The subdirectories may contain Elasticsearch specific files which we do not want to add to # Rally's Python load path. We may need to define a more advanced strategy in the future. - self.loader = loader_class(root_path=self.component.root_path, component_entry_point=self.component.entry_point, recurse=False) - self.hooks = {} + if isinstance(self.component.root_path, list): + root_path = self.component.root_path + else: + root_path = [self.component.root_path] + self.loader = loader_class(root_path=root_path, component_entry_point=self.component.entry_point, recurse=False) + self.hooks: MutableMapping[str, List[Callable]] = {} self.logger = logging.getLogger(__name__) - def can_load(self): + def can_load(self) -> bool: return self.loader.can_load() - def load(self): - root_module = self.loader.load() + def load(self) -> None: + root_modules: Collection[ModuleType] = self.loader.load() try: # every module needs to have a register() method - root_module.register(self) + for module in root_modules: + module.register(self) except exceptions.RallyError: # just pass our own exceptions transparently. raise @@ -501,15 +562,16 @@ def load(self): self.logger.exception(msg) raise exceptions.SystemSetupError(msg) - def register(self, phase, hook): + def register(self, phase: str, hook: Callable) -> None: self.logger.info("Registering bootstrap hook [%s] for phase [%s] in component [%s]", hook.__name__, phase, self.component.name) if not BootstrapPhase.valid(phase): raise exceptions.SystemSetupError(f"Unknown bootstrap phase [{phase}]. Valid phases are: {BootstrapPhase.names()}.") if phase not in self.hooks: - self.hooks[phase] = [] + empty: List[Callable] = [] + self.hooks[phase] = empty self.hooks[phase].append(hook) - def invoke(self, phase, **kwargs): + def invoke(self, phase: str, **kwargs: Mapping[str, Any]) -> None: if phase in self.hooks: self.logger.info("Invoking phase [%s] for component [%s] in config [%s]", phase, self.component.name, self.component.config) for hook in self.hooks[phase]: diff --git a/esrally/track/loader.py b/esrally/track/loader.py index 56a299d13..6c6f947e1 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -1226,10 +1226,11 @@ def load(self): # get dependent libraries installed in a prior step. ensure dir exists to make sure loading works correctly. os.makedirs(paths.libs(), exist_ok=True) sys.path.insert(0, paths.libs()) - root_module = self.loader.load() + root_modules = self.loader.load() try: # every module needs to have a register() method - root_module.register(self) + for module in root_modules: + module.register(self) except BaseException: msg = "Could not register track plugin at [%s]" % self.loader.root_path logging.getLogger(__name__).exception(msg) diff --git a/esrally/types.py b/esrally/types.py index d64dbc896..bb6e8c2b1 100644 --- a/esrally/types.py +++ b/esrally/types.py @@ -162,17 +162,12 @@ class Config(Protocol): - def add(self, scope, section: Section, key: Key, value: Any) -> None: - ... + def add(self, scope, section: Section, key: Key, value: Any) -> None: ... - def add_all(self, source: _Config, section: Section) -> None: - ... + def add_all(self, source: _Config, section: Section) -> None: ... - def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: - ... + def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: ... - def all_opts(self, section: Section) -> dict: - ... + def all_opts(self, section: Section) -> dict: ... - def exists(self, section: Section, key: Key) -> bool: - ... + def exists(self, section: Section, key: Key) -> bool: ... diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 7452e2b71..e4e83b2cd 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -26,40 +26,66 @@ import subprocess import tarfile import zipfile +from types import TracebackType +from typing import ( + IO, + Any, + AnyStr, + Callable, + Collection, + Generic, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, +) import zstandard +# This was introduced in Python 3.11 to `typing`; older versions need `typing_extensions` +# but they are treated the same by mypy, so I'm not going to use conditional imports here +from typing_extensions import Self + from esrally.utils import console SUPPORTED_ARCHIVE_FORMATS = [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".zst"] -class FileSource: +class FileSource(Generic[AnyStr]): """ FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None + self.f: Optional[IO[AnyStr]] = None - def open(self): + def open(self) -> Self: self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.f is not None, "File is not open" self.f.seek(offset) - def read(self): + def read(self) -> AnyStr: + assert self.f is not None, "File is not open" return self.f.read() - def readline(self): + def readline(self) -> AnyStr: + assert self.f is not None, "File is not open" return self.f.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[AnyStr]: + assert self.f is not None, "File is not open" lines = [] f = self.f for _ in range(num_lines): @@ -69,19 +95,22 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> Self: self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -90,14 +119,14 @@ class MmapSource: MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None - self.mm = None + self.f: Optional[IO[bytes]] = None + self.mm: Optional[mmap.mmap] = None - def open(self): + def open(self) -> Self: self.f = open(self.file_name, mode="r+b") self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) self.mm.madvise(mmap.MADV_SEQUENTIAL) @@ -105,16 +134,20 @@ def open(self): # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.mm is not None, "Source is not open" self.mm.seek(offset) - def read(self): + def read(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.read() - def readline(self): + def readline(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[bytes]: + assert self.mm is not None, "Source is not open" lines = [] mm = self.mm for _ in range(num_lines): @@ -124,21 +157,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.mm is not None, "Source is not open" self.mm.close() self.mm = None + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> Self: self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -149,10 +186,10 @@ class DictStringFileSourceFactory: It is intended for scenarios where multiple files may be read by client code. """ - def __init__(self, name_to_contents): + def __init__(self, name_to_contents: Mapping[str, Sequence[str]]): self.name_to_contents = name_to_contents - def __call__(self, name, mode, encoding="utf-8"): + def __call__(self, name: str, mode: str, encoding: str = "utf-8") -> "StringAsFileSource": return StringAsFileSource(self.name_to_contents[name], mode, encoding) @@ -162,7 +199,7 @@ class StringAsFileSource: be used in production code. """ - def __init__(self, contents, mode, encoding="utf-8"): + def __init__(self, contents: Sequence[str], mode: str, encoding: str = "utf-8"): """ :param contents: The file contents as an array of strings. Each item in the array should correspond to one line. :param mode: The file mode. It is ignored in this implementation but kept to implement the same interface as ``FileSource``. @@ -172,20 +209,20 @@ def __init__(self, contents, mode, encoding="utf-8"): self.current_index = 0 self.opened = False - def open(self): + def open(self) -> Self: self.opened = True return self - def seek(self, offset): + def seek(self, offset: int) -> None: self._assert_opened() if offset != 0: raise AssertionError("StringAsFileSource does not support random seeks") - def read(self): + def read(self) -> str: self._assert_opened() return "\n".join(self.contents) - def readline(self): + def readline(self) -> str: self._assert_opened() if self.current_index >= len(self.contents): return "" @@ -193,7 +230,7 @@ def readline(self): self.current_index += 1 return line - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[str]: lines = [] for _ in range(num_lines): line = self.readline() @@ -202,23 +239,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: self._assert_opened() - self.contents = None + self.contents = [] self.opened = False - def _assert_opened(self): + def _assert_opened(self) -> None: assert self.opened - def __enter__(self): + def __enter__(self) -> Self: self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return "StringAsFileSource" @@ -227,20 +266,20 @@ class ZstAdapter: Adapter class to make the zstandard API work with Rally's decompression abstractions """ - def __init__(self, path): + def __init__(self, path: str): self.fh = open(path, "rb") self.dctx = zstandard.ZstdDecompressor() self.reader = self.dctx.stream_reader(self.fh) - def read(self, size): + def read(self, size: int) -> bytes: return self.reader.read(size) - def close(self): + def close(self) -> None: self.reader.close() self.fh.close() -def ensure_dir(directory, mode=0o777): +def ensure_dir(directory: str, mode: int = 0o777) -> None: """ Ensure that the provided directory and all of its parent directories exist. This function is safe to execute on existing directories (no op). @@ -252,7 +291,7 @@ def ensure_dir(directory, mode=0o777): os.makedirs(directory, mode, exist_ok=True) -def _zipdir(source_directory, archive): +def _zipdir(source_directory: str, archive: zipfile.ZipFile) -> None: for root, _, files in os.walk(source_directory): for file in files: archive.write( @@ -261,7 +300,7 @@ def _zipdir(source_directory, archive): ) -def is_archive(name): +def is_archive(name: str) -> bool: """ :param name: File name to check. Can be either just the file name or optionally also an absolute path. :return: True iff the given file name is an archive that is also recognized for decompression by Rally. @@ -270,7 +309,7 @@ def is_archive(name): return ext in SUPPORTED_ARCHIVE_FORMATS -def is_executable(name): +def is_executable(name: str) -> bool: """ :param name: File name to check. :return: True iff given file name is executable and in PATH, all other cases False. @@ -279,7 +318,7 @@ def is_executable(name): return shutil.which(name) is not None -def compress(source_directory, archive_name): +def compress(source_directory: str, archive_name: str) -> None: """ Compress a directory tree. @@ -290,7 +329,7 @@ def compress(source_directory, archive_name): _zipdir(source_directory, archive) -def decompress(zip_name, target_directory): +def decompress(zip_name: str, target_directory: str) -> None: """ Decompresses the provided archive to the target directory. The following file extensions are supported: @@ -314,23 +353,23 @@ def decompress(zip_name, target_directory): _do_decompress(target_directory, zipfile.ZipFile(zip_name)) elif extension == ".bz2": decompressor_args = ["pbzip2", "-d", "-k", "-m10000", "-c"] - decompressor_lib = bz2.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_bz2 = bz2.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_bz2) elif extension == ".zst": decompressor_args = ["pzstd", "-f", "-d", "-c"] - decompressor_lib = ZstAdapter - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_zst = ZstAdapter + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_zst) elif extension == ".gz": decompressor_args = ["pigz", "-d", "-k", "-c"] - decompressor_lib = gzip.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_gzip = gzip.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_gzip) elif extension in [".tar", ".tar.gz", ".tgz", ".tar.bz2"]: _do_decompress(target_directory, tarfile.open(zip_name)) else: raise RuntimeError("Unsupported file extension [%s]. Cannot decompress [%s]" % (extension, zip_name)) -def _do_decompress_manually(target_directory, filename, decompressor_args, decompressor_lib): +def _do_decompress_manually(target_directory: str, filename: str, decompressor_args: List[str], decompressor_lib: Callable) -> None: decompressor_bin = decompressor_args[0] base_path_without_extension = basename(splitext(filename)[0]) @@ -345,7 +384,9 @@ def _do_decompress_manually(target_directory, filename, decompressor_args, decom _do_decompress_manually_with_lib(target_directory, filename, decompressor_lib(filename)) -def _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): +def _do_decompress_manually_external( + target_directory: str, filename: str, base_path_without_extension: str, decompressor_args: List[str] +) -> bool: with open(os.path.join(target_directory, base_path_without_extension), "wb") as new_file: try: subprocess.run(decompressor_args + [filename], stdout=new_file, stderr=subprocess.PIPE, check=True) @@ -357,7 +398,7 @@ def _do_decompress_manually_external(target_directory, filename, base_path_witho return True -def _do_decompress_manually_with_lib(target_directory, filename, compressed_file): +def _do_decompress_manually_with_lib(target_directory: str, filename: str, compressed_file: IO[bytes]) -> None: path_without_extension = basename(splitext(filename)[0]) ensure_dir(target_directory) @@ -369,29 +410,43 @@ def _do_decompress_manually_with_lib(target_directory, filename, compressed_file compressed_file.close() -def _do_decompress(target_directory, compressed_file): +def _do_decompress(target_directory: str, compressed_file: Union[zipfile.ZipFile, tarfile.TarFile]) -> None: try: compressed_file.extractall(path=target_directory) except BaseException: - raise RuntimeError("Could not decompress provided archive [%s]" % compressed_file.filename) + if isinstance(compressed_file, zipfile.ZipFile): + raise RuntimeError( + f"Could not decompress provided archive [{compressed_file.filename}]. Please check if it is a valid zip file." + ) + if isinstance(compressed_file, tarfile.TarFile): + raise RuntimeError(f"Could not decompress provided archive [{compressed_file.name!r}]. Please check if it is a valid tar file.") finally: compressed_file.close() # just in a dedicated method to ease mocking -def dirname(path): +def dirname(path: AnyStr) -> AnyStr: return os.path.dirname(path) -def basename(path): +def basename(path: AnyStr) -> AnyStr: return os.path.basename(path) -def exists(path): +def exists(path: AnyStr) -> bool: return os.path.exists(path) +@overload +def normalize_path(path: str) -> str: ... +@overload +def normalize_path(path: str, cwd: str = ".") -> str: ... +@overload +def normalize_path(path: bytes) -> bytes: ... +@overload +def normalize_path(path: bytes, cwd: bytes = b".") -> bytes: ... def normalize_path(path, cwd="."): + # This is a bug in mypy, see https://github.com/python/mypy/issues/3737 """ Normalizes a path by removing redundant "../" and also expanding the "~" character to the user home directory. :param path: A possibly non-normalized path. @@ -406,7 +461,7 @@ def normalize_path(path, cwd="."): return normalized -def escape_path(path): +def escape_path(path: str) -> str: """ Escapes any characters that might be problematic in shell interactions. @@ -416,7 +471,7 @@ def escape_path(path): return path.replace("\\", "\\\\") -def splitext(file_name): +def splitext(file_name: str) -> Tuple[str, str]: if file_name.endswith(".tar.gz"): return file_name[0:-7], file_name[-7:] elif file_name.endswith(".tar.bz2"): @@ -425,7 +480,7 @@ def splitext(file_name): return os.path.splitext(file_name) -def has_extension(file_name, extension): +def has_extension(file_name: str, extension: str) -> bool: """ Checks whether the given file name has the given extension. @@ -443,7 +498,7 @@ class FileOffsetTable: data file. This helps bulk-indexing clients to advance quickly to a certain position in a large data file. """ - def __init__(self, data_file_path, offset_table_path, mode): + def __init__(self, data_file_path: str, offset_table_path: str, mode: str): """ Creates a new FileOffsetTable instance. The constructor should not be called directly but instead the respective factory methods should be used. @@ -456,34 +511,35 @@ def __init__(self, data_file_path, offset_table_path, mode): self.data_file_path = data_file_path self.offset_table_path = offset_table_path self.mode = mode - self.offset_file = None + self.offset_file: Optional[IO[str]] = None - def exists(self): + def exists(self) -> bool: """ :return: True iff the file offset table already exists. """ return os.path.exists(self.offset_table_path) - def is_valid(self): + def is_valid(self) -> bool: """ :return: True iff the file offset table exists and it is up-to-date. """ return self.exists() and os.path.getmtime(self.offset_table_path) >= os.path.getmtime(self.data_file_path) - def __enter__(self): + def __enter__(self) -> Self: self.offset_file = open(self.offset_table_path, self.mode) return self - def add_offset(self, line_number, offset): + def add_offset(self, line_number: int, offset: int) -> None: """ Adds a new offset mapping to the file offset table. This method has to be called inside a context-manager block. :param line_number: A line number to add. :param offset: The corresponding offset in bytes. """ + assert self.offset_file is not None, "File offset table must be opened in a context manager block." print(f"{line_number};{offset}", file=self.offset_file) - def find_closest_offset(self, target_line_number): + def find_closest_offset(self, target_line_number: int) -> Tuple[int, int]: """ Determines the offset in bytes for the line L in the corresponding data file with the following properties: @@ -497,6 +553,7 @@ def find_closest_offset(self, target_line_number): prior_offset = 0 prior_remaining_lines = target_line_number + assert self.offset_file is not None, "File offset table must be opened in a context manager block." for line in self.offset_file: line_number, offset_in_bytes = (int(i) for i in line.strip().split(";")) if line_number <= target_line_number: @@ -507,13 +564,16 @@ def find_closest_offset(self, target_line_number): return prior_offset, prior_remaining_lines - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: + assert self.offset_file is not None, "File offset table must be opened in a context manager block." self.offset_file.close() self.offset_file = None return False @classmethod - def create_for_data_file(cls, data_file_path): + def create_for_data_file(cls, data_file_path: str) -> Self: """ Factory method to create a new file offset table. @@ -522,7 +582,7 @@ def create_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "wt") @classmethod - def read_for_data_file(cls, data_file_path): + def read_for_data_file(cls, data_file_path: str) -> Self: """ Factory method to read from an existing file offset table. @@ -532,7 +592,7 @@ def read_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "rt") @staticmethod - def remove(data_file_path): + def remove(data_file_path: str) -> None: """ Removes a file offset table for the provided data path. @@ -541,7 +601,7 @@ def remove(data_file_path): os.remove(f"{data_file_path}.offset") -def prepare_file_offset_table(data_file_path): +def prepare_file_offset_table(data_file_path: str) -> Optional[int]: """ Creates a file that contains a mapping from line numbers to file offsets for the provided path. This file is used internally by #skip_lines(data_file_path, data_file) to speed up line skipping. @@ -568,7 +628,7 @@ def prepare_file_offset_table(data_file_path): return None -def remove_file_offset_table(data_file_path): +def remove_file_offset_table(data_file_path: str) -> None: """ Attempts to remove the file offset table for the provided data path. @@ -578,7 +638,7 @@ def remove_file_offset_table(data_file_path): FileOffsetTable.remove(data_file_path) -def skip_lines(data_file_path, data_file, number_of_lines_to_skip): +def skip_lines(data_file_path: str, data_file: IO[AnyStr], number_of_lines_to_skip: int) -> None: """ Skips the first `number_of_lines_to_skip` lines in `data_file` as a side effect. @@ -606,7 +666,7 @@ def skip_lines(data_file_path, data_file, number_of_lines_to_skip): data_file.readline() -def get_size(start_path="."): +def get_size(start_path: str = ".") -> int: total_size = 0 for dirpath, _, filenames in os.walk(start_path): for f in filenames: diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 728b55d3a..69ec87d22 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. -import importlib.machinery +import importlib.util import logging import os import sys +from types import ModuleType +from typing import Collection, Iterator, Tuple, Union from esrally import exceptions from esrally.utils import io @@ -34,36 +36,44 @@ class ComponentLoader: """ - def __init__(self, root_path, component_entry_point, recurse=True): + def __init__(self, root_path: Union[str, Collection[str]], component_entry_point: str, recurse: bool = True): """ Creates a new component loader. - :param root_path: An absolute path to a directory which contains the component entry point. + :param root_path: An absolute path or list of paths to a directory which contains the component entry point. :param component_entry_point: The name of the component entry point. A corresponding file with the extension ".py" must exist in the ``root_path``. :param recurse: Search recursively for modules but ignore modules starting with "_" (Default: ``True``). """ - self.root_path = root_path + self.root_path: Collection[str] = root_path if isinstance(root_path, list) else [str(root_path)] self.component_entry_point = component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) - def _modules(self, module_paths, component_name): + def _modules(self, module_paths: Collection[str], component_name: str, root_path: str) -> Iterator[Tuple[str, str]]: for path in module_paths: for filename in os.listdir(path): name, ext = os.path.splitext(filename) if ext.endswith(".py"): - root_relative_path = os.path.join(path, name)[len(self.root_path) + len(os.path.sep) :] + file_absolute_path = os.path.join(path, filename) + root_absolute_path = os.path.join(path, name) + root_relative_path = root_absolute_path[len(root_path) + len(os.path.sep) :] module_name = "%s.%s" % (component_name, root_relative_path.replace(os.path.sep, ".")) - yield module_name + yield module_name, file_absolute_path - def _load_component(self, component_name, module_dirs): + def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str) -> ModuleType: # precondition: A module with this name has to exist provided that the caller has called #can_load() before. root_module_name = "%s.%s" % (component_name, self.component_entry_point) - for p in self._modules(module_dirs, component_name): - self.logger.debug("Loading module [%s]", p) - m = importlib.import_module(p) - if p == root_module_name: + for name, p in self._modules(module_dirs, component_name, root_path): + self.logger.debug("Loading module [%s]: %s", name, p) + # Use the util methods instead of `importlib.import_module` to allow for more fine-grained control over the import process. + # in particular, we want to be able to import multiple modules that use the same name, but are from different directories. + spec = importlib.util.spec_from_file_location(name, p) + if spec is None or spec.loader is None: + raise exceptions.SystemSetupError(f"Could not load module [{name}]") + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + if name == root_module_name: root_module = m return root_module @@ -71,41 +81,46 @@ def can_load(self): """ :return: True iff the component entry point could be found. """ - return self.root_path and os.path.exists(os.path.join(self.root_path, "%s.py" % self.component_entry_point)) + return all(self.root_path) and all( + os.path.exists(os.path.join(root_path, "%s.py" % self.component_entry_point)) for root_path in self.root_path + ) - def load(self): + def load(self) -> Collection[ModuleType]: """ - Loads a component with the given component entry point. + Loads components with the given component entry point. Precondition: ``ComponentLoader#can_load() == True``. - :return: The root module. + :return: The root modules. """ - component_name = io.basename(self.root_path) - self.logger.info("Loading component [%s] from [%s]", component_name, self.root_path) - module_dirs = [] - # search all paths within this directory for modules but exclude all directories starting with "_" - if self.recurse: - for dirpath, dirs, _ in os.walk(self.root_path): - module_dirs.append(dirpath) - ignore = [] - for d in dirs: - if d.startswith("_"): - self.logger.debug("Removing [%s] from load path.", d) - ignore.append(d) - for d in ignore: - dirs.remove(d) - else: - module_dirs.append(self.root_path) - # load path is only the root of the package hierarchy - component_root_path = os.path.abspath(os.path.join(self.root_path, os.pardir)) - self.logger.debug("Adding [%s] to Python load path.", component_root_path) - # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules - sys.path.insert(0, component_root_path) - try: - root_module = self._load_component(component_name, module_dirs) - return root_module - except BaseException: - msg = f"Could not load component [{component_name}]" - self.logger.exception(msg) - raise exceptions.SystemSetupError(msg) + root_modules = [] + for root_path in self.root_path: + component_name = io.basename(root_path) + self.logger.info("Loading component [%s] from [%s]", component_name, root_path) + module_dirs = [] + # search all paths within this directory for modules but exclude all directories starting with "_" + if self.recurse: + for dirpath, dirs, _ in os.walk(root_path): + module_dirs.append(dirpath) + ignore = [] + for d in dirs: + if d.startswith("_"): + self.logger.debug("Removing [%s] from load path.", d) + ignore.append(d) + for d in ignore: + dirs.remove(d) + else: + module_dirs.append(root_path) + # load path is only the root of the package hierarchy + component_root_path = os.path.abspath(os.path.join(root_path, os.pardir)) + self.logger.debug("Adding [%s] to Python load path.", component_root_path) + # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules + sys.path.insert(0, component_root_path) + try: + root_module = self._load_component(component_name, module_dirs, root_path) + root_modules.append(root_module) + except BaseException: + msg = f"Could not load component [{component_name}]" + self.logger.exception(msg) + raise exceptions.SystemSetupError(msg) + return root_modules diff --git a/esrally/utils/process.py b/esrally/utils/process.py index 6a283a723..b6d8d9f1c 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -20,7 +20,7 @@ import shlex import subprocess import time -from typing import Callable, Dict, List +from typing import IO, Callable, List, Mapping, Optional, Union import psutil @@ -38,7 +38,7 @@ def run_subprocess(command_line: str) -> int: return subprocess.call(command_line, shell=True) -def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> List[str]: +def run_subprocess_with_output(command_line: str, env: Optional[Mapping[str, str]] = None) -> List[str]: logger = logging.getLogger(__name__) logger.debug("Running subprocess [%s] with output.", command_line) command_line_args = shlex.split(command_line) @@ -46,6 +46,7 @@ def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> has_output = True lines = [] while has_output: + assert command_line_process.stdout is not None, "stdout is None" line = command_line_process.stdout.readline() if line: lines.append(line.decode("UTF-8").strip()) @@ -72,10 +73,10 @@ def exit_status_as_bool(runnable: Callable[[], int], quiet: bool = False) -> boo def run_subprocess_with_logging( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, - stdin: FileId = None, - env: Dict[str, str] = None, + stdin: Optional[Union[FileId, IO[bytes]]] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> int: """ @@ -117,10 +118,10 @@ def run_subprocess_with_logging( def run_subprocess_with_logging_and_output( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, - stdin: FileId = None, - env: Dict[str, str] = None, + stdin: Optional[Union[FileId, IO[bytes]]] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> subprocess.CompletedProcess: """ @@ -142,7 +143,6 @@ def run_subprocess_with_logging_and_output( if header is not None: logger.info(header) - # pylint: disable=subprocess-popen-preexec-fn completed = subprocess.run( command_line_args, stdout=subprocess.PIPE, @@ -174,7 +174,7 @@ def is_rally_process(p: psutil.Process) -> bool: def find_all_other_rally_processes() -> List[psutil.Process]: - others = [] + others: List[psutil.Process] = [] for_all_other_processes(is_rally_process, others.append) return others @@ -188,7 +188,7 @@ def redact_cmdline(cmdline: list) -> List[str]: def kill_all(predicate: Callable[[psutil.Process], bool]) -> None: - def kill(p: psutil.Process): + def kill(p: psutil.Process) -> None: logging.getLogger(__name__).info( "Killing lingering process with PID [%s] and command line [%s].", p.pid, redact_cmdline(p.cmdline()) ) diff --git a/pyproject.toml b/pyproject.toml index 1331874bd..98694685c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,13 @@ path = "esrally/_version.py" name = "esrally" dynamic = ["version"] authors = [ - {name="Daniel Mitterdorfer", email="daniel.mitterdorfer@gmail.com"}, + { name = "Daniel Mitterdorfer", email = "daniel.mitterdorfer@gmail.com" }, ] description = "Macrobenchmarking framework for Elasticsearch" readme = "README.md" -license = {text = "Apache License 2.0"} +license = { text = "Apache License 2.0" } requires-python = ">=3.8" -classifiers=[ +classifiers = [ "Topic :: System :: Benchmark", "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", @@ -81,7 +81,9 @@ dependencies = [ # License: Apache 2.0 "google-auth==1.22.1", # License: BSD - "zstandard==0.21.0" + "zstandard==0.21.0", + # License: Python Software Foundation License + "typing-extensions==4.12.2", ] [project.optional-dependencies] @@ -112,6 +114,14 @@ develop = [ "pylint==3.1.0", "trustme==0.9.0", "GitPython==3.1.30", + # mypy + "boto3-stubs==1.26.125", + "mypy==1.10.1", + "types-psutil==5.9.4", + "types-tabulate==0.8.9", + "types-urllib3==1.26.19", + "types-requests<2.32.0", + "types-jsonschema==3.2.0", ] [project.scripts] @@ -143,12 +153,25 @@ junit_logging = "all" asyncio_mode = "strict" xfail_strict = true -# With rare exceptions, Rally does not use type hints. The intention of the -# following largely reduced mypy configuration scope is verification of argument -# types in config.Config methods while introducing configuration properties -# (props). The error we are after here is "arg-type". +# With exceptions specified in mypy override section, Rally does not use type +# hints (they were a novelty when Rally came to be). Hints are being slowly and +# opportunistically introduced whenever we revisit a group of modules. +# +# The intention of the following largely reduced global config scope is +# verification of argument types in config.Config methods while introducing +# configuration properties (props). The intention of "disable_error_code" option +# is to keep "arg-type" error code, while disabling other error codes. +# Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] -python_version = 3.8 +python_version = "3.8" +# subset of "strict", kept at global config level as some of the options are +# supported only at this level +# https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +extra_checks = true check_untyped_defs = true disable_error_code = [ "assignment", @@ -168,3 +191,32 @@ disable_error_code = [ "union-attr", "var-annotated", ] +files = ["esrally/", "it/", "tests/"] + +[[tool.mypy.overrides]] +module = [ + "esrally.mechanic.team", + "esrally.utils.modules", + "esrally.utils.io", + "esrally.utils.process", +] +disallow_incomplete_defs = true +# this should be a copy of disabled_error_code from above +enable_error_code = [ + "assignment", + "attr-defined", + "call-arg", + "call-overload", + "dict-item", + "import-not-found", + "import-untyped", + "index", + "list-item", + "misc", + "name-defined", + "operator", + "str-bytes-safe", + "syntax", + "union-attr", + "var-annotated", +] diff --git a/tests/client/factory_test.py b/tests/client/factory_test.py index 791423365..fd966d30f 100644 --- a/tests/client/factory_test.py +++ b/tests/client/factory_test.py @@ -29,7 +29,7 @@ import pytest import trustme import urllib3.exceptions -from elastic_transport import ApiResponseMeta +from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig from pytest_httpserver import HTTPServer from esrally import client, doc_link, exceptions @@ -38,7 +38,17 @@ def _api_error(status, message): - return elasticsearch.ApiError(message, ApiResponseMeta(status=status, http_version="1.1", headers={}, duration=0.0, node=None), None) + return elasticsearch.ApiError( + message, + ApiResponseMeta( + status=status, + http_version="1.1", + headers=HttpHeaders(), + duration=0.0, + node=NodeConfig(scheme="https", host="localhost", port=9200), + ), + None, + ) class TestEsClientFactory: @@ -518,7 +528,7 @@ def test_connection_ssl_error(self, es): def test_connection_protocol_error(self, es): es.cluster.health.side_effect = elasticsearch.ConnectionError( message="N/A", - errors=[urllib3.exceptions.ProtocolError("Connection aborted.")], + errors=[urllib3.exceptions.ProtocolError("Connection aborted.")], # type: ignore[arg-type] ) with pytest.raises( exceptions.SystemSetupError, diff --git a/tests/driver/driver_test.py b/tests/driver/driver_test.py index e9360d849..10f753793 100644 --- a/tests/driver/driver_test.py +++ b/tests/driver/driver_test.py @@ -1894,7 +1894,13 @@ async def test_execute_single_with_connection_error_always_aborts(self, on_error async def test_execute_single_with_http_400_aborts_when_specified(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=404, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock( side_effect=elasticsearch.NotFoundError(message="not found", meta=error_meta, body="the requested document could not be found") ) @@ -1912,7 +1918,13 @@ async def test_execute_single_with_http_400_with_empty_raw_response_body(self): params = None empty_body = io.BytesIO(b"") str_literal_empty_body = str(empty_body) - error_meta = elastic_transport.ApiResponseMeta(status=413, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=413, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.ApiError(message=str_literal_empty_body, meta=error_meta, body=empty_body)) with pytest.raises(exceptions.RallyAssertionError) as exc: @@ -1925,7 +1937,13 @@ async def test_execute_single_with_http_400_with_raw_response_body(self): params = None body = io.BytesIO(b"Huge error") str_literal = str(body) - error_meta = elastic_transport.ApiResponseMeta(status=499, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=499, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.ApiError(message=str_literal, meta=error_meta, body=body)) with pytest.raises(exceptions.RallyAssertionError) as exc: @@ -1936,7 +1954,13 @@ async def test_execute_single_with_http_400_with_raw_response_body(self): async def test_execute_single_with_http_400(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=404, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock( side_effect=elasticsearch.NotFoundError(message="not found", meta=error_meta, body="the requested document could not be found") ) @@ -1956,7 +1980,13 @@ async def test_execute_single_with_http_400(self): async def test_execute_single_with_http_413(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=413, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=413, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(message="", meta=error_meta, body="")) ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") diff --git a/tests/driver/runner_test.py b/tests/driver/runner_test.py index 910c1de95..561404752 100644 --- a/tests/driver/runner_test.py +++ b/tests/driver/runner_test.py @@ -3905,8 +3905,14 @@ async def test_create_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_create_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.put_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.put_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() datafeed_id = "some-data-feed" body = {"job_id": "total-requests", "indices": ["server-metrics"]} @@ -3935,8 +3941,16 @@ async def test_delete_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_delete_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.delete_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.delete_datafeed = mock.AsyncMock( + side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request") + ) es.perform_request = mock.AsyncMock() datafeed_id = "some-data-feed" @@ -3969,8 +3983,14 @@ async def test_start_ml_datafeed_with_body(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_start_ml_datafeed_with_body_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.start_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.start_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() body = {"end": "now"} params = {"datafeed-id": "some-data-feed", "body": body} @@ -4018,8 +4038,14 @@ async def test_stop_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_stop_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.stop_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.stop_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() params = { @@ -4070,8 +4096,14 @@ async def test_create_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_create_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.put_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.put_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() body = { @@ -4113,8 +4145,14 @@ async def test_delete_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_delete_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.delete_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.delete_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() job_id = "an-ml-job" @@ -4145,8 +4183,14 @@ async def test_open_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_open_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.open_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.open_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() job_id = "an-ml-job" @@ -4177,8 +4221,14 @@ async def test_close_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_close_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.close_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.close_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() params = { @@ -7399,7 +7449,7 @@ async def test_is_transparent_on_success_when_no_retries(self): @pytest.mark.asyncio async def test_is_transparent_on_exception_when_no_retries(self): - delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host")) + delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionError(message="no route to host")) es = None params = { # no retries @@ -7537,7 +7587,7 @@ async def test_retries_mixed_timeout_and_application_errors(self): @pytest.mark.asyncio async def test_does_not_retry_on_timeout_if_not_wanted(self): - delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionTimeout(408, "timed out")) + delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionTimeout(message="timed out")) es = None params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": False, "retry-on-error": True} retrier = runner.Retry(delegate) diff --git a/tests/mechanic/launcher_test.py b/tests/mechanic/launcher_test.py index c6758880e..9bdb84c6c 100644 --- a/tests/mechanic/launcher_test.py +++ b/tests/mechanic/launcher_test.py @@ -72,7 +72,7 @@ def __init__(self, client_options): def info(self): if self.client_options.get("raise-error-on-info", False): - raise elasticsearch.TransportError(401, "Unauthorized") + raise elasticsearch.TransportError(message="Unauthorized") return self._info def search(self, *args, **kwargs): diff --git a/tests/mechanic/team_test.py b/tests/mechanic/team_test.py index 32dcbedcf..6fb0fce04 100644 --- a/tests/mechanic/team_test.py +++ b/tests/mechanic/team_test.py @@ -48,17 +48,17 @@ def test_load_known_car(self): car = team.load_car(self.team_dir, ["default"], car_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]}) assert car.name == "default" assert car.config_paths == [os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates")] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]} - assert car.root_path is None + assert car.root_path == [] def test_load_car_with_mixin_single_config_base(self): car = team.load_car(self.team_dir, ["32gheap", "ea"]) assert car.name == "32gheap+ea" assert car.config_paths == [os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates")] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "32g", "clean_command": "./gradlew clean", "assertions": "true"} - assert car.root_path is None + assert car.root_path == [] def test_load_car_with_mixin_multiple_config_bases(self): car = team.load_car(self.team_dir, ["32gheap", "ea", "verbose"]) @@ -67,7 +67,7 @@ def test_load_car_with_mixin_multiple_config_bases(self): os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates"), os.path.join(current_dir, "data", "cars", "v1", "verbose_logging", "templates"), ] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "32g", "clean_command": "./gradlew clean", "verbose_logging": "true", "assertions": "true"} def test_load_car_with_install_hook(self): @@ -77,7 +77,7 @@ def test_load_car_with_install_hook(self): os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates"), os.path.join(current_dir, "data", "cars", "v1", "with_hook", "templates"), ] - assert car.root_path == os.path.join(current_dir, "data", "cars", "v1", "with_hook") + assert car.root_path == [os.path.join(current_dir, "data", "cars", "v1", "with_hook")] assert car.variables == {"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]} def test_load_car_with_multiple_bases_referring_same_install_hook(self): @@ -88,7 +88,7 @@ def test_load_car_with_multiple_bases_referring_same_install_hook(self): os.path.join(current_dir, "data", "cars", "v1", "with_hook", "templates"), os.path.join(current_dir, "data", "cars", "v1", "verbose_logging", "templates"), ] - assert car.root_path == os.path.join(current_dir, "data", "cars", "v1", "with_hook") + assert car.root_path == [os.path.join(current_dir, "data", "cars", "v1", "with_hook")] assert car.variables == {"heap_size": "16g", "clean_command": "./gradlew clean", "verbose_logging": "true"} def test_raises_error_on_unknown_car(self): @@ -112,12 +112,10 @@ def test_raises_error_on_missing_config_base(self): ): team.load_car(self.team_dir, ["missing_cfg_base"]) - def test_raises_error_if_more_than_one_different_install_hook(self): - with pytest.raises( - exceptions.SystemSetupError, - match=r"Invalid car: \['multi_hook'\]. Multiple bootstrap hooks are forbidden.", - ): - team.load_car(self.team_dir, ["multi_hook"]) + def test_doesnt_raise_error_if_more_than_one_different_install_hook(self): + car = team.load_car(self.team_dir, ["multi_hook"]) + assert isinstance(car.root_path, list) + assert len(car.root_path) == 2 class TestPluginLoader: @@ -229,7 +227,7 @@ def test_loads_module(self): hook = self.UnitTestHook() handler = team.BootstrapHookHandler(plugin, loader_class=self.UnitTestComponentLoader) - handler.loader.registration_function = hook + handler.loader.registration_function = [hook] handler.load() handler.invoke("post_install", variables={"increment": 4}) @@ -242,7 +240,7 @@ def test_cannot_register_for_unknown_phase(self): hook = self.UnitTestHook(phase="this_is_an_unknown_install_phase") handler = team.BootstrapHookHandler(plugin, loader_class=self.UnitTestComponentLoader) - handler.loader.registration_function = hook + handler.loader.registration_function = [hook] with pytest.raises(exceptions.SystemSetupError) as exc: handler.load() assert exc.value.args[0] == "Unknown bootstrap phase [this_is_an_unknown_install_phase]. Valid phases are: ['post_install']." diff --git a/tests/metrics_test.py b/tests/metrics_test.py index b299172c9..adf9b1f18 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -234,7 +234,12 @@ def logging_statements(self, retries): return logging_statements def raise_error(self): - err = elasticsearch.exceptions.ApiError("unit-test", meta=TestEsClient.ApiResponseMeta(status=self.status_code), body={}) + err = elasticsearch.exceptions.ApiError( + "unit-test", + # TODO remove this ignore when introducing type hints + meta=TestEsClient.ApiResponseMeta(status=self.status_code), # type: ignore[arg-type] + body={}, + ) raise err class BulkIndexError: @@ -321,7 +326,9 @@ def raise_error(self): def test_raises_sytem_setup_error_on_authentication_problems(self): def raise_authentication_error(): - raise elasticsearch.exceptions.AuthenticationException(meta=None, body=None, message="unit-test") + raise elasticsearch.exceptions.AuthenticationException( + meta=None, body=None, message="unit-test" # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + ) client = metrics.EsClient(self.ClientMock([{"host": "127.0.0.1", "port": "9243"}])) @@ -334,7 +341,9 @@ def raise_authentication_error(): def test_raises_sytem_setup_error_on_authorization_problems(self): def raise_authorization_error(): - raise elasticsearch.exceptions.AuthorizationException(meta=None, body=None, message="unit-test") + raise elasticsearch.exceptions.AuthorizationException( + meta=None, body=None, message="unit-test" # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + ) client = metrics.EsClient(self.ClientMock([{"host": "127.0.0.1", "port": "9243"}])) diff --git a/tests/telemetry_test.py b/tests/telemetry_test.py index c14b79c26..001a42e7c 100644 --- a/tests/telemetry_test.py +++ b/tests/telemetry_test.py @@ -24,6 +24,7 @@ from unittest import mock from unittest.mock import call +import elastic_transport import elasticsearch import pytest @@ -275,7 +276,8 @@ class ApiResponseMeta: def __call__(self, status=None, body=None, message=None): return elasticsearch.ApiError( - meta=self.ApiResponseMeta(status=status), + # TODO remove this ignore when introducing type hints + meta=self.ApiResponseMeta(status=status), # type: ignore[arg-type] body=body, message=message, ) @@ -1751,7 +1753,18 @@ def test_no_metrics_if_no_searchable_snapshots_stats(self, metrics_store_put_doc metrics_store = metrics.EsMetricsStore(cfg) client = Client( transport_client=TransportClient( - force_error=True, error=elasticsearch.NotFoundError("", "", {"error": {"reason": "No searchable snapshots indices found"}}) + force_error=True, + error=elasticsearch.NotFoundError( + message="", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body={"error": {"reason": "No searchable snapshots indices found"}}, + ), ) ) recorder = telemetry.SearchableSnapshotsStatsRecorder( @@ -4885,7 +4898,17 @@ def test_uses_indices_param_if_specified_instead_of_data_stream_names(self, es): def test_error_on_retrieval_does_not_store_metrics(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - es.indices.disk_usage.side_effect = elasticsearch.RequestError(message="error", meta=None, body=None) + es.indices.disk_usage.side_effect = elasticsearch.RequestError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) device = telemetry.DiskUsageStats({}, es, metrics_store, index_names=["foo"], data_stream_names=[]) t = telemetry.Telemetry(enabled_devices=[device.command], devices=[device]) t.on_benchmark_start() @@ -4916,7 +4939,17 @@ def test_no_indices_fails(self, es, metrics_store_cluster_level, caplog): def test_missing_all_fails(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - es.indices.disk_usage.side_effect = elasticsearch.NotFoundError(message="error", meta=None, body=None) + es.indices.disk_usage.side_effect = elasticsearch.NotFoundError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) device = telemetry.DiskUsageStats({}, es, metrics_store, index_names=["foo", "bar"], data_stream_names=[]) t = telemetry.Telemetry(enabled_devices=[device.command], devices=[device]) t.on_benchmark_start() @@ -4933,7 +4966,17 @@ def test_missing_all_fails(self, es, metrics_store_cluster_level, caplog): def test_some_mising_succeeds(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - not_found_response = elasticsearch.NotFoundError(message="error", meta=None, body=None) + not_found_response = elasticsearch.NotFoundError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) successful_response = { "_shards": {"failed": 0}, "foo": { diff --git a/tests/track/loader_test.py b/tests/track/loader_test.py index 20b9f64a3..56cc65911 100644 --- a/tests/track/loader_test.py +++ b/tests/track/loader_test.py @@ -2758,7 +2758,7 @@ def test_parse_valid_without_types(self): "indices": [ { "name": "index-historical", - "body": "body.json" + "body": "body.json", # no type information here } ], diff --git a/tests/utils/collections_test.py b/tests/utils/collections_test.py index e262b52f8..0c451bec6 100644 --- a/tests/utils/collections_test.py +++ b/tests/utils/collections_test.py @@ -18,7 +18,7 @@ import random from typing import Any, Mapping -import pytest # type: ignore +import pytest from esrally.utils import collections diff --git a/tests/utils/net_test.py b/tests/utils/net_test.py index 4fa912c9d..733acbbbd 100644 --- a/tests/utils/net_test.py +++ b/tests/utils/net_test.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import random +from typing import Mapping, Union from unittest import mock import pytest @@ -135,7 +136,8 @@ def raise_error(seconds): def test_download_http_retry_incomplete_read_retry_failure(httpserver, tmp_path): data = b"x" * 10 - short_resp = Response(headers={"Content-Length": 100, "foo": "bar"}) + headers: Mapping[str, Union[str, int]] = {"Content-Length": 100, "foo": "bar"} + short_resp = Response(headers=headers) short_resp.automatically_set_content_length = False short_resp.set_data(data) @@ -154,7 +156,8 @@ def sleep(seconds): def test_download_http_retry_incomplete_read_retry_success(httpserver, tmp_path): data = b"x" * 10 - short_resp = Response(headers={"Content-Length": 100, "foo": "bar"}) + headers: Mapping[str, Union[str, int]] = {"Content-Length": 100, "foo": "bar"} + short_resp = Response(headers=headers) short_resp.automatically_set_content_length = False short_resp.set_data(data) diff --git a/tests/utils/versions_test.py b/tests/utils/versions_test.py index 845ceb311..f03cb0329 100644 --- a/tests/utils/versions_test.py +++ b/tests/utils/versions_test.py @@ -18,7 +18,7 @@ import random import re -import pytest # type: ignore +import pytest from esrally import exceptions from esrally.utils import versions