-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into order-imports-codemod
- Loading branch information
Showing
16 changed files
with
344 additions
and
390 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import io | ||
import os | ||
import tempfile | ||
|
||
import libcst as cst | ||
from libcst.codemod import ( | ||
CodemodContext, | ||
) | ||
import yaml | ||
|
||
from codemodder.codemods.base_codemod import ( | ||
CodemodMetadata, | ||
BaseCodemod as _BaseCodemod, | ||
SemgrepCodemod as _SemgrepCodemod, | ||
) | ||
|
||
from codemodder.codemods.base_visitor import BaseTransformer | ||
from codemodder.codemods.change import Change | ||
from codemodder.file_context import FileContext | ||
from codemodder.semgrep import rule_ids_from_yaml_files | ||
from .helpers import Helpers | ||
|
||
|
||
def _populate_yaml(rule: str, metadata: CodemodMetadata) -> str: | ||
config = yaml.safe_load(io.StringIO(rule)) | ||
# TODO: handle more than rule per config? | ||
assert len(config["rules"]) == 1 | ||
config["rules"][0].setdefault("id", metadata.NAME) | ||
config["rules"][0].setdefault("message", "Semgrep found a match") | ||
config["rules"][0].setdefault("severity", "WARNING") | ||
config["rules"][0].setdefault("languages", ["python"]) | ||
return yaml.safe_dump(config) | ||
|
||
|
||
def _create_temp_yaml_file(orig_cls, metadata: CodemodMetadata): | ||
fd, path = tempfile.mkstemp() | ||
with os.fdopen(fd, "w") as ff: | ||
ff.write(_populate_yaml(orig_cls.rule(), metadata)) | ||
|
||
return [path] | ||
|
||
|
||
class _CodemodSubclassWithMetadata: | ||
def __init_subclass__(cls): | ||
# This is a pretty yucky workaround. | ||
# But it is necessary to get around the fact that these fields are | ||
# checked by __init_subclass__ of the other parents of SemgrepCodemod | ||
# first. | ||
if cls.__name__ not in ("BaseCodemod", "SemgrepCodemod"): | ||
# TODO: if we intend to continue to check class-level attributes | ||
# using this mechanism, we should add checks (or defaults) for | ||
# NAME, DESCRIPTION, and REVIEW_GUIDANCE here. | ||
|
||
cls.METADATA = CodemodMetadata( | ||
cls.DESCRIPTION, # pylint: disable=no-member | ||
cls.NAME, # pylint: disable=no-member | ||
cls.REVIEW_GUIDANCE, # pylint: disable=no-member | ||
) | ||
|
||
# This is a little bit hacky, but it also feels like the right solution? | ||
cls.CHANGE_DESCRIPTION = cls.DESCRIPTION # pylint: disable=no-member | ||
|
||
return cls | ||
|
||
|
||
class BaseCodemod( | ||
_CodemodSubclassWithMetadata, | ||
_BaseCodemod, | ||
BaseTransformer, | ||
Helpers, | ||
): | ||
CHANGESET_ALL_FILES: list = [] | ||
CHANGES_IN_FILE: list = [] | ||
|
||
def report_change(self, original_node): | ||
line_number = self.lineno_for_node(original_node) | ||
self.CHANGES_IN_FILE.append( | ||
Change(str(line_number), self.CHANGE_DESCRIPTION).to_json() | ||
) | ||
|
||
|
||
# NOTE: this shadows base_codemod.SemgrepCodemod but I can't think of a better name right now | ||
# At least it is namespaced but we might want to deconflict these things in the long term | ||
class SemgrepCodemod( | ||
BaseCodemod, | ||
_CodemodSubclassWithMetadata, | ||
_SemgrepCodemod, | ||
BaseTransformer, | ||
): | ||
def __init_subclass__(cls): | ||
super().__init_subclass__() | ||
cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA) | ||
cls.RULE_IDS = rule_ids_from_yaml_files(cls.YAML_FILES) | ||
|
||
def __init__(self, codemod_context: CodemodContext, file_context: FileContext): | ||
_SemgrepCodemod.__init__(self, file_context) | ||
BaseTransformer.__init__( | ||
self, | ||
codemod_context, | ||
self._results, | ||
file_context.line_exclude, | ||
file_context.line_include, | ||
) | ||
|
||
# TODO: there needs to be a way to generalize this so that it applies | ||
# more broadly than to just a specific kind of node. There's probably a | ||
# decent way to do this with metaprogramming. We could either apply it | ||
# broadly to every known method (which would probably have a big | ||
# performance impact). Or we could allow users to register the handler | ||
# for a specific node or nodes by means of a decorator or something | ||
# similar when they define their `on_result_found` method. | ||
# Right now this is just to demonstrate a particular use case. | ||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
if self.filter_by_result( | ||
pos_to_match | ||
) and self.filter_by_path_includes_or_excludes(pos_to_match): | ||
self.report_change(original_node) | ||
if (attr := getattr(self, "on_result_found", None)) is not None: | ||
# pylint: disable=not-callable | ||
new_node = attr(original_node, updated_node) | ||
return new_node | ||
|
||
return updated_node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import libcst as cst | ||
from libcst import matchers | ||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor | ||
from codemodder.codemods.utils import get_call_name | ||
|
||
|
||
class Helpers: | ||
def remove_unused_import(self, original_node): | ||
# pylint: disable=no-member | ||
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) | ||
|
||
def add_needed_import(self, module, obj=None): | ||
# TODO: do we need to check if this import already exists? | ||
AddImportsVisitor.add_needed_import( | ||
self.context, module, obj # pylint: disable=no-member | ||
) | ||
|
||
def update_call_target(self, original_node, new_target, replacement_args=None): | ||
# TODO: is an assertion the best way to handle this? | ||
# Or should we just return the original node if it's not a Call? | ||
assert isinstance(original_node, cst.Call) | ||
return cst.Call( | ||
func=cst.Attribute( | ||
value=cst.parse_expression(new_target), | ||
attr=cst.Name(value=get_call_name(original_node)), | ||
), | ||
args=replacement_args if replacement_args else original_node.args, | ||
) | ||
|
||
def update_arg_target(self, updated_node, new_args: list): | ||
return updated_node.with_changes( | ||
args=[new if isinstance(new, cst.Arg) else cst.Arg(new) for new in new_args] | ||
) | ||
|
||
def parse_expression(self, expression: str): | ||
return cst.parse_expression(expression) | ||
|
||
def replace_arg(self, original_node, target_arg_name, target_arg_replacement_val): | ||
"""Given a node, return its args with one arg's value changed""" | ||
assert hasattr(original_node, "args") | ||
new_args = [] | ||
for arg in original_node.args: | ||
if matchers.matches(arg.keyword, matchers.Name(target_arg_name)): | ||
new = cst.Arg( | ||
keyword=cst.parse_expression(target_arg_name), | ||
value=cst.parse_expression(target_arg_replacement_val), | ||
equal=arg.equal, | ||
) | ||
else: | ||
new = arg | ||
new_args.append(new) | ||
return new_args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,42 @@ | ||
import libcst as cst | ||
from libcst.codemod import CodemodContext | ||
import yaml | ||
|
||
from codemodder.codemods.base_codemod import ( | ||
SemgrepCodemod, | ||
CodemodMetadata, | ||
ReviewGuidance, | ||
) | ||
from codemodder.codemods.base_visitor import BaseTransformer | ||
from codemodder.codemods.change import Change | ||
from codemodder.file_context import FileContext | ||
|
||
|
||
UNSAFE_LOADERS = yaml.loader.__all__.copy() # type: ignore | ||
UNSAFE_LOADERS.remove("SafeLoader") | ||
|
||
|
||
class HardenPyyaml(SemgrepCodemod, BaseTransformer): | ||
METADATA = CodemodMetadata( | ||
DESCRIPTION=("Ensures all calls to yaml.load use `SafeLoader`."), | ||
NAME="harden-pyyaml", | ||
REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, | ||
) | ||
CHANGE_DESCRIPTION = "Ensures all calls to yaml.load use `SafeLoader`." | ||
YAML_FILES = [ | ||
"harden-pyyaml.yaml", | ||
] | ||
|
||
CHANGE_DESCRIPTION = "Adds `yaml.SafeLoader` to yaml.load calls" | ||
|
||
def __init__(self, codemod_context: CodemodContext, file_context: FileContext): | ||
SemgrepCodemod.__init__(self, file_context) | ||
BaseTransformer.__init__( | ||
self, | ||
codemod_context, | ||
self._results, | ||
file_context.line_exclude, | ||
file_context.line_include, | ||
) | ||
|
||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
if self.filter_by_result( | ||
pos_to_match | ||
) and self.filter_by_path_includes_or_excludes(pos_to_match): | ||
line_number = pos_to_match.start.line | ||
self.CHANGES_IN_FILE.append( | ||
Change(str(line_number), self.CHANGE_DESCRIPTION).to_json() | ||
) | ||
second_arg = cst.Arg(value=cst.parse_expression("yaml.SafeLoader")) | ||
return updated_node.with_changes(args=[*updated_node.args[:1], second_arg]) | ||
return updated_node | ||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.api import SemgrepCodemod | ||
|
||
|
||
class HardenPyyaml(SemgrepCodemod): | ||
NAME = "harden-pyyaml" | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW | ||
DESCRIPTION = "Ensures all calls to yaml.load use `SafeLoader`." | ||
|
||
@classmethod | ||
def rule(cls): | ||
return """ | ||
rules: | ||
- pattern-either: | ||
- patterns: | ||
- pattern: yaml.load(...) | ||
- pattern-inside: | | ||
import yaml | ||
... | ||
yaml.load(...,$ARG) | ||
- metavariable-pattern: | ||
metavariable: $ARG | ||
patterns: | ||
- pattern-not: | ||
pattern: yaml.SafeLoader | ||
- patterns: | ||
- pattern: yaml.load(...) | ||
- pattern-inside: | | ||
import yaml | ||
... | ||
yaml.load(...,Loader=$ARG) | ||
- metavariable-pattern: | ||
metavariable: $ARG | ||
patterns: | ||
- pattern-not: | ||
pattern: yaml.SafeLoader | ||
""" | ||
|
||
def on_result_found(self, _, updated_node): | ||
new_args = [*updated_node.args[:1], self.parse_expression("yaml.SafeLoader")] | ||
return self.update_arg_target(updated_node, new_args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,30 @@ | ||
import libcst as cst | ||
from libcst.codemod import CodemodContext | ||
from libcst import matchers | ||
from codemodder.codemods.base_codemod import ( | ||
SemgrepCodemod, | ||
CodemodMetadata, | ||
ReviewGuidance, | ||
) | ||
from codemodder.codemods.base_visitor import BaseTransformer | ||
from codemodder.codemods.change import Change | ||
from codemodder.file_context import FileContext | ||
from codemodder.codemods.base_codemod import ReviewGuidance | ||
from codemodder.codemods.api import SemgrepCodemod | ||
|
||
|
||
class HardenRuamel(SemgrepCodemod, BaseTransformer): | ||
METADATA = CodemodMetadata( | ||
DESCRIPTION=("Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`."), | ||
NAME="harden-ruamel", | ||
REVIEW_GUIDANCE=ReviewGuidance.MERGE_WITHOUT_REVIEW, | ||
) | ||
CHANGE_DESCRIPTION = METADATA.DESCRIPTION | ||
YAML_FILES = [ | ||
"harden-ruamel.yaml", | ||
] | ||
class HardenRuamel(SemgrepCodemod): | ||
NAME = "harden-ruamel" | ||
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW | ||
DESCRIPTION = "Ensures all unsafe calls to ruamel.yaml.YAML use `typ='safe'`." | ||
|
||
def __init__(self, codemod_context: CodemodContext, file_context: FileContext): | ||
SemgrepCodemod.__init__(self, file_context) | ||
BaseTransformer.__init__( | ||
self, | ||
codemod_context, | ||
self._results, | ||
file_context.line_exclude, | ||
file_context.line_include, | ||
) | ||
@classmethod | ||
def rule(cls): | ||
return """ | ||
rules: | ||
- pattern-either: | ||
- patterns: | ||
- pattern: ruamel.yaml.YAML(typ="unsafe", ...) | ||
- pattern-inside: | | ||
import ruamel | ||
... | ||
- patterns: | ||
- pattern: ruamel.yaml.YAML(typ="base", ...) | ||
- pattern-inside: | | ||
import ruamel | ||
... | ||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): | ||
pos_to_match = self.node_position(original_node) | ||
if self.filter_by_result( | ||
pos_to_match | ||
) and self.filter_by_path_includes_or_excludes(pos_to_match): | ||
line_number = pos_to_match.start.line | ||
self.CHANGES_IN_FILE.append( | ||
Change(str(line_number), self.CHANGE_DESCRIPTION).to_json() | ||
) | ||
new_args = update_arg_target(original_node.args, target_arg="typ") | ||
return updated_node.with_changes(args=new_args) | ||
return updated_node | ||
""" | ||
|
||
|
||
def update_arg_target(original_args, target_arg): | ||
new_args = [] | ||
for arg in original_args: | ||
if matchers.matches(arg.keyword, matchers.Name(target_arg)): | ||
new = cst.Arg( | ||
keyword=cst.parse_expression("typ"), | ||
value=cst.parse_expression('"safe"'), | ||
equal=arg.equal, | ||
) | ||
else: | ||
new = arg | ||
new_args.append(new) | ||
return new_args | ||
def on_result_found(self, original_node, updated_node): | ||
new_args = self.replace_arg(original_node, "typ", '"safe"') | ||
return self.update_arg_target(updated_node, new_args) |
Oops, something went wrong.