Skip to content

Commit

Permalink
fix/refactor-scan-part-1 (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanpulver authored Jan 3, 2025
1 parent 9d3acde commit 3f5882f
Showing 1 changed file with 196 additions and 77 deletions.
273 changes: 196 additions & 77 deletions safety/scan/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


cli_apps_opts = {"rich_markup_mode": "rich", "cls": SafetyCLISubGroup}

scan_project_app = typer.Typer(**cli_apps_opts)
scan_system_app = typer.Typer(**cli_apps_opts)

Expand Down Expand Up @@ -258,9 +257,6 @@ def generate_cve_details(files: List[FileModel]) -> List[Dict[str, Any]]:
return sort_cve_data(cve_data)





def add_cve_details_to_report(report_to_output: str, files: List[FileModel]) -> str:
"""
Add CVE details to the JSON report output.
Expand Down Expand Up @@ -294,6 +290,185 @@ def generate_updates_arguments() -> List:
return fixes


def validate_authentication(ctx: typer.Context) -> None:
"""
Validates that the user is authenticated.
Args:
ctx (typer.Context): The Typer context object.
Raises:
SafetyError: If the user is not authenticated.
"""
if not ctx.obj.metadata.authenticated:
raise SafetyError("Authentication required. Please run 'safety auth login' to authenticate before using this command.")


def generate_fixes_target(apply_updates: bool) -> List:
"""
Generates a list of update targets if `apply_updates` is enabled.
Args:
apply_updates (bool): Whether to generate fixes target.
Returns:
List: A list of update targets if enabled, otherwise an empty list.
"""
return generate_updates_arguments() if apply_updates else []


def validate_save_as(ctx: typer.Context, save_as: Optional[Tuple[ScanExport, Path]]) -> None:
"""
Ensures the `save_as` parameters are valid.
Args:
ctx (typer.Context): The Typer context object.
save_as (Optional[Tuple[ScanExport, Path]]): The save-as parameters.
"""
if not all(save_as):
ctx.params["save_as"] = None


def initialize_file_finder(ctx: typer.Context, target: Path, console: Console, ecosystems: List[Ecosystem]) -> FileFinder:
"""
Initializes the FileFinder for scanning files in the target directory.
Args:
ctx (typer.Context): The Typer context object.
target (Path): The target directory to scan.
console (Console): The console object for logging.
ecosystems (List[Ecosystem]): The list of scannable ecosystems.
Returns:
FileFinder: An initialized FileFinder object.
"""
to_include = {
file_type: paths
for file_type, paths in ctx.obj.config.scan.include_files.items()
if file_type.ecosystem in ecosystems
}

file_finder = FileFinder(
target=target,
ecosystems=ecosystems,
max_level=ctx.obj.config.scan.max_depth,
exclude=ctx.obj.config.scan.ignore,
include_files=to_include,
console=console,
)

# Download necessary assets for each handler
for handler in file_finder.handlers:
if handler.ecosystem:
wait_msg = "Fetching Safety's vulnerability database..."
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
handler.download_required_assets(ctx.obj.auth.client)

return file_finder


def scan_project_directory(file_finder: FileFinder, console: Console) -> Tuple[Path, Dict]:
"""
Scans the project directory and identifies relevant files for analysis.
Args:
file_finder (FileFinder): Initialized file finder object.
console (Console): Console for logging output.
Returns:
Tuple[Path, Dict]: The base path of the project and a dictionary of file paths grouped by type.
"""
wait_msg = "Scanning project directory"
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
path, file_paths = file_finder.search()
print_detected_ecosystems_section(console, file_paths, include_safety_prjs=True)
return path, file_paths


def detect_dependency_vulnerabilities(console, dependency_vuln_detected):
"""
Prints a message indicating that dependency vulnerabilities were detected.
"""
if not dependency_vuln_detected:
console.print()
console.print("Dependency vulnerabilities detected:")
return True
return dependency_vuln_detected


def print_file_info(console, path, target):
"""
Prints the file information for vulnerabilities.
"""
console.print()
msg = f":pencil: [file_title]{path.relative_to(target)}:[/file_title]"
console.print(msg, emoji=True)


def sort_and_filter_vulnerabilities(vulnerabilities, key_func, reverse=True):
"""
Sorts and filters vulnerabilities.
"""
return sorted(
[vuln for vuln in vulnerabilities if not vuln.ignored],
key=key_func,
reverse=reverse
)


def count_critical_vulnerabilities(vulnerabilities: List[Vulnerability]) -> int:
"""
Count the number of critical vulnerabilities in a list of vulnerabilities.
Args:
vulnerabilities (List[Vulnerability]): List of vulnerabilities to evaluate.
Returns:
int: The number of vulnerabilities with a critical severity level.
"""
return sum(
1 for vuln in vulnerabilities
if vuln.severity
and vuln.severity.cvssv3
and vuln.severity.cvssv3.get("base_severity", "none").lower() == VulnerabilitySeverityLabels.CRITICAL.value.lower()
)


def generate_vulnerability_message(spec_name: str, spec_raw: str, vulns_found: int, critical_vulns_count: int, vuln_word: str) -> str:
"""
Generate a formatted message for vulnerabilities in a specific dependency.
Args:
spec_name (str): Name of the dependency.
spec_raw (str): Raw specification string of the dependency.
vulns_found (int): Number of vulnerabilities found.
critical_vulns_count (int): Number of critical vulnerabilities found.
vuln_word (str): Pluralized form of the word "vulnerability."
Returns:
str: Formatted vulnerability message.
"""
msg = f"[dep_name]{spec_name}[/dep_name][specifier]{spec_raw.replace(spec_name, '')}[/specifier] [{vulns_found} {vuln_word} found"
if vulns_found > 3 and critical_vulns_count > 0:
msg += f", [brief_severity]including {critical_vulns_count} critical severity {pluralize('vulnerability', critical_vulns_count)}[/brief_severity]"
return msg


def render_vulnerabilities(vulns_to_report: List[Vulnerability], console: Console, detailed_output: bool) -> None:
"""
Render vulnerabilities to the console.
Args:
vulns_to_report (List[Vulnerability]): List of vulnerabilities to render.
console (Console): Console object for printing.
detailed_output (bool): Whether to display detailed output.
"""
for vuln in vulns_to_report:
render_to_console(
vuln, console, rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=detailed_output
)


@scan_project_app.command(
cls=SafetyCLICommand,
help=CLI_SCAN_COMMAND_HELP,
Expand Down Expand Up @@ -371,71 +546,35 @@ def scan(ctx: typer.Context,
Scans a project (defaulted to the current directory) for supply-chain security and configuration issues
"""

if not ctx.obj.metadata.authenticated:
raise SafetyError("Authentication required. Please run 'safety auth login' to authenticate before using this command.")

# Generate update arguments if apply updates option is enabled
fixes_target = []
if apply_updates:
fixes_target = generate_updates_arguments()

# Ensure save_as params are correctly set
if not all(save_as):
ctx.params["save_as"] = None
validate_authentication(ctx)
fixes_target = generate_fixes_target(apply_updates)
validate_save_as(ctx, save_as)

console = ctx.obj.console
ecosystems = [Ecosystem(member.value) for member in list(ScannableEcosystems)]
to_include = {file_type: paths for file_type, paths in ctx.obj.config.scan.include_files.items() if file_type.ecosystem in ecosystems}

# Initialize file finder
file_finder = FileFinder(target=target, ecosystems=ecosystems,
max_level=ctx.obj.config.scan.max_depth,
exclude=ctx.obj.config.scan.ignore,
include_files=to_include,
console=console)

# Download necessary assets for each handler
for handler in file_finder.handlers:
if handler.ecosystem:
wait_msg = "Fetching Safety's vulnerability database..."
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
handler.download_required_assets(ctx.obj.auth.client)

# Start scanning the project directory
wait_msg = "Scanning project directory"

path = None
file_paths = {}

with console.status(wait_msg, spinner=DEFAULT_SPINNER):
path, file_paths = file_finder.search()
print_detected_ecosystems_section(console, file_paths,
include_safety_prjs=True)
file_finder = initialize_file_finder(ctx, target, console, ecosystems)
path, file_paths = scan_project_directory(file_finder, console)

target_ecosystems = ", ".join([member.value for member in ecosystems])
wait_msg = f"Analyzing {target_ecosystems} files and environments for security findings"

files: List[FileModel] = []

to_fix_files = []
ignored_vulns_data = iter([])
config = ctx.obj.config

count = 0

affected_count = 0
dependency_vuln_detected = False

ignored_vulns_data = iter([])

exit_code = 0
fixes_count = 0
total_resolved_vulns = 0
to_fix_files = []
fix_file_types = [fix_target[0] if isinstance(fix_target[0], str) else fix_target[0].value for fix_target in fixes_target]
dependency_vuln_detected = False
requirements_txt_found = False
display_apply_fix_suggestion = False

# Process each file for dependencies and vulnerabilities
with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status:
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
for path, analyzed_file in process_files(paths=file_paths,
config=config, use_server_matching=use_server_matching, obj=ctx.obj, target=target):
count += len(analyzed_file.dependency_results.dependencies)
Expand All @@ -450,49 +589,30 @@ def scan(ctx: typer.Context,
def sort_vulns_by_score(vuln: Vulnerability) -> int:
if vuln.severity and vuln.severity.cvssv3:
return vuln.severity.cvssv3.get("base_score", 0)

return 0

to_fix_spec = []
file_matched_for_fix = analyzed_file.file_type.value in fix_file_types

if any(affected_specifications):
if not dependency_vuln_detected:
console.print()
console.print("Dependency vulnerabilities detected:")
dependency_vuln_detected = True
dependency_vuln_detected = detect_dependency_vulnerabilities(console, dependency_vuln_detected)
print_file_info(console, path, target)

console.print()
msg = f":pencil: [file_title]{path.relative_to(target)}:[/file_title]"
console.print(msg, emoji=True)
for spec in affected_specifications:
if file_matched_for_fix:
to_fix_spec.append(spec)

console.print()
vulns_to_report = sorted(
[vuln for vuln in spec.vulnerabilities if not vuln.ignored],
key=sort_vulns_by_score,
reverse=True)
critical_vulns_count = sum(1 for vuln in vulns_to_report if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.get("base_severity", "none").lower() == VulnerabilitySeverityLabels.CRITICAL.value.lower())

vulns_to_report = sort_and_filter_vulnerabilities(spec.vulnerabilities, key_func=sort_vulns_by_score)
critical_vulns_count = count_critical_vulnerabilities(vulns_to_report)
vulns_found = len(vulns_to_report)
vuln_word = pluralize("vulnerability", vulns_found)

msg = f"[dep_name]{spec.name}[/dep_name][specifier]{spec.raw.replace(spec.name, '')}[/specifier] [{vulns_found} {vuln_word} found"

if vulns_found > 3 and critical_vulns_count > 0:
msg += f", [brief_severity]including {critical_vulns_count} critical severity {pluralize('vulnerability', critical_vulns_count)}[/brief_severity]"

console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True,
overflow="crop")
msg = generate_vulnerability_message(spec.name, spec.raw, vulns_found, critical_vulns_count, vuln_word)
console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True, overflow="crop")

if detailed_output or vulns_found < 3:
for vuln in vulns_to_report:
render_to_console(vuln, console,
rich_kwargs={"emoji": True,
"overflow": "crop"},
detailed_output=detailed_output)
render_vulnerabilities(vulns_to_report, console, detailed_output)

lines = []

Expand Down Expand Up @@ -591,7 +711,6 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int:
**{k: v for k, v in ctx.params.items() if k not in {"detailed_output", "output", "save_as", "filter_keys"}}
)


project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}"

if apply_updates:
Expand Down Expand Up @@ -630,7 +749,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int:

if not no_output:
console.print("-" * console.size.width)

if output is ScanOutput.SCREEN:
run_easter_egg(console, exit_code)

Expand Down

0 comments on commit 3f5882f

Please sign in to comment.