diff --git a/codemodder/codemods/api/helpers.py b/codemodder/codemods/api/helpers.py index 3417b01e..70b32b3d 100644 --- a/codemodder/codemods/api/helpers.py +++ b/codemodder/codemods/api/helpers.py @@ -8,13 +8,13 @@ def remove_unused_import(self, original_node): # pylint: disable=no-member RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node) - def add_needed_import(self, import_name): + def add_needed_import(self, module, obj=None): # TODO: do we need to check if this import already exists? AddImportsVisitor.add_needed_import( - self.context, import_name # pylint: disable=no-member + self.context, module, obj # pylint: disable=no-member ) - def update_call_target(self, original_node, new_target): + def update_call_target(self, original_node, new_target, replacement_args=None): # TODO: is an assertion the best way to handle this? # Or should we just return the original node if it's not a Call? assert isinstance(original_node, cst.Call) @@ -23,7 +23,7 @@ def update_call_target(self, original_node, new_target): value=cst.parse_expression(new_target), attr=cst.Name(value=get_call_name(original_node)), ), - args=original_node.args, + args=replacement_args if replacement_args else original_node.args, ) def update_arg_target(self, updated_node, new_arg): diff --git a/codemodder/codemods/process_creation_sandbox.py b/codemodder/codemods/process_creation_sandbox.py index eba2661e..3b4bfe12 100644 --- a/codemodder/codemods/process_creation_sandbox.py +++ b/codemodder/codemods/process_creation_sandbox.py @@ -1,62 +1,38 @@ import libcst as cst -from libcst.codemod import CodemodContext -from codemodder.file_context import FileContext from codemodder.dependency_manager import DependencyManager -from libcst.codemod.visitors import AddImportsVisitor -from codemodder.codemods.change import Change -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, - ReviewGuidance, -) -from codemodder.codemods.base_visitor import BaseTransformer -from codemodder.codemods.utils import get_call_name +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod -replacement_import = "safe_command" - -class ProcessSandbox(SemgrepCodemod, BaseTransformer): - METADATA = CodemodMetadata( - DESCRIPTION=( - "Replaces subprocess.{func} with more secure safe_command library functions." - ), - NAME="process-sandbox", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, +class ProcessSandbox(SemgrepCodemod): + NAME = "sandbox-process-creation" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW + DESCRIPTION = ( + "Replaces subprocess.{func} with more secure safe_command library functions." ) - CHANGE_DESCRIPTION = "Switch use of subprocess for security.safe_command" - YAML_FILES = [ - "sandbox_process_creation.yaml", - ] - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__( - self, - codemod_context, - self._results, - file_context.line_exclude, - file_context.line_include, - ) + @classmethod + def rule(cls): + return """ + rules: + - pattern-either: + - patterns: + - pattern: subprocess.run(...) + - pattern-inside: | + import subprocess + ... + - patterns: + - pattern: subprocess.call(...) + - pattern-inside: | + import subprocess + ... + """ - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: - pos_to_match = self.node_position(original_node) - if self.filter_by_result( - pos_to_match - ) and self.filter_by_path_includes_or_excludes(pos_to_match): - line_number = pos_to_match.start.line - self.CHANGES_IN_FILE.append( - Change(str(line_number), self.CHANGE_DESCRIPTION).to_json() - ) - AddImportsVisitor.add_needed_import( - self.context, "security", "safe_command" - ) - DependencyManager().add(["security==1.0.1"]) - new_call = cst.Call( - func=cst.Attribute( - value=cst.Name(value=replacement_import), - attr=cst.Name(value=get_call_name(original_node)), - ), - args=[cst.Arg(original_node.func), *original_node.args], - ) - return new_call - return updated_node + def on_result_found(self, original_node, updated_node): + self.add_needed_import("security", "safe_command") + DependencyManager().add(["security==1.0.1"]) + return self.update_call_target( + updated_node, + "safe_command", + replacement_args=[cst.Arg(original_node.func), *original_node.args], + ) diff --git a/integration_tests/test_multiple.py b/integration_tests/test_multiple.py index c555f5f4..d26392d5 100644 --- a/integration_tests/test_multiple.py +++ b/integration_tests/test_multiple.py @@ -69,14 +69,7 @@ def _assert_codetf_output(self): ) assert len(limit_readline["changeset"][0]["changes"]) == 1 - process_sandbox = sorted_results[5] - assert len(process_sandbox["changeset"]) == 1 - assert ( - process_sandbox["changeset"][0]["path"] == "tests/samples/make_process.py" - ) - assert len(process_sandbox["changeset"][0]["changes"]) == 4 - - unnecessary_f_str = sorted_results[6] + unnecessary_f_str = sorted_results[5] assert len(unnecessary_f_str["changeset"]) == 1 assert ( unnecessary_f_str["changeset"][0]["path"] @@ -84,6 +77,13 @@ def _assert_codetf_output(self): ) assert len(unnecessary_f_str["changeset"][0]["changes"]) == 1 + process_sandbox = sorted_results[6] + assert len(process_sandbox["changeset"]) == 1 + assert ( + process_sandbox["changeset"][0]["path"] == "tests/samples/make_process.py" + ) + assert len(process_sandbox["changeset"][0]["changes"]) == 4 + secure_random = sorted_results[7] assert len(secure_random["changeset"]) == 1 assert (