diff --git a/pyupgrade/_data.py b/pyupgrade/_data.py index 15e03a72..ee54b3f3 100644 --- a/pyupgrade/_data.py +++ b/pyupgrade/_data.py @@ -38,6 +38,7 @@ class State(NamedTuple): RECORD_FROM_IMPORTS = frozenset(( '__future__', + 'asyncio', 'functools', 'mmap', 'os', diff --git a/pyupgrade/_plugins/exceptions.py b/pyupgrade/_plugins/exceptions.py new file mode 100644 index 00000000..4c80c80f --- /dev/null +++ b/pyupgrade/_plugins/exceptions.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import ast +import functools +from typing import Iterable +from typing import NamedTuple + +from tokenize_rt import Offset +from tokenize_rt import Token + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._data import Version +from pyupgrade._token_helpers import arg_str +from pyupgrade._token_helpers import find_op +from pyupgrade._token_helpers import parse_call_args +from pyupgrade._token_helpers import replace_name + + +class _Target(NamedTuple): + target: str + module: str | None + name: str + min_version: Version + + +_TARGETS = ( + _Target('OSError', 'mmap', 'error', (3,)), + _Target('OSError', 'os', 'error', (3,)), + _Target('OSError', 'select', 'error', (3,)), + _Target('OSError', 'socket', 'error', (3,)), + _Target('OSError', None, 'IOError', (3,)), + _Target('OSError', None, 'EnvironmentError', (3,)), + _Target('OSError', None, 'WindowsError', (3,)), + _Target('TimeoutError', 'socket', 'timeout', (3, 10)), + _Target('TimeoutError', 'asyncio', 'TimeoutError', (3, 11)), +) + + +def _fix_except( + i: int, + tokens: list[Token], + *, + at_idx: dict[int, _Target], +) -> None: + # find all the arg strs in the tuple + except_index = i + while tokens[except_index].src != 'except': + except_index -= 1 + start = find_op(tokens, except_index, '(') + func_args, end = parse_call_args(tokens, start) + + # save the exceptions and remove the block + arg_strs = [arg_str(tokens, *arg) for arg in func_args] + del tokens[start:end] + + # rewrite the block without dupes + args = [] + for i, arg in enumerate(arg_strs): + target = at_idx.get(i) + if target is not None: + args.append(target.target) + else: + args.append(arg) + + unique_args = tuple(dict.fromkeys(args)) + + if len(unique_args) > 1: + joined = '({})'.format(', '.join(unique_args)) + elif tokens[start - 1].name != 'UNIMPORTANT_WS': + joined = f' {unique_args[0]}' + else: + joined = unique_args[0] + + new = Token('CODE', joined) + tokens.insert(start, new) + + +def _get_rewrite( + node: ast.AST, + state: State, + targets: list[_Target], +) -> _Target | None: + for target in targets: + if ( + target.module is None and + isinstance(node, ast.Name) and + node.id == target.name + ): + return target + elif ( + target.module is not None and + isinstance(node, ast.Name) and + node.id == target.name and + node.id in state.from_imports[target.module] + ): + return target + elif ( + target.module is not None and + isinstance(node, ast.Attribute) and + isinstance(node.value, ast.Name) and + node.attr == target.name and + node.value.id == target.module + ): + return target + else: + return None + + +def _alias_cbs( + node: ast.expr, + state: State, + targets: list[_Target], +) -> Iterable[tuple[Offset, TokenFunc]]: + target = _get_rewrite(node, state, targets) + if target is not None: + func = functools.partial( + replace_name, + name=target.name, + new=target.target, + ) + yield ast_to_offset(node), func + + +@register(ast.Raise) +def visit_Raise( + state: State, + node: ast.Raise, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + targets = [ + target for target in _TARGETS + if state.settings.min_version >= target.min_version + ] + if node.exc is not None: + yield from _alias_cbs(node.exc, state, targets) + if isinstance(node.exc, ast.Call): + yield from _alias_cbs(node.exc.func, state, targets) + + +@register(ast.Try) +def visit_Try( + state: State, + node: ast.Try, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + targets = [ + target for target in _TARGETS + if state.settings.min_version >= target.min_version + ] + for handler in node.handlers: + if isinstance(handler.type, ast.Tuple): + at_idx = {} + for i, elt in enumerate(handler.type.elts): + target = _get_rewrite(elt, state, targets) + if target is not None: + at_idx[i] = target + + if at_idx: + func = functools.partial(_fix_except, at_idx=at_idx) + yield ast_to_offset(handler.type), func + elif handler.type is not None: + yield from _alias_cbs(handler.type, state, targets) diff --git a/pyupgrade/_plugins/oserror_aliases.py b/pyupgrade/_plugins/oserror_aliases.py deleted file mode 100644 index c11cadfd..00000000 --- a/pyupgrade/_plugins/oserror_aliases.py +++ /dev/null @@ -1,136 +0,0 @@ -from __future__ import annotations - -import ast -import functools -from typing import Iterable - -from tokenize_rt import Offset -from tokenize_rt import Token - -from pyupgrade._ast_helpers import ast_to_offset -from pyupgrade._data import register -from pyupgrade._data import State -from pyupgrade._data import TokenFunc -from pyupgrade._token_helpers import arg_str -from pyupgrade._token_helpers import find_op -from pyupgrade._token_helpers import parse_call_args -from pyupgrade._token_helpers import replace_name - -ERROR_NAMES = frozenset(('EnvironmentError', 'IOError', 'WindowsError')) -ERROR_MODULES = frozenset(('mmap', 'select', 'socket', 'os')) - - -def _fix_oserror_except( - i: int, - tokens: list[Token], - *, - from_imports: dict[str, set[str]], -) -> None: - # find all the arg strs in the tuple - except_index = i - while tokens[except_index].src != 'except': - except_index -= 1 - start = find_op(tokens, except_index, '(') - func_args, end = parse_call_args(tokens, start) - - # save the exceptions and remove the block - arg_strs = [arg_str(tokens, *arg) for arg in func_args] - del tokens[start:end] - - # rewrite the block without dupes - args = [] - for arg in arg_strs: - left, part, right = arg.partition('.') - if left in ERROR_MODULES and part == '.' and right == 'error': - args.append('OSError') - elif left in ERROR_NAMES and part == right == '': - args.append('OSError') - elif ( - left == 'error' and - part == right == '' and - any('error' in from_imports[mod] for mod in ERROR_MODULES) - ): - args.append('OSError') - else: - args.append(arg) - - unique_args = tuple(dict.fromkeys(args)) - - if len(unique_args) > 1: - joined = '({})'.format(', '.join(unique_args)) - elif tokens[start - 1].name != 'UNIMPORTANT_WS': - joined = f' {unique_args[0]}' - else: - joined = unique_args[0] - - new = Token('CODE', joined) - tokens.insert(start, new) - - -def _is_oserror_alias( - node: ast.AST, - from_imports: dict[str, set[str]], -) -> tuple[Offset, str] | None: - if isinstance(node, ast.Name) and node.id in ERROR_NAMES: - return ast_to_offset(node), node.id - elif ( - isinstance(node, ast.Name) and - node.id == 'error' and - any(node.id in from_imports[mod] for mod in ERROR_MODULES) - ): - return ast_to_offset(node), node.id - elif ( - isinstance(node, ast.Attribute) and - isinstance(node.value, ast.Name) and - node.value.id in ERROR_MODULES and - node.attr == 'error' - ): - return ast_to_offset(node), node.attr - else: - return None - - -def _oserror_alias_cbs( - node: ast.AST, - from_imports: dict[str, set[str]], -) -> Iterable[tuple[Offset, TokenFunc]]: - offset_name = _is_oserror_alias(node, from_imports) - if offset_name is not None: - offset, name = offset_name - func = functools.partial(replace_name, name=name, new='OSError') - yield offset, func - - -@register(ast.Raise) -def visit_Raise( - state: State, - node: ast.Raise, - parent: ast.AST, -) -> Iterable[tuple[Offset, TokenFunc]]: - if node.exc is not None: - yield from _oserror_alias_cbs(node.exc, state.from_imports) - if isinstance(node.exc, ast.Call): - yield from _oserror_alias_cbs(node.exc.func, state.from_imports) - - -@register(ast.Try) -def visit_Try( - state: State, - node: ast.Try, - parent: ast.AST, -) -> Iterable[tuple[Offset, TokenFunc]]: - for handler in node.handlers: - if ( - isinstance(handler.type, ast.Tuple) and - any( - _is_oserror_alias(elt, state.from_imports) - for elt in handler.type.elts - ) - ): - func = functools.partial( - _fix_oserror_except, - from_imports=state.from_imports, - ) - yield ast_to_offset(handler.type), func - elif handler.type is not None: - yield from _oserror_alias_cbs(handler.type, state.from_imports) diff --git a/tests/features/exceptions_test.py b/tests/features/exceptions_test.py new file mode 100644 index 00000000..e855bf2a --- /dev/null +++ b/tests/features/exceptions_test.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'try: ...\n' + 'except Exception:\n' + ' raise', + id='empty raise', + ), + pytest.param( + 'try: ...\n' + 'except: ...\n', + id='empty try-except', + ), + pytest.param( + 'try: ...\n' + 'except AssertionError: ...\n', + id='unrelated exception type as name', + ), + pytest.param( + 'try: ...\n' + 'except (AssertionError,): ...\n', + id='unrelated exception type as tuple', + ), + pytest.param( + 'try: ...\n' + 'except OSError: ...\n', + id='already rewritten name', + ), + pytest.param( + 'try: ...\n' + 'except (TypeError, OSError): ...\n', + id='already rewritten tuple', + ), + pytest.param( + 'from .os import error\n' + 'raise error(1)\n', + id='same name as rewrite but relative import', + ), + pytest.param( + 'from os import error\n' + 'def f():\n' + ' error = 3\n' + ' return error\n', + id='not rewriting outside of raise or except', + ), + pytest.param( + 'from os import error as the_roof\n' + 'raise the_roof()\n', + id='ignoring imports with aliases', + ), + # TODO: could probably rewrite these but leaving for now + pytest.param( + 'import os\n' + 'try: ...\n' + 'except (os).error: ...\n', + id='weird parens', + ), + ), +) +def test_fix_exceptions_noop(s): + assert _fix_plugins(s, settings=Settings()) == s + + +@pytest.mark.parametrize( + ('s', 'version'), + ( + pytest.param( + 'raise socket.timeout()', + (3, 9), + id='raise socket.timeout is noop <3.10', + ), + pytest.param( + 'try: ...\n' + 'except socket.timeout: ...\n', + (3, 9), + id='except socket.timeout is noop <3.10', + ), + pytest.param( + 'raise asyncio.TimeoutError()', + (3, 10), + id='raise asyncio.TimeoutError() is noop <3.11', + ), + pytest.param( + 'try: ...\n' + 'except asyncio.TimeoutError: ...\n', + (3, 10), + id='except asyncio.TimeoutError() is noop <3.11', + ), + ), +) +def test_fix_exceptions_version_specific_noop(s, version): + assert _fix_plugins(s, settings=Settings(min_version=version)) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'raise mmap.error(1)\n', + 'raise OSError(1)\n', + id='mmap.error', + ), + pytest.param( + 'raise os.error(1)\n', + 'raise OSError(1)\n', + id='os.error', + ), + pytest.param( + 'raise select.error(1)\n', + 'raise OSError(1)\n', + id='select.error', + ), + pytest.param( + 'raise socket.error(1)\n', + 'raise OSError(1)\n', + id='socket.error', + ), + pytest.param( + 'raise IOError(1)\n', + 'raise OSError(1)\n', + id='IOError', + ), + pytest.param( + 'raise EnvironmentError(1)\n', + 'raise OSError(1)\n', + id='EnvironmentError', + ), + pytest.param( + 'raise WindowsError(1)\n', + 'raise OSError(1)\n', + id='WindowsError', + ), + pytest.param( + 'raise os.error\n', + 'raise OSError\n', + id='raise exception type without call', + ), + pytest.param( + 'from os import error\n' + 'raise error(1)\n', + 'from os import error\n' + 'raise OSError(1)\n', + id='raise via from import', + ), + pytest.param( + 'try: ...\n' + 'except WindowsError: ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='except of name', + ), + pytest.param( + 'try: ...\n' + 'except os.error: ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='except of dotted name', + ), + pytest.param( + 'try: ...\n' + 'except (WindowsError,): ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='except of name in tuple', + ), + pytest.param( + 'try: ...\n' + 'except (os.error,): ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='except of dotted name in tuple', + ), + pytest.param( + 'try: ...\n' + 'except (WindowsError, KeyError, OSError): ...\n', + + 'try: ...\n' + 'except (OSError, KeyError): ...\n', + + id='deduplicates exception types', + ), + pytest.param( + 'try: ...\n' + 'except (os.error, WindowsError, OSError): ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='deduplicates to a single type', + ), + pytest.param( + 'try: ...\n' + 'except(os.error, WindowsError, OSError): ...\n', + + 'try: ...\n' + 'except OSError: ...\n', + + id='deduplicates to a single type without whitespace', + ), + pytest.param( + 'from wat import error\n' + 'try: ...\n' + 'except (WindowsError, error): ...\n', + + 'from wat import error\n' + 'try: ...\n' + 'except (OSError, error): ...\n', + + id='leave unrelated error names alone', + ), + ), +) +def test_fix_exceptions(s, expected): + assert _fix_plugins(s, settings=Settings()) == expected + + +@pytest.mark.parametrize( + ('s', 'expected', 'version'), + ( + pytest.param( + 'raise socket.timeout(1)\n', + 'raise TimeoutError(1)\n', + (3, 10), + id='socket.timeout', + ), + pytest.param( + 'raise asyncio.TimeoutError(1)\n', + 'raise TimeoutError(1)\n', + (3, 11), + id='asyncio.TimeoutError', + ), + ), +) +def test_fix_exceptions_versioned(s, expected, version): + assert _fix_plugins(s, settings=Settings(min_version=version)) == expected + + +def test_can_rewrite_disparate_names(): + s = '''\ +try: ... +except (asyncio.TimeoutError, WindowsError): ... +''' + expected = '''\ +try: ... +except (TimeoutError, OSError): ... +''' + + assert _fix_plugins(s, settings=Settings(min_version=(3, 11))) == expected diff --git a/tests/features/oserror_aliases_test.py b/tests/features/oserror_aliases_test.py deleted file mode 100644 index 12798a65..00000000 --- a/tests/features/oserror_aliases_test.py +++ /dev/null @@ -1,516 +0,0 @@ -from __future__ import annotations - -import pytest - -from pyupgrade._data import Settings -from pyupgrade._main import _fix_plugins -from pyupgrade._plugins.oserror_aliases import ERROR_MODULES -from pyupgrade._plugins.oserror_aliases import ERROR_NAMES - - -@pytest.mark.parametrize('alias', ERROR_NAMES) -@pytest.mark.parametrize( - ('tpl', 'expected'), - ( - ( - 'try:\n' - ' pass\n' - 'except {alias}:\n' - ' pass\n', - - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'try:\n' - ' pass\n' - 'except ({alias},):\n' - ' pass\n', - - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'try:\n' - ' pass\n' - 'except ({alias}, KeyError, OSError):\n' - ' pass\n', - - 'try:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - ), - ( - 'try:\n' - ' pass\n' - 'except ({alias}, OSError, IOError):\n' - ' pass\n', - - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'try:\n' - ' pass\n' - 'except({alias}, OSError, IOError):\n' - ' pass\n', - - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - pytest.param( - 'from wat import error\n' - 'try:\n' - ' pass\n' - 'except ({alias}, error):\n' - ' pass\n', - - 'from wat import error\n' - 'try:\n' - ' pass\n' - 'except (OSError, error):\n' - ' pass\n', - - id='preserve unrelated .error class', - ), - ), -) -def test_fix_oserror_aliases_try(alias, tpl, expected): - s = tpl.format(alias=alias) - ret = _fix_plugins(s, settings=Settings()) - assert ret == expected - - -@pytest.mark.parametrize( - 's', - ( - pytest.param('raise\n', id='empty raise'), - # empty try-except - 'try:\n' - ' pass\n' - 'except:\n' - ' pass\n', - # no exception to rewrite - 'try:\n' - ' pass\n' - 'except AssertionError:\n' - ' pass\n', - # no exception to rewrite - 'try:\n' - ' pass\n' - 'except (' - ' AssertionError,' - '):\n' - ' pass\n', - # already correct - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - # already correct - 'try:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - pytest.param( - 'import mmap\n' - 'try:\n' - ' pass\n' - 'except (mmap).error:\n' - ' pass\n', - id='weird parens', - ), - pytest.param( - 'from .mmap import error\n' - 'raise error("hi")\n', - id='relative imports', - ), - ), -) -def test_fix_oserror_aliases_noop(s): - assert _fix_plugins(s, settings=Settings()) == s - - -@pytest.mark.parametrize('imp', ERROR_MODULES) -@pytest.mark.parametrize( - 'tpl', - ( - # if the error isn't in a try or except it shouldn't be rewritten - # to avoid false positives - 'from {imp} import error\n\n' - 'def foo():\n' - ' error = 3\n', - ' return error\n', - # renaming things for weird reasons - 'from {imp} import error as the_roof\n' - 'raise the_roof()\n', - ), -) -def test_fix_oserror_aliases_noop_tpl(imp, tpl): - s = tpl.format(imp=imp) - assert _fix_plugins(s, settings=Settings()) == s - - -@pytest.mark.parametrize('imp', ERROR_MODULES) -@pytest.mark.parametrize( - ('tpl', 'expected_tpl'), - ( - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except {imp}.error:\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except ({imp}.error,):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except ({imp}.error, KeyError, OSError):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except ({imp}.error, OSError, IOError):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, {imp}.error, IOError):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, {imp}.error, IOError):\n' - ' pass\n' - 'except (OSError, {imp}.error, KeyError):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except({imp}.error, OSError, IOError):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except(' - ' {imp}.error,' - ' OSError,' - ' IOError,' - '):\n' - ' pass\n', - - 'import {imp}\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except error:\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (error,):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (error, KeyError, OSError):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (error, OSError, IOError):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, error, OSError):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except (OSError, error, OSError):\n' - ' pass\n' - 'except (OSError, error, KeyError):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n' - 'except (OSError, KeyError):\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except(error, OSError, IOError):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ( - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except(' - ' error,' - ' OSError,' - ' IOError,' - '):\n' - ' pass\n', - - 'from {imp} import error\n\n' - 'try:\n' - ' pass\n' - 'except OSError:\n' - ' pass\n', - ), - ), -) -def test_fix_oserror_complex_aliases_try(imp, tpl, expected_tpl): - s, expected = tpl.format(imp=imp), expected_tpl.format(imp=imp) - ret = _fix_plugins(s, settings=Settings()) - assert ret == expected - - -@pytest.mark.parametrize('alias', ERROR_NAMES) -@pytest.mark.parametrize( - ('tpl', 'expected'), - ( - ('raise {alias}', 'raise OSError'), - ('raise {alias}()', 'raise OSError()'), - ('raise {alias}(1)', 'raise OSError(1)'), - ('raise {alias}(1, 2)', 'raise OSError(1, 2)'), - ( - 'raise {alias}(\n' - ' 1,\n' - ' 2,\n' - ')', - 'raise OSError(\n' - ' 1,\n' - ' 2,\n' - ')', - ), - ), -) -def test_fix_oserror_aliases_raise(alias, tpl, expected): - s = tpl.format(alias=alias) - ret = _fix_plugins(s, settings=Settings()) - assert ret == expected - - -@pytest.mark.parametrize('imp', ERROR_MODULES) -@pytest.mark.parametrize( - ('tpl', 'expected_tpl'), - ( - ( - 'import {imp}\n\n' - 'raise {imp}.error\n', - - 'import {imp}\n\n' - 'raise OSError\n', - ), - ( - 'import {imp}\n\n' - 'raise {imp}.error()\n', - - 'import {imp}\n\n' - 'raise OSError()\n', - ), - ( - 'import {imp}\n\n' - 'raise {imp}.error(1)\n', - - 'import {imp}\n\n' - 'raise OSError(1)\n', - ), - ( - 'import {imp}\n\n' - 'raise {imp}.error(1, 2)\n', - - 'import {imp}\n\n' - 'raise OSError(1, 2)\n', - ), - ( - 'import {imp}\n\n' - 'raise {imp}.error(\n' - ' 1,\n' - ' 2,\n' - ')', - - 'import {imp}\n\n' - 'raise OSError(\n' - ' 1,\n' - ' 2,\n' - ')', - ), - ( - 'from {imp} import error\n\n' - 'raise error\n', - - 'from {imp} import error\n\n' - 'raise OSError\n', - ), - ( - 'from {imp} import error\n\n' - 'raise error()\n', - - 'from {imp} import error\n\n' - 'raise OSError()\n', - ), - ( - 'from {imp} import error\n\n' - 'raise error(1)\n', - - 'from {imp} import error\n\n' - 'raise OSError(1)\n', - ), - ( - 'from {imp} import error\n\n' - 'raise error(1, 2)\n', - - 'from {imp} import error\n\n' - 'raise OSError(1, 2)\n', - ), - ( - 'from {imp} import error\n\n' - 'raise error(\n' - ' 1,\n' - ' 2,\n' - ')', - - 'from {imp} import error\n\n' - 'raise OSError(\n' - ' 1,\n' - ' 2,\n' - ')', - ), - ), -) -def test_fix_oserror_complex_aliases_raise(imp, tpl, expected_tpl): - s, expected = tpl.format(imp=imp), expected_tpl.format(imp=imp) - ret = _fix_plugins(s, settings=Settings()) - assert ret == expected