From c3b450b4ccefa579b1495ff1fcf8c2cc65b71be1 Mon Sep 17 00:00:00 2001 From: Antoine Martin Date: Wed, 12 Jul 2023 18:07:10 +0200 Subject: [PATCH] #3313 auto-upgrade tcp sockets to ssl --- xpra/client/base/client_base.py | 60 ++++++++++++++++++++++++++--- xpra/client/gui/ui_client_base.py | 1 - xpra/net/common.py | 2 + xpra/net/protocol/socket_handler.py | 3 ++ xpra/net/websockets/protocol.py | 2 + xpra/scripts/main.py | 3 +- xpra/scripts/parsing.py | 3 +- xpra/server/server_core.py | 52 ++++++++++++++++++++++++- 8 files changed, 116 insertions(+), 10 deletions(-) diff --git a/xpra/client/base/client_base.py b/xpra/client/base/client_base.py index a537f1561d..6adf884a59 100644 --- a/xpra/client/base/client_base.py +++ b/xpra/client/base/client_base.py @@ -18,7 +18,7 @@ from xpra.common import SPLASH_EXIT_DELAY, FULL_INFO, LOG_HELLO from xpra.child_reaper import getChildReaper, reaper_cleanup from xpra.net import compression -from xpra.net.common import may_log_packet, PACKET_TYPES +from xpra.net.common import may_log_packet, PACKET_TYPES, SSL_UPGRADE from xpra.make_thread import start_thread from xpra.net.protocol.factory import get_client_protocol_class from xpra.net.protocol.constants import CONNECTION_LOST, GIBBERISH, INVALID @@ -164,9 +164,8 @@ def init(self, opts): self.encryption_keyfile = opts.encryption_keyfile or opts.tcp_encryption_keyfile self.init_challenge_handlers(opts.challenge_handlers) self.install_signal_handlers() - #this is now done in UI client only, - #most simple clients are just wasting time doing this - #self.init_aliases() + #we need this to expose the 'packet-types' capability, + self.init_aliases() def show_progress(self, pct, text=""): @@ -333,8 +332,9 @@ def get_scheduler(self): raise NotImplementedError() def setup_connection(self, conn): - netlog("setup_connection(%s) timeout=%s, socktype=%s", conn, conn.timeout, conn.socktype) protocol_class = get_client_protocol_class(conn.socktype) + netlog("setup_connection(%s) timeout=%s, socktype=%s, protocol-class=", + conn, conn.timeout, conn.socktype, protocol_class) protocol = protocol_class(self.get_scheduler(), conn, self.process_packet, self.next_packet) #ssh channel may contain garbage initially, #tell the protocol to wait for a valid header: @@ -744,6 +744,53 @@ def _process_connection_lost(self, _packet) -> None: self.warn_and_quit(exit_code, msg) + def _process_ssl_upgrade(self, packet) -> None: + assert SSL_UPGRADE + ssl_attrs = typedict(packet[1]) + start_thread(self.ssl_upgrade, "ssl-upgrade", True, args=(ssl_attrs, )) + + def ssl_upgrade(self, ssl_attrs) -> None: + # send ssl-upgrade request! + ssllog = Logger("client", "ssl") + ssllog(f"ssl-upgrade({ssl_attrs})") + conn = self._protocol._conn + socktype = conn.socktype + new_socktype = {"tcp" : "ssl", "ws" : "wss"}.get(socktype) + if not new_socktype: + raise ValueError(f"cannot upgrade {socktype} to ssl") + log.info(f"upgrading {conn} to {new_socktype}") + self.send("ssl-upgrade", {}) + from xpra.net.socket_util import ssl_wrap_socket, get_ssl_attributes, ssl_handshake + overrides = { + "verify_mode" : "none", + "check_hostname" : "no", + } + overrides.update(conn.options.get("ssl-options", {})) + ssl_options = get_ssl_attributes(None, False, overrides) + kwargs = dict((k.replace("-", "_"), v) for k, v in ssl_options.items()) + # wait for the 'ssl-upgrade' packet to be sent... + # this should be done by watching the IO and formatting threads instead + import time + time.sleep(1) + def read_callback(packet): + if packet: + ssllog.error("Error: received another packet during ssl socket upgrade:") + ssllog.error(" %s", packet) + self.quit(EXIT_INTERNAL_ERROR) + conn = self._protocol.steal_connection(read_callback) + if not self._protocol.wait_for_io_threads_exit(1): + log.error("Error: failed to terminate network threads for ssl upgrade") + self.quit(EXIT_INTERNAL_ERROR) + return + ssl_sock = ssl_wrap_socket(conn._socket, **kwargs) + ssl_sock = ssl_handshake(ssl_sock) + authlog("ssl handshake complete") + from xpra.net.bytestreams import SSLSocketConnection + ssl_conn = SSLSocketConnection(ssl_sock, conn.local, conn.remote, conn.endpoint, new_socktype) + self._protocol = self.setup_connection(ssl_conn) + self._protocol.start() + + ######################################## # Authentication def _process_challenge(self, packet) -> None: @@ -1022,6 +1069,7 @@ def _process_hello(self, packet) -> None: netlog.info("received hello:") print_nested_dict(packet[1], print_fn=netlog.info) self.remove_packet_handlers("challenge") + self.remove_packet_handlers("ssl-upgrade") if not self.password_sent and self.has_password(): p = self._protocol if not p or p.TYPE=="xpra": @@ -1161,6 +1209,8 @@ def init_packet_handlers(self) -> None: self._packet_handlers = {} self._ui_packet_handlers = {} self.add_packet_handler("hello", self._process_hello, False) + if SSL_UPGRADE: + self.add_packet_handler("ssl-upgrade", self._process_ssl_upgrade) self.add_packet_handlers({ "challenge": self._process_challenge, "disconnect": self._process_disconnect, diff --git a/xpra/client/gui/ui_client_base.py b/xpra/client/gui/ui_client_base.py index be3a5b1881..d99f1d8326 100644 --- a/xpra/client/gui/ui_client_base.py +++ b/xpra/client/gui/ui_client_base.py @@ -171,7 +171,6 @@ def __init__(self): # pylint: disable=super-init-not-called def init(self, opts) -> None: """ initialize variables from configuration """ - self.init_aliases() for c in CLIENT_BASES: log(f"init: {c}") c.init(self, opts) diff --git a/xpra/net/common.py b/xpra/net/common.py index a85180c8eb..acea95f5cb 100644 --- a/xpra/net/common.py +++ b/xpra/net/common.py @@ -30,6 +30,7 @@ class ConnectionClosedException(Exception): MAX_PACKET_SIZE : int = envint("XPRA_MAX_PACKET_SIZE", 16*1024*1024) FLUSH_HEADER : bool = envbool("XPRA_FLUSH_HEADER", True) +SSL_UPGRADE : bool = envbool("XPRA_SSL_UPGRADE", False) SOCKET_TYPES : Tuple[str, ...] = ("tcp", "ws", "wss", "ssl", "ssh", "rfb", "vsock", "socket", "named-pipe", "quic") @@ -60,6 +61,7 @@ class ConnectionClosedException(Exception): #generic: "hello", "challenge", + "ssl-upgrade", "info", "info-response", #server state: "server-event", "startup-complete", diff --git a/xpra/net/protocol/socket_handler.py b/xpra/net/protocol/socket_handler.py index 1fd5b20f4e..07064ef789 100644 --- a/xpra/net/protocol/socket_handler.py +++ b/xpra/net/protocol/socket_handler.py @@ -19,6 +19,7 @@ from xpra.os_util import memoryview_to_bytes, strtobytes, bytestostr, hexstr from xpra.util import repr_ellipsized, ellipsizer, csv, envint, envbool, typedict from xpra.make_thread import make_thread, start_thread +from xpra.net.bytestreams import SOCKET_TIMEOUT, set_socket_timeout from xpra.net.protocol.header import ( unpack_header, pack_header, find_xpra_header, FLAGS_CIPHER, FLAGS_NOHEADER, FLAGS_FLUSH, HEADER_SIZE, @@ -234,6 +235,8 @@ def parse_remote_caps(self, caps : typedict) -> None: for k,v in caps.dictget("aliases", {}).items(): self.send_aliases[bytestostr(k)] = v self.send_flush_flag = FLUSH_HEADER and caps.boolget("flush", False) + set_socket_timeout(self._conn, SOCKET_TIMEOUT) + def set_receive_aliases(self, aliases:Dict) -> None: self.receive_aliases = aliases diff --git a/xpra/net/websockets/protocol.py b/xpra/net/websockets/protocol.py index e95e66547d..1ec30dd9c2 100644 --- a/xpra/net/websockets/protocol.py +++ b/xpra/net/websockets/protocol.py @@ -107,6 +107,8 @@ def parse_ws_frame(self, buf:ByteString) -> None: if not fin: if opcode not in (OPCODE_BINARY, OPCODE_TEXT): op = OPCODES.get(opcode, opcode) + log(f"invalid opcode {opcode} from {buf}") + log(f"parsed as {parsed}") raise RuntimeError(f"cannot handle fragmented {op} frames") #fragmented, keep this payload for later self.ws_payload_opcode = opcode diff --git a/xpra/scripts/main.py b/xpra/scripts/main.py index dbdd6404c1..a9eb0b8b14 100755 --- a/xpra/scripts/main.py +++ b/xpra/scripts/main.py @@ -1112,7 +1112,8 @@ def sockpathfail_cb(msg): if dtype in ("tcp", "ssl", "ws", "wss", "vnc"): sock = retry_socket_connect(display_desc) - sock.settimeout(None) + # use non-blocking until the connection is finalized + sock.settimeout(0.1) conn = SocketConnection(sock, sock.getsockname(), sock.getpeername(), display_name, dtype, socket_options=display_desc) diff --git a/xpra/scripts/parsing.py b/xpra/scripts/parsing.py index eecb31c91e..29467b0d24 100755 --- a/xpra/scripts/parsing.py +++ b/xpra/scripts/parsing.py @@ -539,8 +539,9 @@ def add_query() -> None: host, port = add_host_port(DEFAULT_PORTS.get(protocol, DEFAULT_PORT)) add_path() add_query() + # always parse ssl options so we can auto-upgrade: + desc["ssl-options"] = get_ssl_options(desc, opts, cmdline) if protocol in ("ssl", "wss", "quic"): - desc["ssl-options"] = get_ssl_options(desc, opts, cmdline) alt_scheme = "https" else: alt_scheme = "http" diff --git a/xpra/server/server_core.py b/xpra/server/server_core.py index aa164ca61c..90457b3b95 100644 --- a/xpra/server/server_core.py +++ b/xpra/server/server_core.py @@ -26,7 +26,7 @@ from xpra.scripts.server import deadly_signal, clean_session_files, rm_session_dir from xpra.server.server_util import write_pidfile, rm_pidfile from xpra.scripts.config import parse_bool, parse_with_unit, TRUE_OPTIONS, FALSE_OPTIONS -from xpra.net.common import may_log_packet, SOCKET_TYPES, MAX_PACKET_SIZE, DEFAULT_PORTS +from xpra.net.common import may_log_packet, SOCKET_TYPES, MAX_PACKET_SIZE, DEFAULT_PORTS, SSL_UPGRADE from xpra.net.socket_util import ( hosts, mdns_publish, peek_connection, PEEK_TIMEOUT_MS, SOCKET_PEEK_TIMEOUT_MS, @@ -41,6 +41,7 @@ get_network_caps, get_info as get_net_info, import_netifaces, get_interfaces_addresses, ) +from xpra.net.protocol.factory import get_server_protocol_class from xpra.net.protocol.socket_handler import SocketProtocol from xpra.net.protocol.constants import CONNECTION_LOST, GIBBERISH, INVALID from xpra.net.digest import get_salt, gendigest, choose_digest @@ -1038,6 +1039,7 @@ def init_packet_handlers(self) -> None: self._default_packet_handlers : Dict[str,Callable] = { "hello": self._process_hello, "disconnect": self._process_disconnect, + "ssl-upgrade": self._process_ssl_upgrade, CONNECTION_LOST: self._process_connection_lost, GIBBERISH: self._process_gibberish, INVALID: self._process_invalid, @@ -2003,6 +2005,20 @@ def auth_failed(msg:str): if auth_caps is None: return + # try to auto upgrade to ssl: + packet_types = c.strtupleget("packet-types", ()) + if SSL_UPGRADE and not auth_caps and "ssl-upgrade" in packet_types and conn.socktype in ("tcp", ): + options = conn.options + if options.get("ssl-upgrade", "yes").lower() in TRUE_OPTIONS: + ssl_options = self.get_ssl_socket_options(options) + cert = ssl_options.get("cert") + if cert: + log.info(f"sending ssl upgrade for {conn}") + cert_data = load_binary_file(cert) + ssl_attrs = {"cert-data" : cert_data} + proto.send_now(("ssl-upgrade", ssl_attrs)) + return + def send_fake_challenge() -> None: #fake challenge so the client will send the real hello: salt = get_salt() @@ -2012,7 +2028,6 @@ def send_fake_challenge() -> None: #skip the authentication module we have "passed" already: remaining_authenticators = tuple(x for x in proto.authenticators if not x.passed) - authlog("processing authentication with %s, remaining=%s, digest_modes=%s, salt_digest_modes=%s", proto.authenticators, remaining_authenticators, digest_modes, salt_digest_modes) #verify each remaining authenticator: @@ -2092,6 +2107,39 @@ def auth_verified(self, proto:SocketProtocol, caps:typedict, auth_caps:Dict) -> self.idle_add(self.call_hello_oked, proto, caps, auth_caps) + def _process_ssl_upgrade(self, proto, packet): + socktype = proto._conn.socktype + new_socktype = {"tcp" : "ssl", "ws" : "wss"}.get(socktype) + if not new_socktype: + raise ValueError(f"cannot upgrade {socktype} to ssl") + self.cancel_verify_connection_accepted(proto) + self.cancel_upgrade_to_rfb_timer(proto) + if proto in self._potential_protocols: + self._potential_protocols.remove(proto) + ssllog("ssl-upgrade: %s", packet[1:]) + conn = proto.steal_connection() + # threads should be able to terminate immediately + # as there's no traffic yet: + ioe = proto.wait_for_io_threads_exit(1) + if not ioe: + self.disconnect_protocol(proto, "failed to terminate network threads for ssl upgrade") + conn.close() + return + options = conn.options + socktype = conn.socktype + ssl_sock = self._ssl_wrap_socket(socktype, conn._socket, options) + if not ssl_sock: + self.disconnect_protocol(proto, "failed to upgrade socket to ssl") + conn.close() + return + # sock, sockname, address, endpoint = conn._socket, conn.local, conn.remote, conn.endpoint + ssl_conn = SSLSocketConnection(ssl_sock, conn.local, conn.remote, conn.endpoint, "ssl", socket_options=options) + ssl_conn.socktype_wrapped = socktype + protocol_class = get_server_protocol_class(new_socktype) + proto = self.make_protocol(new_socktype, ssl_conn, options, protocol_class) + ssllog.info("upgraded %s to %s", conn, new_socktype) + + def setup_encryption(self, proto:SocketProtocol, c : typedict) -> Optional[Dict[str,Any]]: def auth_failed(msg): self.auth_failed(proto, msg)