Skip to content

Commit

Permalink
#3313 auto-upgrade tcp sockets to ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
totaam committed Jul 12, 2023
1 parent 62f0f9c commit c3b450b
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 10 deletions.
60 changes: 55 additions & 5 deletions xpra/client/base/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xpra.common import SPLASH_EXIT_DELAY, FULL_INFO, LOG_HELLO
from xpra.child_reaper import getChildReaper, reaper_cleanup
from xpra.net import compression
from xpra.net.common import may_log_packet, PACKET_TYPES
from xpra.net.common import may_log_packet, PACKET_TYPES, SSL_UPGRADE
from xpra.make_thread import start_thread
from xpra.net.protocol.factory import get_client_protocol_class
from xpra.net.protocol.constants import CONNECTION_LOST, GIBBERISH, INVALID
Expand Down Expand Up @@ -164,9 +164,8 @@ def init(self, opts):
self.encryption_keyfile = opts.encryption_keyfile or opts.tcp_encryption_keyfile
self.init_challenge_handlers(opts.challenge_handlers)
self.install_signal_handlers()
#this is now done in UI client only,
#most simple clients are just wasting time doing this
#self.init_aliases()
#we need this to expose the 'packet-types' capability,
self.init_aliases()


def show_progress(self, pct, text=""):
Expand Down Expand Up @@ -333,8 +332,9 @@ def get_scheduler(self):
raise NotImplementedError()

def setup_connection(self, conn):
netlog("setup_connection(%s) timeout=%s, socktype=%s", conn, conn.timeout, conn.socktype)
protocol_class = get_client_protocol_class(conn.socktype)
netlog("setup_connection(%s) timeout=%s, socktype=%s, protocol-class=",
conn, conn.timeout, conn.socktype, protocol_class)
protocol = protocol_class(self.get_scheduler(), conn, self.process_packet, self.next_packet)
#ssh channel may contain garbage initially,
#tell the protocol to wait for a valid header:
Expand Down Expand Up @@ -744,6 +744,53 @@ def _process_connection_lost(self, _packet) -> None:
self.warn_and_quit(exit_code, msg)


def _process_ssl_upgrade(self, packet) -> None:
assert SSL_UPGRADE
ssl_attrs = typedict(packet[1])
start_thread(self.ssl_upgrade, "ssl-upgrade", True, args=(ssl_attrs, ))

def ssl_upgrade(self, ssl_attrs) -> None:
# send ssl-upgrade request!
ssllog = Logger("client", "ssl")
ssllog(f"ssl-upgrade({ssl_attrs})")
conn = self._protocol._conn
socktype = conn.socktype
new_socktype = {"tcp" : "ssl", "ws" : "wss"}.get(socktype)
if not new_socktype:
raise ValueError(f"cannot upgrade {socktype} to ssl")
log.info(f"upgrading {conn} to {new_socktype}")
self.send("ssl-upgrade", {})
from xpra.net.socket_util import ssl_wrap_socket, get_ssl_attributes, ssl_handshake
overrides = {
"verify_mode" : "none",
"check_hostname" : "no",
}
overrides.update(conn.options.get("ssl-options", {}))
ssl_options = get_ssl_attributes(None, False, overrides)
kwargs = dict((k.replace("-", "_"), v) for k, v in ssl_options.items())
# wait for the 'ssl-upgrade' packet to be sent...
# this should be done by watching the IO and formatting threads instead
import time
time.sleep(1)
def read_callback(packet):
if packet:
ssllog.error("Error: received another packet during ssl socket upgrade:")
ssllog.error(" %s", packet)
self.quit(EXIT_INTERNAL_ERROR)
conn = self._protocol.steal_connection(read_callback)
if not self._protocol.wait_for_io_threads_exit(1):
log.error("Error: failed to terminate network threads for ssl upgrade")
self.quit(EXIT_INTERNAL_ERROR)
return
ssl_sock = ssl_wrap_socket(conn._socket, **kwargs)
ssl_sock = ssl_handshake(ssl_sock)
authlog("ssl handshake complete")
from xpra.net.bytestreams import SSLSocketConnection
ssl_conn = SSLSocketConnection(ssl_sock, conn.local, conn.remote, conn.endpoint, new_socktype)
self._protocol = self.setup_connection(ssl_conn)
self._protocol.start()


########################################
# Authentication
def _process_challenge(self, packet) -> None:
Expand Down Expand Up @@ -1022,6 +1069,7 @@ def _process_hello(self, packet) -> None:
netlog.info("received hello:")
print_nested_dict(packet[1], print_fn=netlog.info)
self.remove_packet_handlers("challenge")
self.remove_packet_handlers("ssl-upgrade")
if not self.password_sent and self.has_password():
p = self._protocol
if not p or p.TYPE=="xpra":
Expand Down Expand Up @@ -1161,6 +1209,8 @@ def init_packet_handlers(self) -> None:
self._packet_handlers = {}
self._ui_packet_handlers = {}
self.add_packet_handler("hello", self._process_hello, False)
if SSL_UPGRADE:
self.add_packet_handler("ssl-upgrade", self._process_ssl_upgrade)
self.add_packet_handlers({
"challenge": self._process_challenge,
"disconnect": self._process_disconnect,
Expand Down
1 change: 0 additions & 1 deletion xpra/client/gui/ui_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def __init__(self): # pylint: disable=super-init-not-called

def init(self, opts) -> None:
""" initialize variables from configuration """
self.init_aliases()
for c in CLIENT_BASES:
log(f"init: {c}")
c.init(self, opts)
Expand Down
2 changes: 2 additions & 0 deletions xpra/net/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ConnectionClosedException(Exception):

MAX_PACKET_SIZE : int = envint("XPRA_MAX_PACKET_SIZE", 16*1024*1024)
FLUSH_HEADER : bool = envbool("XPRA_FLUSH_HEADER", True)
SSL_UPGRADE : bool = envbool("XPRA_SSL_UPGRADE", False)

SOCKET_TYPES : Tuple[str, ...] = ("tcp", "ws", "wss", "ssl", "ssh", "rfb", "vsock", "socket", "named-pipe", "quic")

Expand Down Expand Up @@ -60,6 +61,7 @@ class ConnectionClosedException(Exception):
#generic:
"hello",
"challenge",
"ssl-upgrade",
"info", "info-response",
#server state:
"server-event", "startup-complete",
Expand Down
3 changes: 3 additions & 0 deletions xpra/net/protocol/socket_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from xpra.os_util import memoryview_to_bytes, strtobytes, bytestostr, hexstr
from xpra.util import repr_ellipsized, ellipsizer, csv, envint, envbool, typedict
from xpra.make_thread import make_thread, start_thread
from xpra.net.bytestreams import SOCKET_TIMEOUT, set_socket_timeout
from xpra.net.protocol.header import (
unpack_header, pack_header, find_xpra_header,
FLAGS_CIPHER, FLAGS_NOHEADER, FLAGS_FLUSH, HEADER_SIZE,
Expand Down Expand Up @@ -234,6 +235,8 @@ def parse_remote_caps(self, caps : typedict) -> None:
for k,v in caps.dictget("aliases", {}).items():
self.send_aliases[bytestostr(k)] = v
self.send_flush_flag = FLUSH_HEADER and caps.boolget("flush", False)
set_socket_timeout(self._conn, SOCKET_TIMEOUT)


def set_receive_aliases(self, aliases:Dict) -> None:
self.receive_aliases = aliases
Expand Down
2 changes: 2 additions & 0 deletions xpra/net/websockets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def parse_ws_frame(self, buf:ByteString) -> None:
if not fin:
if opcode not in (OPCODE_BINARY, OPCODE_TEXT):
op = OPCODES.get(opcode, opcode)
log(f"invalid opcode {opcode} from {buf}")
log(f"parsed as {parsed}")
raise RuntimeError(f"cannot handle fragmented {op} frames")
#fragmented, keep this payload for later
self.ws_payload_opcode = opcode
Expand Down
3 changes: 2 additions & 1 deletion xpra/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,8 @@ def sockpathfail_cb(msg):

if dtype in ("tcp", "ssl", "ws", "wss", "vnc"):
sock = retry_socket_connect(display_desc)
sock.settimeout(None)
# use non-blocking until the connection is finalized
sock.settimeout(0.1)
conn = SocketConnection(sock, sock.getsockname(), sock.getpeername(), display_name,
dtype, socket_options=display_desc)

Expand Down
3 changes: 2 additions & 1 deletion xpra/scripts/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,9 @@ def add_query() -> None:
host, port = add_host_port(DEFAULT_PORTS.get(protocol, DEFAULT_PORT))
add_path()
add_query()
# always parse ssl options so we can auto-upgrade:
desc["ssl-options"] = get_ssl_options(desc, opts, cmdline)
if protocol in ("ssl", "wss", "quic"):
desc["ssl-options"] = get_ssl_options(desc, opts, cmdline)
alt_scheme = "https"
else:
alt_scheme = "http"
Expand Down
52 changes: 50 additions & 2 deletions xpra/server/server_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from xpra.scripts.server import deadly_signal, clean_session_files, rm_session_dir
from xpra.server.server_util import write_pidfile, rm_pidfile
from xpra.scripts.config import parse_bool, parse_with_unit, TRUE_OPTIONS, FALSE_OPTIONS
from xpra.net.common import may_log_packet, SOCKET_TYPES, MAX_PACKET_SIZE, DEFAULT_PORTS
from xpra.net.common import may_log_packet, SOCKET_TYPES, MAX_PACKET_SIZE, DEFAULT_PORTS, SSL_UPGRADE
from xpra.net.socket_util import (
hosts, mdns_publish, peek_connection,
PEEK_TIMEOUT_MS, SOCKET_PEEK_TIMEOUT_MS,
Expand All @@ -41,6 +41,7 @@
get_network_caps, get_info as get_net_info,
import_netifaces, get_interfaces_addresses,
)
from xpra.net.protocol.factory import get_server_protocol_class
from xpra.net.protocol.socket_handler import SocketProtocol
from xpra.net.protocol.constants import CONNECTION_LOST, GIBBERISH, INVALID
from xpra.net.digest import get_salt, gendigest, choose_digest
Expand Down Expand Up @@ -1038,6 +1039,7 @@ def init_packet_handlers(self) -> None:
self._default_packet_handlers : Dict[str,Callable] = {
"hello": self._process_hello,
"disconnect": self._process_disconnect,
"ssl-upgrade": self._process_ssl_upgrade,
CONNECTION_LOST: self._process_connection_lost,
GIBBERISH: self._process_gibberish,
INVALID: self._process_invalid,
Expand Down Expand Up @@ -2003,6 +2005,20 @@ def auth_failed(msg:str):
if auth_caps is None:
return

# try to auto upgrade to ssl:
packet_types = c.strtupleget("packet-types", ())
if SSL_UPGRADE and not auth_caps and "ssl-upgrade" in packet_types and conn.socktype in ("tcp", ):
options = conn.options
if options.get("ssl-upgrade", "yes").lower() in TRUE_OPTIONS:
ssl_options = self.get_ssl_socket_options(options)
cert = ssl_options.get("cert")
if cert:
log.info(f"sending ssl upgrade for {conn}")
cert_data = load_binary_file(cert)
ssl_attrs = {"cert-data" : cert_data}
proto.send_now(("ssl-upgrade", ssl_attrs))
return

def send_fake_challenge() -> None:
#fake challenge so the client will send the real hello:
salt = get_salt()
Expand All @@ -2012,7 +2028,6 @@ def send_fake_challenge() -> None:

#skip the authentication module we have "passed" already:
remaining_authenticators = tuple(x for x in proto.authenticators if not x.passed)

authlog("processing authentication with %s, remaining=%s, digest_modes=%s, salt_digest_modes=%s",
proto.authenticators, remaining_authenticators, digest_modes, salt_digest_modes)
#verify each remaining authenticator:
Expand Down Expand Up @@ -2092,6 +2107,39 @@ def auth_verified(self, proto:SocketProtocol, caps:typedict, auth_caps:Dict) ->
self.idle_add(self.call_hello_oked, proto, caps, auth_caps)


def _process_ssl_upgrade(self, proto, packet):
socktype = proto._conn.socktype
new_socktype = {"tcp" : "ssl", "ws" : "wss"}.get(socktype)
if not new_socktype:
raise ValueError(f"cannot upgrade {socktype} to ssl")
self.cancel_verify_connection_accepted(proto)
self.cancel_upgrade_to_rfb_timer(proto)
if proto in self._potential_protocols:
self._potential_protocols.remove(proto)
ssllog("ssl-upgrade: %s", packet[1:])
conn = proto.steal_connection()
# threads should be able to terminate immediately
# as there's no traffic yet:
ioe = proto.wait_for_io_threads_exit(1)
if not ioe:
self.disconnect_protocol(proto, "failed to terminate network threads for ssl upgrade")
conn.close()
return
options = conn.options
socktype = conn.socktype
ssl_sock = self._ssl_wrap_socket(socktype, conn._socket, options)
if not ssl_sock:
self.disconnect_protocol(proto, "failed to upgrade socket to ssl")
conn.close()
return
# sock, sockname, address, endpoint = conn._socket, conn.local, conn.remote, conn.endpoint
ssl_conn = SSLSocketConnection(ssl_sock, conn.local, conn.remote, conn.endpoint, "ssl", socket_options=options)
ssl_conn.socktype_wrapped = socktype
protocol_class = get_server_protocol_class(new_socktype)
proto = self.make_protocol(new_socktype, ssl_conn, options, protocol_class)
ssllog.info("upgraded %s to %s", conn, new_socktype)


def setup_encryption(self, proto:SocketProtocol, c : typedict) -> Optional[Dict[str,Any]]:
def auth_failed(msg):
self.auth_failed(proto, msg)
Expand Down

0 comments on commit c3b450b

Please sign in to comment.