diff --git a/src/modules/workflow.py b/src/modules/workflow.py index f8ce9a8..e7d455a 100755 --- a/src/modules/workflow.py +++ b/src/modules/workflow.py @@ -9,7 +9,7 @@ from tqdm import tqdm from utils import helper, github_util, report_util - +import os from modules import notebook_inspector, model_inspector @@ -72,6 +72,7 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d scanning_status = True failed_scan_files = list() scanned_report_dictionary = {} + base_path = str(os.getcwd()) try: print("Scanning Started ...") @@ -85,7 +86,8 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d scanning_id=scanning_id, path=path, branch_name=branch_name, - depth=depth) + depth=depth, + base_path = base_path) # iterate to get response from each files for file in tqdm(to_be_scanned_files): @@ -139,12 +141,13 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d except Exception as e: scanning_status = False print("Scanning Failed due to {}".format(str(e))) - - # Clean up the local cloned directory either scanning failed or Completed - if repo_type.lower() == "github": - github_util.delete_github_repo(repo_dir=save_dir) - elif repo_type.lower() not in ["file", "folder"]: - helper.delete_directory([save_dir]) + + if save_dir != None: + # Clean up the local cloned directory either scanning failed or Completed + if repo_type.lower() == "github": + github_util.delete_github_repo(repo_dir=save_dir) + elif repo_type.lower() not in ["file", "folder"]: + helper.delete_directory(base_path,[save_dir]) return report_path, scanning_status diff --git a/src/utils/github_util.py b/src/utils/github_util.py index a2c5224..7445823 100755 --- a/src/utils/github_util.py +++ b/src/utils/github_util.py @@ -81,4 +81,4 @@ def delete_github_repo(repo_dir): print("Locally cloned repository has been successfully removed") except Exception as e: - print("{} Failed to remove due to {}".format(repo_dir, str(e))) + print("{} Failed to remove due to {}, it is recommended to delete the directory manually".format(repo_dir, str(e))) diff --git a/src/utils/helper.py b/src/utils/helper.py index 467eb13..b063b83 100755 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -14,7 +14,7 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, github_clone_dir: str = None, aws_access_key_id: str = None, aws_secret_access_key: str = None, region: str = None, - bucket_name: str = None, s3_download_dir: str = None,path: str = None, branch_name: str = 'main',depth: int=1): + bucket_name: str = None, s3_download_dir: str = None,path: str = None, branch_name: str = 'main',depth: int=1,base_path: str=None): """ Fetches files to be scanned based on the repository type and scanning ID. @@ -53,7 +53,10 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, save_dir = github_clone_dir # Clone the gitHub repository in the local - + # if not os.path.exists(save_dir): + # os.makedirs(save_dir) + + # print(os.path.dirname(save_dir)) if repo_type.lower() == 'github': repo_url = repo_url elif repo_type.lower() == 'huggingface': @@ -61,23 +64,22 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, repo_url = f'https://huggingface.co/{repo_url}' - github_util.clone_github_repo(repo_url, save_dir,branch_name,depth) - + # get all h5 files - h5_files = search_files(github_clone_dir, '.h5') + h5_files = search_files(base_path,github_clone_dir, '.h5') # get all .pb files - pb_files = search_files(github_clone_dir, '.pb') + pb_files = search_files(base_path,github_clone_dir, '.pb') # get all .pkl files - pkl_files = search_files(github_clone_dir, '.pkl') + pkl_files = search_files(base_path,github_clone_dir, '.pkl') # get all ipynb files - ipynb_files = search_files(github_clone_dir, '.ipynb') + ipynb_files = search_files(base_path,github_clone_dir, '.ipynb') # get requirements files - requirement_files = search_files(github_clone_dir, 'requirements.txt') + requirement_files = search_files(base_path,github_clone_dir, 'requirements.txt') to_be_scanned_files = h5_files + ipynb_files + pb_files + pkl_files + requirement_files @@ -90,7 +92,7 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, # Ensure local directory exists if not os.path.exists(save_dir): os.makedirs(save_dir) - + # create s3 object to interact with s3 buckets s3_object = aws_s3_util.AIShieldWatchtowerS3(aws_access_key_id, aws_secret_access_key, region, bucket_name, save_dir) @@ -108,18 +110,19 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, if repo_type.lower() == 'folder': tar_dir = path # Assuming file_path is the path to the folder - h5_files = search_files(tar_dir, '.h5') - pb_files = search_files(tar_dir, '.pb') - pkl_files = search_files(tar_dir, '.pkl') - ipynb_files = search_files(tar_dir, '.ipynb') - requirement_files = search_files(tar_dir, 'requirements.txt') + folder_base_path = os.path.dirname(tar_dir) + h5_files = search_files(folder_base_path,tar_dir, '.h5') + pb_files = search_files(folder_base_path,tar_dir, '.pb') + pkl_files = search_files(folder_base_path,tar_dir, '.pkl') + ipynb_files = search_files(folder_base_path,tar_dir, '.ipynb') + requirement_files = search_files(folder_base_path,tar_dir, 'requirements.txt') to_be_scanned_files = h5_files + ipynb_files + pb_files + pkl_files + requirement_files return to_be_scanned_files, save_dir -def search_files(target_dir: str, file_extensions: str): +def search_files(base_path:str, target_dir: str, file_extensions: str): """ Finds all the files ending with a given extension in the specified directory and its sub-folders. @@ -130,7 +133,14 @@ def search_files(target_dir: str, file_extensions: str): Returns: - List of paths to files with the specified extension. """ + if not target_dir: + raise Exception("Target directory is empty") + # Normalize the target directory and check if it is within the base_path + full_target_dir = os.path.normpath(os.path.join(base_path, target_dir)) + + if not os.path.abspath(full_target_dir).startswith(os.path.abspath(base_path)): + raise Exception("Target directory is outside the base path") # List to hold the paths of all files matching the given extension matching_files = [] @@ -166,7 +176,7 @@ def make_directory(path): print("{} created successfully".format(path)) -def delete_directory(directory): +def delete_directory(base_path :str,directory): """ delete directory @@ -181,6 +191,12 @@ def delete_directory(directory): """ for d in directory: + + full_path = os.path.normpath(os.path.join(base_path, d)) + if not os.path.abspath(full_path).startswith(os.path.abspath(base_path)): + print("Path '{}' is outside the base folder '{}' and cannot be deleted.".format(full_path, base_path)) + return + try: if os.path.isdir(d): try: diff --git a/src/utils/report_util.py b/src/utils/report_util.py index e06877e..cdaf239 100755 --- a/src/utils/report_util.py +++ b/src/utils/report_util.py @@ -334,18 +334,18 @@ def whisper_output_parser(output: str): output = eval(output) if len(output) != 0: for out in output: - key = out['severity'].lower() - if key == "info" or key == "minor": - key = "Low" - elif key == "major": - key = "Medium" - elif key == "blocker" or key == "critical": - key = "High" - out['vulnerability_severity'] = key - if key in vulnerability_severity_map: - vulnerability_severity_map[key] = int(vulnerability_severity_map[key]) + 1 + whisper_sev_key = out['severity'].lower() + if whisper_sev_key == "info" or whisper_sev_key == "minor": + vul_sev_key = "Low" + elif whisper_sev_key == "major": + vul_sev_key = "Medium" + elif whisper_sev_key == "blocker" or whisper_sev_key == "critical": + vul_sev_key = "High" + out['vulnerability_severity'] = vul_sev_key + if vul_sev_key in vulnerability_severity_map: + vulnerability_severity_map[vul_sev_key] = int(vulnerability_severity_map[vul_sev_key]) + 1 else: - vulnerability_severity_map[key] = 1 + vulnerability_severity_map[vul_sev_key] = 1 except Exception as e: print("Failed to parse whisper_output {}".format(e))