Skip to content

Commit

Permalink
[#511] Remove usage of ThreadPoolExecutor to reduce memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
nabla-c0d3 committed Mar 27, 2021
1 parent f79cd37 commit f1dd803
Show file tree
Hide file tree
Showing 23 changed files with 735 additions and 518 deletions.
2 changes: 0 additions & 2 deletions sslyze/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

def sigint_handler(signum: int, frame: Any) -> None:
print("Scan interrupted... shutting down.")
if global_scanner:
global_scanner.emergency_shutdown()
sys.exit()


Expand Down
19 changes: 12 additions & 7 deletions sslyze/plugins/certificate_info/implementation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures import Future
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Dict, Tuple
Expand All @@ -14,7 +13,13 @@
from sslyze.plugins.certificate_info._get_cert_chain import get_certificate_chain, ArgumentsToGetCertificateChain
from sslyze.plugins.certificate_info.trust_stores.trust_store import TrustStore
from sslyze.plugins.certificate_info.trust_stores.trust_store_repository import TrustStoresRepository
from sslyze.plugins.plugin_base import ScanCommandImplementation, ScanJob, ScanCommandResult, ScanCommandExtraArguments
from sslyze.plugins.plugin_base import (
ScanCommandImplementation,
ScanJob,
ScanCommandResult,
ScanCommandExtraArguments,
ScanJobResult,
)
from sslyze.server_connectivity import ServerConnectivityInfo, TlsVersionEnum


Expand Down Expand Up @@ -86,18 +91,18 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> CertificateInfoScanResult:
if len(completed_scan_jobs) != 3:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 3:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

# Only keep certificate deployments that are different
# Leaf certificate => certificate chain, OCSP response
all_configured_certificate_chains: Dict[str, Tuple[List[str], Optional[nassl._nassl.OCSP_RESPONSE]]] = {}
custom_ca_file = None
for completed_job in completed_scan_jobs:
for completed_job in scan_job_results:
try:
received_chain_as_pem, ocsp_response, custom_ca_file = completed_job.result()
received_chain_as_pem, ocsp_response, custom_ca_file = completed_job.get_result()
except TlsHandshakeFailed:
# Can happen when trying to connect with specific cipher suites (such as RSA or non-RSA)
continue
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/compression_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures._base import Future
from dataclasses import dataclass

from nassl.legacy_ssl_client import LegacySslClient
Expand All @@ -10,6 +9,7 @@
ScanCommandExtraArguments,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from typing import List, Optional

Expand Down Expand Up @@ -60,12 +60,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> CompressionScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return CompressionScanResult(supports_compression=completed_scan_jobs[0].result())
return CompressionScanResult(supports_compression=scan_job_results[0].get_result())


def _test_compression_support(server_info: ServerConnectivityInfo) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/early_data_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures._base import Future
from dataclasses import dataclass
from typing import List, Optional

Expand All @@ -12,6 +11,7 @@
ScanJob,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from sslyze.server_connectivity import ServerConnectivityInfo, TlsVersionEnum
from sslyze.errors import ServerRejectedTlsHandshake, TlsHandshakeTimedOut
Expand Down Expand Up @@ -63,12 +63,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> EarlyDataScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return EarlyDataScanResult(supports_early_data=completed_scan_jobs[0].result())
return EarlyDataScanResult(supports_early_data=scan_job_results[0].get_result())


def _test_early_data_support(server_info: ServerConnectivityInfo) -> bool:
Expand Down
14 changes: 7 additions & 7 deletions sslyze/plugins/elliptic_curves_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures._base import Future
from dataclasses import dataclass
from operator import attrgetter
from typing import List, Optional
Expand All @@ -16,6 +15,7 @@
ScanCommandExtraArguments,
ScanJob,
ScanCommandWrongUsageError,
ScanJobResult,
)
from sslyze.server_connectivity import enable_ecdh_cipher_suites

Expand Down Expand Up @@ -109,21 +109,21 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> SupportedEllipticCurvesScanResult:
if len(completed_scan_jobs) < 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) < 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

if len(completed_scan_jobs) == 1:
if len(scan_job_results) == 1:
try:
completed_scan_jobs[0].result()
scan_job_results[0].get_result()
raise RuntimeError("Should never happen")
except _EllipticCurveNotSupported:
return SupportedEllipticCurvesScanResult(
supports_ecdh_key_exchange=False, supported_curves=None, rejected_curves=None,
)
else:
all_ecdh_results = [scan_job.result() for scan_job in completed_scan_jobs]
all_ecdh_results = [scan_job.get_result() for scan_job in scan_job_results]
return SupportedEllipticCurvesScanResult(
supports_ecdh_key_exchange=True,
supported_curves=[
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/fallback_scsv_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures._base import Future
from dataclasses import dataclass
from typing import List, Optional
from nassl import _nassl
Expand All @@ -10,6 +9,7 @@
ScanJob,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from sslyze.server_connectivity import ServerConnectivityInfo, TlsVersionEnum
from sslyze.errors import ServerRejectedTlsHandshake, TlsHandshakeTimedOut
Expand Down Expand Up @@ -58,12 +58,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> FallbackScsvScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return FallbackScsvScanResult(supports_fallback_scsv=completed_scan_jobs[0].result())
return FallbackScsvScanResult(supports_fallback_scsv=scan_job_results[0].get_result())


def _test_scsv(server_info: ServerConnectivityInfo) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/heartbleed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import socket
import types
from concurrent.futures._base import Future
from dataclasses import dataclass
from typing import List, Optional

Expand All @@ -13,6 +12,7 @@
ScanCommandExtraArguments,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from tls_parser.alert_protocol import TlsAlertRecord
from tls_parser.exceptions import NotEnoughData, UnknownTlsVersionByte
Expand Down Expand Up @@ -69,12 +69,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> HeartbleedScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return HeartbleedScanResult(is_vulnerable_to_heartbleed=completed_scan_jobs[0].result())
return HeartbleedScanResult(is_vulnerable_to_heartbleed=scan_job_results[0].get_result())


def _test_heartbleed(server_info: ServerConnectivityInfo) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/http_headers_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import socket
from concurrent.futures._base import Future
from http.client import HTTPResponse

from dataclasses import dataclass
Expand All @@ -14,6 +13,7 @@
ScanCommandResult,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from sslyze.server_connectivity import ServerConnectivityInfo
from sslyze.connection_helpers.http_request_generator import HttpRequestGenerator
Expand Down Expand Up @@ -180,12 +180,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> HttpHeadersScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return completed_scan_jobs[0].result()
return scan_job_results[0].get_result()


def _retrieve_and_analyze_http_response(server_info: ServerConnectivityInfo) -> HttpHeadersScanResult:
Expand Down
10 changes: 5 additions & 5 deletions sslyze/plugins/openssl_ccs_injection_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import socket
import types
from concurrent.futures._base import Future
from dataclasses import dataclass
from typing import List, Optional

Expand All @@ -13,6 +12,7 @@
ScanJob,
ScanCommandWrongUsageError,
ScanCommandCliConnector,
ScanJobResult,
)
from tls_parser.alert_protocol import TlsAlertRecord
from tls_parser.application_data_protocol import TlsApplicationDataRecord
Expand Down Expand Up @@ -70,12 +70,12 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> OpenSslCcsInjectionScanResult:
if len(completed_scan_jobs) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != 1:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

return OpenSslCcsInjectionScanResult(is_vulnerable_to_ccs_injection=completed_scan_jobs[0].result())
return OpenSslCcsInjectionScanResult(is_vulnerable_to_ccs_injection=scan_job_results[0].get_result())


def _test_for_ccs_injection(server_info: ServerConnectivityInfo) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions sslyze/plugins/openssl_cipher_suites/implementation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures import Future
from operator import attrgetter

from dataclasses import dataclass
Expand All @@ -17,6 +16,7 @@
ScanJob,
ScanCommandExtraArguments,
ScanCommandWrongUsageError,
ScanJobResult,
)
from typing import ClassVar, Optional
from typing import List
Expand Down Expand Up @@ -113,17 +113,17 @@ def scan_jobs_for_scan_command(

@classmethod
def result_for_completed_scan_jobs(
cls, server_info: ServerConnectivityInfo, completed_scan_jobs: List[Future]
cls, server_info: ServerConnectivityInfo, scan_job_results: List[ScanJobResult]
) -> CipherSuitesScanResult:
expected_scan_jobs_count = len(CipherSuitesRepository.get_all_cipher_suites(cls._tls_version))
if len(completed_scan_jobs) != expected_scan_jobs_count:
raise RuntimeError(f"Unexpected number of scan jobs received: {completed_scan_jobs}")
if len(scan_job_results) != expected_scan_jobs_count:
raise RuntimeError(f"Unexpected number of scan jobs received: {scan_job_results}")

accepted_cipher_suites = []
rejected_cipher_suites = []
for completed_job in completed_scan_jobs:
for completed_job in scan_job_results:
try:
cipher_suite_result = completed_job.result()
cipher_suite_result = completed_job.get_result()
except NoCiphersAvailableBugInSSlyze:
# Happens when we passed a cipher suite and a TLS version that are not supported together by OpenSSL
# Swallowing this exception makes it easier as we can just always use the ALL:COMPLEMENTOFALL OpenSSL
Expand Down
28 changes: 22 additions & 6 deletions sslyze/plugins/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor

from dataclasses import dataclass

Expand Down Expand Up @@ -40,6 +40,18 @@ class ScanJob:
function_arguments: Any


@dataclass(frozen=True)
class ScanJobResult:
_return_value: Optional[Any]
_exception: Optional[Exception]

def get_result(self) -> Any:
if self._exception:
raise self._exception
else:
return self._return_value


_ScanCommandResultTypeVar = TypeVar("_ScanCommandResultTypeVar", bound=ScanCommandResult)
_ScanCommandExtraArgumentsTypeVar = TypeVar(
"_ScanCommandExtraArgumentsTypeVar", bound=Optional[ScanCommandExtraArguments]
Expand Down Expand Up @@ -68,9 +80,9 @@ def scan_jobs_for_scan_command(
@classmethod
@abstractmethod
def result_for_completed_scan_jobs(
cls, server_info: "ServerConnectivityInfo", completed_scan_jobs: List[Future]
cls, server_info: "ServerConnectivityInfo", scan_job_results: List[ScanJobResult]
) -> _ScanCommandResultTypeVar:
"""Transform the completed scan jobs for a given scan command into a result.
"""Transform the individual scan job results for a given scan command into a scan command result.
"""
pass

Expand All @@ -86,12 +98,16 @@ def scan_server(
thread_pool = ThreadPoolExecutor(max_workers=5)

all_jobs = cls.scan_jobs_for_scan_command(server_info, extra_arguments)
all_futures = []
all_job_results = []
for job in all_jobs:
future = thread_pool.submit(job.function_to_call, *job.function_arguments)
all_futures.append(future)
try:
job_result = ScanJobResult(_return_value=future.result(), _exception=None)
except Exception as e:
job_result = ScanJobResult(_return_value=None, _exception=e)
all_job_results.append(job_result)

result = cls.result_for_completed_scan_jobs(server_info, all_futures)
result = cls.result_for_completed_scan_jobs(server_info, all_job_results)
return result


Expand Down
Loading

0 comments on commit f1dd803

Please sign in to comment.