From 153701b8b3f35a9d9762720a72026d91c863e0f8 Mon Sep 17 00:00:00 2001 From: Aditya Pillai Date: Fri, 6 Oct 2023 10:57:15 -0700 Subject: [PATCH] refactor load_compiled_module_from_source Summary: Refactors `load_compiled_module_from_source` to only run the strict analysis that requires the `mod` object from `loader.check_source` in the case where either static or strict flags are set. I wanted to isolate out this behavior in this diff before I separate out static and strict analysis to make that change a 1-liner to ensure that I'm not creating any bugs down the line. The diff to separate strict/static analysis will come in the next diff. Reviewed By: carljm Differential Revision: D49735184 fbshipit-source-id: 0d7c88c5ff9d1e2758dbf19bcf09edd5c2d29c61 --- Lib/compiler/strict/compiler.py | 105 ++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 40 deletions(-) diff --git a/Lib/compiler/strict/compiler.py b/Lib/compiler/strict/compiler.py index 46b8f71b3dd..0cdaa15ed3d 100644 --- a/Lib/compiler/strict/compiler.py +++ b/Lib/compiler/strict/compiler.py @@ -123,7 +123,9 @@ def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]: stubKind = mod.stub_kind if STUB_KIND_MASK_TYPING & stubKind: root = remove_annotations(root) - root = self._get_rewritten_ast(name, mod, root, optimize) + root = self._get_rewritten_ast( + name, root, getSymbolTable(mod), mod.file_name, optimize + ) log = self.log_time_func ctx = ( log()(name, mod.file_name, "declaration_visit") @@ -138,13 +140,17 @@ def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]: return self.modules.get(name) def _get_rewritten_ast( - self, name: str, mod: StrictAnalysisResult, root: ast.Module, optimize: int + self, + name: str, + root: ast.Module, + symbols: PythonSymbolTable, + filename: str, + optimize: int, ) -> ast.Module: - symbols = getSymbolTable(mod) return rewrite( root, symbols, - mod.file_name, + filename, name, optimize=optimize, is_static=True, @@ -169,38 +175,58 @@ def load_compiled_module_from_source( if override_flags and override_flags.is_strict: self.logger.debug(f"Forcibly treating module {name} as strict") self.loader.set_force_strict_by_name(name) - # TODO(pilleye): Only call this when no side effect analysis is requested + pyast = ast.parse(source) + symbols = symtable.symtable(source, filename, "exec") + flags = FlagExtractor().get_flags(pyast).merge(override_flags) + + if not flags.is_static and not flags.is_strict: + code = self._compile_basic(name, pyast, filename, optimize) + return (code, False) + + # TODO: Remove the check when static is enabled in the next diff to isolate errors + is_valid_strict = False + if flags.is_strict or flags.is_static: + is_valid_strict = self._strict_analyze( + source, flags, symbols, filename, name, submodule_search_locations + ) + + if flags.is_static: + code = self._compile_static(pyast, symbols, filename, name, optimize) + return (code, is_valid_strict) + else: + code = self._compile_strict(pyast, symbols, filename, name, optimize) + return (code, is_valid_strict) + + def _strict_analyze( + self, + source: str | bytes, + flags: Flags, + symbols: PythonSymbolTable, + filename: str, + name: str, + submodule_search_locations: Optional[List[str]] = None, + ) -> bool: mod = self.loader.check_source( source, filename, name, submodule_search_locations or [] ) - flags = FlagExtractor().get_flags(pyast).merge(override_flags) - errors = mod.errors - is_valid_strict = ( - mod.is_valid and len(errors) == 0 and (flags.is_static or flags.is_strict) - ) - if errors and self.raise_on_error: - # if raise on error, just raise the first error - error = errors[0] + is_valid_strict = mod.is_valid and len(mod.errors) == 0 + + if mod.errors and self.raise_on_error: + error = mod.errors[0] raise StrictModuleError(error[0], error[1], error[2], error[3]) - elif is_valid_strict: - symbols = symtable.symtable(source, filename, "exec") - try: - check_class_conflict(pyast, filename, symbols) - except StrictModuleError as e: - if self.raise_on_error: - raise - mod.errors.append((e.msg, e.filename, e.lineno, e.col)) - - if not is_valid_strict: - code = self._compile_basic(name, pyast, filename, optimize) - elif flags.is_static: - code = self._compile_static(mod, filename, name, optimize) - else: - code = self._compile_strict(mod, filename, name, optimize) - return code, is_valid_strict + # TODO: Figure out if we need to run this analysis. This should be done only for + # static analysis and not necessarily for strict modules. Keeping it for now since + # it is currently running with the strict compiler. + try: + check_class_conflict(mod.ast, filename, symbols) + except StrictModuleError as e: + if self.raise_on_error: + raise + + return is_valid_strict def _compile_basic( self, name: str, root: ast.Module, filename: str, optimize: int @@ -215,14 +241,14 @@ def _compile_basic( def _compile_strict( self, - mod: StrictAnalysisResult, + root: ast.Module, + symbols: PythonSymbolTable, filename: str, name: str, optimize: int, ) -> CodeType: - symbols = getSymbolTable(mod) tree = rewrite( - mod.ast, + root, symbols, filename, name, @@ -233,21 +259,20 @@ def _compile_strict( def _compile_static( self, - mod: StrictAnalysisResult, + root: ast.Module, + symbols: PythonSymbolTable, filename: str, name: str, optimize: int, ) -> CodeType | None: - root = self.ast_cache.get(name) - if root is None: - root = self._get_rewritten_ast(name, mod, mod.ast, optimize) - code = None - + root = self.ast_cache.get(name) or self._get_rewritten_ast( + name, root, symbols, filename, optimize + ) try: log = self.log_time_func ctx = log()(name, filename, "compile") if log else nullcontext() with ctx: - code = self.compile( + return self.compile( name, filename, root, @@ -266,4 +291,4 @@ def _compile_static( if self.raise_on_error: raise err - return code + return None