diff --git a/app/service/file_svc.py b/app/service/file_svc.py index 01cf2b0ec..0e520178d 100644 --- a/app/service/file_svc.py +++ b/app/service/file_svc.py @@ -4,6 +4,7 @@ import copy import json import os +import re import subprocess import sys @@ -19,6 +20,13 @@ from app.utility.payload_encoder import xor_file, xor_bytes FILE_ENCRYPTION_FLAG = '%encrypted%' +URL_SANITIZATION_REGEX = re.compile(r'^[\w\-\.:%+/]+$') +ALLOWED_DEFAULT_LDFLAG_REGEX = re.compile(r'^[\w\-\.]+$') +ALLOWED_LDFLAG_REGEXES = { + 'server': URL_SANITIZATION_REGEX, + 'http': URL_SANITIZATION_REGEX, + 'socket': re.compile(r'^[\w\-\.:]+$') +} class FileSvc(FileServiceInterface, BaseService): @@ -172,6 +180,17 @@ async def compile_go(self, platform, output, src_fle, arch='amd64', ldflags='-s except subprocess.CalledProcessError as e: self.log.warning('Problem building golang executable {}: {} '.format(src_fle, e)) + @staticmethod + def sanitize_ldflag_value(param, value): + """ + Validate that the specified LDFLAG value for the given parameter + only contains safe characters. + Raises a ValueError if disallowed characters are found. + """ + if not ALLOWED_LDFLAG_REGEXES.get(param, ALLOWED_DEFAULT_LDFLAG_REGEX).fullmatch(value): + raise ValueError('Invalid characters in %s LDFLAG value: %s' % (param, value)) + return value + def get_payload_name_from_uuid(self, payload): for t in ['standard_payloads', 'special_payloads']: for k, v in self.get_config(prop=t, name='payloads').items(): diff --git a/tests/services/test_file_svc.py b/tests/services/test_file_svc.py index bf007269b..fe7c6c16b 100644 --- a/tests/services/test_file_svc.py +++ b/tests/services/test_file_svc.py @@ -242,6 +242,91 @@ def test_is_extension_xored_false(self, file_svc): ret = file_svc.is_extension_xored(test_value) assert ret is False + def test_sanitize_ldflag_value(self, file_svc): + safe_values = [ + 'safevalue', + 'SAFE29VALUE', + '_safe_', + 's-a-f-e.s_a_f_e.2', + '1234567890' + ] + for value in safe_values: + assert value == file_svc.sanitize_ldflag_value('contact', value) + assert value == file_svc.sanitize_ldflag_value('group', value) + assert value == file_svc.sanitize_ldflag_value('genericparam', value) + + safe_server_values = [ + 'http://localhost', + 'https://localhost:8443', + 'https://127.0.0.1:8443/home.html', + 'https://some.domain.net:8443/home%20test.html', + 'https://_underscore.domain-with-dash.net:8443/home+test.html', + ] + for value in safe_server_values: + assert value == file_svc.sanitize_ldflag_value('server', value) + assert value == file_svc.sanitize_ldflag_value('http', value) + + safe_socket_values = [ + 'localhost:1234', + '10.10.10.10.:8888', + 'f.q.d.n:443', + 'domain-with-dash.net:443', + ] + for value in safe_socket_values: + assert value == file_svc.sanitize_ldflag_value('socket', value) + + unsafe_values = [ + 'unsafe with spaces', + 'unsafe,comma', + 'unsafe;semicolon', + 'unsafe!', + 'unsafe&&test', + 'unsafe||test', + 'unsafe>test', + 'unsafe