Skip to content

Commit

Permalink
Merge pull request #3129 from mitre/sanitize-param
Browse files Browse the repository at this point in the history
sanitize user-provided LDFLAG parameters
  • Loading branch information
elegantmoose authored Feb 17, 2025
2 parents 86ac639 + f0688ab commit 32adbf5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
19 changes: 19 additions & 0 deletions app/service/file_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import json
import os
import re
import subprocess
import sys

Expand All @@ -19,6 +20,13 @@
from app.utility.payload_encoder import xor_file, xor_bytes

FILE_ENCRYPTION_FLAG = '%encrypted%'
URL_SANITIZATION_REGEX = re.compile(r'^[\w\-\.:%+/]+$')
ALLOWED_DEFAULT_LDFLAG_REGEX = re.compile(r'^[\w\-\.]+$')
ALLOWED_LDFLAG_REGEXES = {
'server': URL_SANITIZATION_REGEX,
'http': URL_SANITIZATION_REGEX,
'socket': re.compile(r'^[\w\-\.:]+$')
}


class FileSvc(FileServiceInterface, BaseService):
Expand Down Expand Up @@ -172,6 +180,17 @@ async def compile_go(self, platform, output, src_fle, arch='amd64', ldflags='-s
except subprocess.CalledProcessError as e:
self.log.warning('Problem building golang executable {}: {} '.format(src_fle, e))

@staticmethod
def sanitize_ldflag_value(param, value):
"""
Validate that the specified LDFLAG value for the given parameter
only contains safe characters.
Raises a ValueError if disallowed characters are found.
"""
if not ALLOWED_LDFLAG_REGEXES.get(param, ALLOWED_DEFAULT_LDFLAG_REGEX).fullmatch(value):
raise ValueError('Invalid characters in %s LDFLAG value: %s' % (param, value))
return value

def get_payload_name_from_uuid(self, payload):
for t in ['standard_payloads', 'special_payloads']:
for k, v in self.get_config(prop=t, name='payloads').items():
Expand Down
85 changes: 85 additions & 0 deletions tests/services/test_file_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,91 @@ def test_is_extension_xored_false(self, file_svc):
ret = file_svc.is_extension_xored(test_value)
assert ret is False

def test_sanitize_ldflag_value(self, file_svc):
safe_values = [
'safevalue',
'SAFE29VALUE',
'_safe_',
's-a-f-e.s_a_f_e.2',
'1234567890'
]
for value in safe_values:
assert value == file_svc.sanitize_ldflag_value('contact', value)
assert value == file_svc.sanitize_ldflag_value('group', value)
assert value == file_svc.sanitize_ldflag_value('genericparam', value)

safe_server_values = [
'http://localhost',
'https://localhost:8443',
'https://127.0.0.1:8443/home.html',
'https://some.domain.net:8443/home%20test.html',
'https://_underscore.domain-with-dash.net:8443/home+test.html',
]
for value in safe_server_values:
assert value == file_svc.sanitize_ldflag_value('server', value)
assert value == file_svc.sanitize_ldflag_value('http', value)

safe_socket_values = [
'localhost:1234',
'10.10.10.10.:8888',
'f.q.d.n:443',
'domain-with-dash.net:443',
]
for value in safe_socket_values:
assert value == file_svc.sanitize_ldflag_value('socket', value)

unsafe_values = [
'unsafe with spaces',
'unsafe,comma',
'unsafe;semicolon',
'unsafe!',
'unsafe&&test',
'unsafe||test',
'unsafe>test',
'unsafe<test',
'unsafe$(test)',
'unsafe~/test',
'unsafe%test+',
]
for value in unsafe_values:
with pytest.raises(Exception) as e_info:
file_svc.sanitize_ldflag_value('group', value)
assert str(e_info.value) == 'Invalid characters in group LDFLAG value: {}'.format(value)

unsafe_server_values = [
'http://localhost||test',
'https://localhost:8443 space',
'https://localhost:8443@',
'https://localhost:8443"test',
'https://localhost:8443\'test',
'https://127.0.0.1:8443/home.html$(test)',
'https://some.domain.net:8443/home%20test.html && test',
'https://_underscore.domain-with-dash.net:8443/home+test.html; test',
]
for value in unsafe_server_values:
with pytest.raises(Exception) as e_info:
file_svc.sanitize_ldflag_value('server', value)
assert str(e_info.value) == 'Invalid characters in server LDFLAG value: {}'.format(value)

with pytest.raises(Exception) as e_info:
file_svc.sanitize_ldflag_value('http', value)
assert str(e_info.value) == 'Invalid characters in http LDFLAG value: {}'.format(value)

unsafe_socket_values = [
'localhost:8888||test',
'127.0.0.1:8443 space',
'domain.com:8443@',
'localhost:8443"test',
'localhost:8443\'test',
'127.0.0.1:8443$(test)',
'some.domain.net:8443 && test',
'domain-with-dash.net:8443; test',
]
for value in unsafe_socket_values:
with pytest.raises(Exception) as e_info:
file_svc.sanitize_ldflag_value('socket', value)
assert str(e_info.value) == 'Invalid characters in socket LDFLAG value: {}'.format(value)

@staticmethod
def _test_download_file_with_encoding(event_loop, file_svc, data_svc, encoding, original_content, encoded_content):
filename = 'testencodedpayload.txt'
Expand Down

0 comments on commit 32adbf5

Please sign in to comment.