From 40c477bd1afd09b00fc69e008dc59569838ae963 Mon Sep 17 00:00:00 2001 From: Jeremy Tuloup Date: Tue, 25 May 2021 01:06:14 +0200 Subject: [PATCH] [Notebook port 4835] Add UNIX socket support to notebook server (#525) * [Notebook port 4835] Add UNIX socket support to notebook server * Address some codeQL issues Co-authored-by: Kevin Bates --- .github/workflows/integration-tests.yml | 47 +++ jupyter_server/__init__.py | 2 + jupyter_server/base/handlers.py | 19 +- jupyter_server/serverapp.py | 332 ++++++++++++++---- jupyter_server/tests/conftest.py | 33 ++ jupyter_server/tests/test_serverapp.py | 78 ++++ jupyter_server/tests/unix_sockets/__init__.py | 0 jupyter_server/tests/unix_sockets/conftest.py | 32 ++ jupyter_server/tests/unix_sockets/test_api.py | 77 ++++ .../test_serverapp_integration.py | 174 +++++++++ jupyter_server/utils.py | 132 ++++++- setup.cfg | 1 + 12 files changed, 841 insertions(+), 86 deletions(-) create mode 100644 .github/workflows/integration-tests.yml create mode 100644 jupyter_server/tests/unix_sockets/__init__.py create mode 100644 jupyter_server/tests/unix_sockets/conftest.py create mode 100644 jupyter_server/tests/unix_sockets/test_api.py create mode 100644 jupyter_server/tests/unix_sockets/test_serverapp_integration.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 0000000000..9a4c03f038 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,47 @@ +name: Jupyter Server Integration Tests [Linux] +on: + push: + branches: 'master' + pull_request: + branches: '*' +jobs: + build: + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: [ubuntu] + python-version: [ '3.6', '3.7', '3.8', '3.9', 'pypy3' ] + steps: + - name: Checkout + uses: actions/checkout@v1 + - name: Install Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + architecture: 'x64' + - name: Upgrade packaging dependencies + run: | + pip install --upgrade pip setuptools wheel --user + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + - name: Cache pip + uses: actions/cache@v1 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.py') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + - name: Install the Python dependencies + run: | + pip install -e ".[test]" + - name: List installed packages + run: | + pip freeze + pip check + - name: Run the tests + run: | + pytest -vv --integration_tests jupyter_server diff --git a/jupyter_server/__init__.py b/jupyter_server/__init__.py index 9a2bce1ffc..bb128fd043 100644 --- a/jupyter_server/__init__.py +++ b/jupyter_server/__init__.py @@ -10,6 +10,8 @@ os.path.join(os.path.dirname(__file__), 'templates'), ] +DEFAULT_JUPYTER_SERVER_PORT = 8888 + del os from ._version import version_info, __version__ diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 743f12d31d..3d1a336301 100755 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -31,7 +31,7 @@ import jupyter_server from jupyter_server._tz import utcnow from jupyter_server.i18n import combine_translations -from jupyter_server.utils import ensure_async, url_path_join, url_is_absolute, url_escape +from jupyter_server.utils import ensure_async, url_path_join, url_is_absolute, url_escape, urldecode_unix_socket_path from jupyter_server.services.security import csp_report_uri #----------------------------------------------------------------------------- @@ -462,13 +462,18 @@ def check_host(self): if host.startswith('[') and host.endswith(']'): host = host[1:-1] - try: - addr = ipaddress.ip_address(host) - except ValueError: - # Not an IP address: check against hostnames - allow = host in self.settings.get('local_hostnames', ['localhost']) + # UNIX socket handling + check_host = urldecode_unix_socket_path(host) + if check_host.startswith('/') and os.path.exists(check_host): + allow = True else: - allow = addr.is_loopback + try: + addr = ipaddress.ip_address(host) + except ValueError: + # Not an IP address: check against hostnames + allow = host in self.settings.get('local_hostnames', ['localhost']) + else: + allow = addr.is_loopback if not allow: self.log.warning( diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index e3e404e541..2dc2af39df 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -21,6 +21,7 @@ import select import signal import socket +import stat import sys import tempfile import threading @@ -61,7 +62,11 @@ from tornado.httputil import url_concat from tornado.log import LogFormatter, app_log, access_log, gen_log +if not sys.platform.startswith('win'): + from tornado.netutil import bind_unix_socket + from jupyter_server import ( + DEFAULT_JUPYTER_SERVER_PORT, DEFAULT_STATIC_FILES_PATH, DEFAULT_TEMPLATE_PATH_LIST, __version__, @@ -104,7 +109,10 @@ check_pid, url_escape, urljoin, - pathname2url + pathname2url, + unix_socket_in_use, + urlencode_unix_socket_path, + fetch ) from jupyter_server.extension.serverextension import ServerExtensionApp @@ -148,7 +156,8 @@ view=['jupyter_server.view.handlers'] ) -DEFAULT_SERVER_PORT = 8888 +# Added for backwards compatibility from classic notebook server. +DEFAULT_SERVER_PORT = DEFAULT_JUPYTER_SERVER_PORT #----------------------------------------------------------------------------- # Helper functions @@ -404,7 +413,7 @@ def start(self): def shutdown_server(server_info, timeout=5, log=None): - """Shutdown a notebook server in a separate process. + """Shutdown a Jupyter server in a separate process. *server_info* should be a dictionary as produced by list_running_servers(). @@ -418,12 +427,14 @@ def shutdown_server(server_info, timeout=5, log=None): from tornado.httpclient import HTTPClient, HTTPRequest url = server_info['url'] pid = server_info['pid'] - req = HTTPRequest(url + 'api/shutdown', method='POST', body=b'', headers={ - 'Authorization': 'token ' + server_info['token'] - }) + if log: log.debug("POST request to %sapi/shutdown", url) - HTTPClient().fetch(req) + r = fetch( + url, + method="POST", + headers={'Authorization': 'token ' + server_info['token']} + ) # Poll to see if it shut down. for _ in range(timeout*10): if not check_pid(pid): @@ -454,38 +465,67 @@ class JupyterServerStopApp(JupyterApp): version = __version__ description = "Stop currently running Jupyter server for a given port" - port = Integer(DEFAULT_SERVER_PORT, config=True, - help=f"Port of the server to be killed. Default {DEFAULT_SERVER_PORT}") + port = Integer(DEFAULT_JUPYTER_SERVER_PORT, config=True, + help="Port of the server to be killed. Default %s" % DEFAULT_JUPYTER_SERVER_PORT) + + sock = Unicode(u'', config=True, + help="UNIX socket of the server to be killed.") def parse_command_line(self, argv=None): super(JupyterServerStopApp, self).parse_command_line(argv) if self.extra_args: - self.port=int(self.extra_args[0]) + try: + self.port = int(self.extra_args[0]) + except ValueError: + # self.extra_args[0] was not an int, so it must be a string (unix socket). + self.sock = self.extra_args[0] def shutdown_server(self, server): return shutdown_server(server, log=self.log) + def _shutdown_or_exit(self, target_endpoint, server): + print("Shutting down server on %s..." % target_endpoint) + if not self.shutdown_server(server): + sys.exit("Could not stop server on %s" % target_endpoint) + + @staticmethod + def _maybe_remove_unix_socket(socket_path): + try: + os.unlink(socket_path) + except (OSError, IOError): + pass + def start(self): servers = list(list_running_servers(self.runtime_dir)) if not servers: - self.exit("There are no running servers") + self.exit("There are no running servers (per %s)" % self.runtime_dir) for server in servers: - if server['port'] == self.port: - print("Shutting down server on port", self.port, "...") - if not self.shutdown_server(server): - sys.exit("Could not stop server") - return - else: - print("There is currently no server running on port {}".format(self.port), file=sys.stderr) - print("Ports currently in use:", file=sys.stderr) - for server in servers: - print(" - {}".format(server['port']), file=sys.stderr) - self.exit(1) + if self.sock: + sock = server.get('sock', None) + if sock and sock == self.sock: + self._shutdown_or_exit(sock, server) + # Attempt to remove the UNIX socket after stopping. + self._maybe_remove_unix_socket(sock) + return + elif self.port: + port = server.get('port', None) + if port == self.port: + self._shutdown_or_exit(port, server) + return + current_endpoint = self.sock or self.port + print( + "There is currently no server running on {}".format(current_endpoint), + file=sys.stderr + ) + print("Ports/sockets currently in use:", file=sys.stderr) + for server in servers: + print(" - {}".format(server.get('sock') or server['port']), file=sys.stderr) + self.exit(1) class JupyterServerListApp(JupyterApp): version = __version__ - description=_i18n("List currently running notebook servers.") + description=_i18n("List currently running Jupyter servers.") flags = dict( jsonlist=({'JupyterServerListApp': {'jsonlist': True}}, @@ -496,7 +536,7 @@ class JupyterServerListApp(JupyterApp): jsonlist = Bool(False, config=True, help=_i18n("If True, the output will be a JSON list of objects, one per " - "active notebook server, each with the details from the " + "active Jupyer server, each with the details from the " "relevant server info file.")) json = Bool(False, config=True, help=_i18n("If True, each line of output will be a JSON object with the " @@ -563,6 +603,8 @@ def start(self): 'ip': 'ServerApp.ip', 'port': 'ServerApp.port', 'port-retries': 'ServerApp.port_retries', + 'sock': 'ServerApp.sock', + 'sock-mode': 'ServerApp.sock_mode', 'transport': 'KernelManager.transport', 'keyfile': 'ServerApp.keyfile', 'certfile': 'ServerApp.certfile', @@ -702,7 +744,7 @@ def _default_ip(self): return 'localhost' @validate('ip') - def _valdate_ip(self, proposal): + def _validate_ip(self, proposal): value = proposal['value'] if value == u'*': value = u'' @@ -722,8 +764,10 @@ def _valdate_ip(self, proposal): ) port_env = 'JUPYTER_PORT' - port_default_value = DEFAULT_SERVER_PORT - port = Integer(port_default_value, config=True, + port_default_value = DEFAULT_JUPYTER_SERVER_PORT + + port = Integer( + config=True, help=_i18n("The port the server will listen on (env: JUPYTER_PORT).") ) @@ -742,6 +786,37 @@ def port_default(self): def port_retries_default(self): return int(os.getenv(self.port_retries_env, self.port_retries_default_value)) + sock = Unicode(u'', config=True, + help="The UNIX socket the Jupyter server will listen on." + ) + + sock_mode = Unicode('0600', config=True, + help="The permissions mode for UNIX socket creation (default: 0600)." + ) + + @validate('sock_mode') + def _validate_sock_mode(self, proposal): + value = proposal['value'] + try: + converted_value = int(value.encode(), 8) + assert all(( + # Ensure the mode is at least user readable/writable. + bool(converted_value & stat.S_IRUSR), + bool(converted_value & stat.S_IWUSR), + # And isn't out of bounds. + converted_value <= 2 ** 12 + )) + except ValueError: + raise TraitError( + 'invalid --sock-mode value: %s, please specify as e.g. "0600"' % value + ) + except AssertionError: + raise TraitError( + 'invalid --sock-mode value: %s, must have u+rw (0600) at a minimum' % value + ) + return value + + certfile = Unicode(u'', config=True, help=_i18n("""The full path to an SSL/TLS certificate file.""") ) @@ -786,7 +861,7 @@ def _default_cookie_secret(self): def _write_cookie_secret_file(self, secret): """write my secret to my secret_file""" - self.log.info(_i18n("Writing notebook server cookie secret to %s"), self.cookie_secret_file) + self.log.info(_i18n("Writing Jupyter server cookie secret to %s"), self.cookie_secret_file) try: with secure_write(self.cookie_secret_file, True) as f: f.write(secret) @@ -1306,7 +1381,7 @@ def _update_server_extensions(self, change): self.server_extensions = change['new'] jpserver_extensions = Dict({}, config=True, - help=(_i18n("Dict of Python modules to load as notebook server extensions." + help=(_i18n("Dict of Python modules to load as Jupyter server extensions." "Entry values can be used to enable and disable the loading of" "the extensions. The extensions will be loaded in alphabetical " "order.")) @@ -1470,6 +1545,36 @@ def init_webapp(self): self.log.critical(_i18n("\t$ python -m jupyter_server.auth password")) sys.exit(1) + # Socket options validation. + if self.sock: + if self.port != DEFAULT_JUPYTER_SERVER_PORT: + self.log.critical( + ('Options --port and --sock are mutually exclusive. Aborting.'), + ) + sys.exit(1) + else: + # Reset the default port if we're using a UNIX socket. + self.port = 0 + + if self.open_browser: + # If we're bound to a UNIX socket, we can't reliably connect from a browser. + self.log.info( + ('Ignoring --ServerApp.open_browser due to --sock being used.'), + ) + + if self.file_to_run: + self.log.critical( + ('Options --ServerApp.file_to_run and --sock are mutually exclusive.'), + ) + sys.exit(1) + + if sys.platform.startswith('win'): + self.log.critical( + ('Option --sock is not supported on Windows, but got value of %s. Aborting.' % self.sock), + ) + sys.exit(1) + + self.web_app = ServerWebApplication( self, self.default_services, self.kernel_manager, self.contents_manager, self.session_manager, self.kernel_spec_manager, @@ -1519,54 +1624,79 @@ def init_resources(self): ) resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) - @property - def display_url(self): - if self.custom_display_url: - parts = urllib.parse.urlparse(self.custom_display_url) - path = parts.path - ip = parts.hostname + def _get_urlparts(self, path=None, include_token=False): + """Constructs a urllib named tuple, ParseResult, + with default values set by server config. + The returned tuple can be manipulated using the `_replace` method. + """ + if self.sock: + scheme = 'http+unix' + netloc = urlencode_unix_socket_path(self.sock) else: - path = None + # Handle nonexplicit hostname. if self.ip in ('', '0.0.0.0'): - ip = "%s" % socket.gethostname() + ip = "%s" % socket.gethostname() else: ip = self.ip - - token = None - if self.token: - # Don't log full token if it came from config - token = self.token if self._token_generated else '...' - - url = ( - self.get_url(ip=ip, path=path, token=token) - + '\n ' - + self.get_url(ip='127.0.0.1', path=path, token=token) - ) - return url - - @property - def connection_url(self): - ip = self.ip if self.ip else 'localhost' - return self.get_url(ip=ip, path=self.base_url) - - def get_url(self, ip=None, path=None, token=None): - """Build a url for the application with reasonable defaults.""" - if not ip: - ip = self.ip if self.ip else 'localhost' + netloc = "{ip}:{port}".format(ip=ip, port=self.port) + if self.certfile: + scheme = 'https' + else: + scheme = 'http' if not path: path = self.default_url - # Build query string. - if token: - token = urllib.parse.urlencode({'token': token}) + query = None + if include_token: + if self.token: # Don't log full token if it came from config + token = self.token if self._token_generated else '...' + query = urllib.parse.urlencode({'token': token}) # Build the URL Parts to dump. urlparts = urllib.parse.ParseResult( - scheme='https' if self.certfile else 'http', - netloc="{ip}:{port}".format(ip=ip, port=self.port), + scheme=scheme, + netloc=netloc, path=path, params=None, - query=token, + query=query, fragment=None ) + return urlparts + + @property + def public_url(self): + parts = self._get_urlparts(include_token=True) + # Update with custom pieces. + if self.custom_display_url: + # Parse custom display_url + custom = urllib.parse.urlparse(self.custom_display_url)._asdict() + # Get pieces that are matter (non None) + custom_updates = {key: item for key, item in custom.items() if item} + # Update public URL parts with custom pieces. + parts = parts._replace(**custom_updates) + return parts.geturl() + + @property + def local_url(self): + parts = self._get_urlparts(include_token=True) + # Update with custom pieces. + if not self.sock: + parts = parts._replace(netloc="127.0.0.1:{port}".format(port=self.port)) + return parts.geturl() + + @property + def display_url(self): + """Human readable string with URLs for interacting + with the running Jupyter Server + """ + url = ( + self.public_url + + '\n or ' + + self.local_url + ) + return url + + @property + def connection_url(self): + urlparts = self._get_urlparts(path=self.base_url) return urlparts.geturl() def init_terminals(self): @@ -1779,6 +1909,37 @@ def init_httpserver(self): max_body_size=self.max_body_size, max_buffer_size=self.max_buffer_size ) + + success = self._bind_http_server() + if not success: + self.log.critical(_i18n('ERROR: the Jupyter server could not be started because ' + 'no available port could be found.')) + self.exit(1) + + def _bind_http_server(self): + return self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp() + + def _bind_http_server_unix(self): + if unix_socket_in_use(self.sock): + self.log.warning(_i18n('The socket %s is already in use.') % self.sock) + return False + + try: + sock = bind_unix_socket(self.sock, mode=int(self.sock_mode.encode(), 8)) + self.http_server.add_socket(sock) + except socket.error as e: + if e.errno == errno.EADDRINUSE: + self.log.warning(_i18n('The socket %s is already in use.') % self.sock) + return False + elif e.errno in (errno.EACCES, getattr(errno, 'WSAEACCES', errno.EACCES)): + self.log.warning(_i18n("Permission to listen on sock %s denied") % self.sock) + return False + else: + raise + else: + return True + + def _bind_http_server_tcp(self): success = None for port in random_ports(self.port, self.port_retries+1): try: @@ -1791,7 +1952,7 @@ def init_httpserver(self): self.log.info(_i18n('The port %i is already in use.') % port) continue elif e.errno in (errno.EACCES, getattr(errno, 'WSAEACCES', errno.EACCES)): - self.log.warning(_i18n("Permission to listen on port %i denied") % port) + self.log.warning(_i18n("Permission to listen on port %i denied.") % port) continue else: raise @@ -1801,12 +1962,14 @@ def init_httpserver(self): break if not success: if self.port_retries: - self.log.critical(_i18n('ERROR: the notebook server could not be started because ' + self.log.critical(_i18n('ERROR: the Jupyter server could not be started because ' 'no available port could be found.')) else: - self.log.critical(_i18n('ERROR: the notebook server could not be started because ' + self.log.critical(_i18n('ERROR: the Jupyter server could not be started because ' 'port %i is not available.') % port) self.exit(1) + return success + @staticmethod def _init_asyncio_patch(): @@ -1941,6 +2104,7 @@ def server_info(self): return {'url': self.connection_url, 'hostname': self.ip if self.ip else 'localhost', 'port': self.port, + 'sock': self.sock, 'secure': bool(self.certfile), 'base_url': self.base_url, 'token': self.token, @@ -2130,19 +2294,31 @@ def start_app(self): self.write_browser_open_files() # Handle the browser opening. - if self.open_browser: + if self.open_browser and not self.sock: self.launch_browser() if self.token and self._token_generated: # log full URL with generated token, so there's a copy/pasteable link # with auth info. - self.log.critical('\n'.join([ - '\n', - 'To access the server, open this file in a browser:', - ' %s' % urljoin('file:', pathname2url(self.browser_open_file)), - 'Or copy and paste one of these URLs:', - ' %s' % self.display_url, - ])) + if self.sock: + self.log.critical('\n'.join([ + '\n', + 'Jupyter Server is listening on %s' % self.display_url, + '', + ( + 'UNIX sockets are not browser-connectable, but you can tunnel to ' + 'the instance via e.g.`ssh -L 8888:%s -N user@this_host` and then ' + 'open e.g. %s in a browser.' + ) % (self.sock, self.connection_url) + ])) + else: + self.log.critical('\n'.join([ + '\n', + 'To access the server, open this file in a browser:', + ' %s' % urljoin('file:', pathname2url(self.browser_open_file)), + 'Or copy and paste one of these URLs:', + ' %s' % self.display_url, + ])) def _cleanup(self): """General cleanup of files and kernels created @@ -2186,11 +2362,11 @@ def _stop(): def list_running_servers(runtime_dir=None): - """Iterate over the server info files of running notebook servers. + """Iterate over the server info files of running Jupyter servers. Given a runtime directory, find jpserver-* files in the security directory, and yield dicts of their information, each one pertaining to - a currently running notebook server instance. + a currently running Jupyter server instance. """ if runtime_dir is None: runtime_dir = jupyter_runtime_dir() diff --git a/jupyter_server/tests/conftest.py b/jupyter_server/tests/conftest.py index bdac3802bc..07dc30bb35 100644 --- a/jupyter_server/tests/conftest.py +++ b/jupyter_server/tests/conftest.py @@ -1,3 +1,36 @@ +import pytest + + pytest_plugins = [ "jupyter_server.pytest_plugin" ] + + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--integration_tests", + default=False, + type=bool, + help="only run tests with the 'integration_test' pytest mark.", + ) + + +def pytest_configure(config): + # register an additional marker + config.addinivalue_line( + "markers", "integration_test" + ) + + +def pytest_runtest_setup(item): + is_integration_test = any([mark for mark in item.iter_markers(name="integration_test")]) + + if item.config.getoption("--integration_tests") is True: + if not is_integration_test: + pytest.skip("Only running tests marked as 'integration_test'.") + else: + if is_integration_test: + pytest.skip("Skipping this test because it's marked 'integration_test'. Run integration tests using the `--integration_tests` flag.") diff --git a/jupyter_server/tests/test_serverapp.py b/jupyter_server/tests/test_serverapp.py index 7176e41bcb..a74b731469 100644 --- a/jupyter_server/tests/test_serverapp.py +++ b/jupyter_server/tests/test_serverapp.py @@ -208,3 +208,81 @@ def test_resolve_file_to_run_and_root_dir( # Clear the singleton instance after each run. ServerApp.clear_instance() + + +# Test the URLs returned by ServerApp. The `` piece +# in urls shown below will be replaced with the token +# generated by the ServerApp on instance creation. +@pytest.mark.parametrize( + 'config,public_url,local_url,connection_url', + [ + # Token is hidden when configured. + ( + {"token": "test"}, + "http://localhost:8888/?token=...", + "http://127.0.0.1:8888/?token=...", + "http://localhost:8888/" + ), + # Verify port number has changed + ( + {"port": 9999}, + "http://localhost:9999/?token=", + "http://127.0.0.1:9999/?token=", + "http://localhost:9999/" + ), + ( + {"ip": "1.1.1.1"}, + "http://1.1.1.1:8888/?token=", + "http://127.0.0.1:8888/?token=", + "http://1.1.1.1:8888/" + ), + # Verify that HTTPS is returned when certfile is given + ( + {"certfile": "/path/to/dummy/file"}, + "https://localhost:8888/?token=", + "https://127.0.0.1:8888/?token=", + "https://localhost:8888/" + ), + # Verify changed port and a custom display URL + ( + {"port": 9999, "custom_display_url": "http://test.org"}, + "http://test.org/?token=", + "http://127.0.0.1:9999/?token=", + "http://localhost:9999/" + ), + ( + {"base_url": "/", "default_url": "/test/"}, + "http://localhost:8888/test/?token=", + "http://127.0.0.1:8888/test/?token=", + "http://localhost:8888/" + ), + # Verify unix socket URLs are handled properly + ( + {"sock": "/tmp/jp-test.sock"}, + "http+unix://%2Ftmp%2Fjp-test.sock/?token=", + "http+unix://%2Ftmp%2Fjp-test.sock/?token=", + "http+unix://%2Ftmp%2Fjp-test.sock/" + ), + ( + {"base_url": "/", "default_url": "/test/", "sock": "/tmp/jp-test.sock"}, + "http+unix://%2Ftmp%2Fjp-test.sock/test/?token=", + "http+unix://%2Ftmp%2Fjp-test.sock/test/?token=", + "http+unix://%2Ftmp%2Fjp-test.sock/" + ), + ] +) +def test_urls(config, public_url, local_url, connection_url): + # Verify we're working with a clean instance. + ServerApp.clear_instance() + serverapp = ServerApp.instance(**config) + # If a token is generated (not set by config), update + # expected_url with token. + if serverapp._token_generated: + public_url = public_url.replace("", serverapp.token) + local_url = local_url.replace("", serverapp.token) + connection_url = connection_url.replace("", serverapp.token) + assert serverapp.public_url == public_url + assert serverapp.local_url == local_url + assert serverapp.connection_url == connection_url + # Cleanup singleton after test. + ServerApp.clear_instance() diff --git a/jupyter_server/tests/unix_sockets/__init__.py b/jupyter_server/tests/unix_sockets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/tests/unix_sockets/conftest.py b/jupyter_server/tests/unix_sockets/conftest.py new file mode 100644 index 0000000000..c3eb43f4fc --- /dev/null +++ b/jupyter_server/tests/unix_sockets/conftest.py @@ -0,0 +1,32 @@ +import os +import pathlib +import pytest +from jupyter_server import DEFAULT_JUPYTER_SERVER_PORT + + +@pytest.fixture +def jp_process_id(): + """Choose a random unused process ID.""" + return os.getpid() + + +@pytest.fixture +def jp_unix_socket_file(jp_process_id): + """Define a temporary socket connection""" + # Rely on `/tmp` to avoid any Linux socket length max buffer + # issues. Key on PID for process-wise concurrency. + tmp_path = pathlib.Path('/tmp') + filename = 'jupyter_server.{}.sock'.format(jp_process_id) + jp_unix_socket_file = tmp_path.joinpath(filename) + yield str(jp_unix_socket_file) + # Clean up the file after the test runs. + if jp_unix_socket_file.exists(): + jp_unix_socket_file.unlink() + + +@pytest.fixture +def jp_http_port(): + """Set the port to the default value, since sock + and port cannot both be configured at the same time. + """ + return DEFAULT_JUPYTER_SERVER_PORT diff --git a/jupyter_server/tests/unix_sockets/test_api.py b/jupyter_server/tests/unix_sockets/test_api.py new file mode 100644 index 0000000000..aaa0d219e7 --- /dev/null +++ b/jupyter_server/tests/unix_sockets/test_api.py @@ -0,0 +1,77 @@ +import sys +import pytest + +# Skip this module if on Windows. Unix sockets are not available on Windows. +pytestmark = pytest.mark.skipif( + sys.platform.startswith('win'), + reason="Unix sockets are not available on Windows." +) + +import os +import urllib +import pathlib + +if not sys.platform.startswith('win'): + from tornado.netutil import bind_unix_socket + +from tornado.escape import url_escape + +import jupyter_server.serverapp +from jupyter_server import DEFAULT_JUPYTER_SERVER_PORT +from jupyter_server.utils import ( + url_path_join, + urlencode_unix_socket, + async_fetch, +) + + +@pytest.fixture +def jp_server_config(jp_unix_socket_file): + """Configure the serverapp fixture with the unix socket.""" + return { + "ServerApp": { + "sock" : jp_unix_socket_file, + "allow_remote_access": True + } + } + + +@pytest.fixture +def http_server_port(jp_unix_socket_file, jp_process_id): + """Unix socket and process ID used by tornado's HTTP Server. + + Overrides the http_server_port fixture from pytest-tornasync and replaces + it with a tuple: (unix socket, process id) + """ + return (bind_unix_socket(jp_unix_socket_file), jp_process_id) + + +@pytest.fixture +def jp_unix_socket_fetch(jp_unix_socket_file, jp_auth_header, jp_base_url, http_server, io_loop): + """A fetch fixture for Jupyter Server tests that use the unix_serverapp fixture""" + async def client(*parts, headers={}, params={}, **kwargs): + # Handle URL strings + host_url = urlencode_unix_socket(jp_unix_socket_file) + path_url = url_path_join(jp_base_url, *parts) + params_url = urllib.parse.urlencode(params) + url = url_path_join(host_url, path_url+ "?" + params_url) + r = await async_fetch(url, headers=headers, io_loop=io_loop, **kwargs) + return r + return client + + +async def test_get_spec(jp_unix_socket_fetch): + # Handle URL strings + parts = ["api", "spec.yaml"] + + # Make request and verify it succeeds.' + response = await jp_unix_socket_fetch(*parts) + assert response.code == 200 + assert response.body != None + + +async def test_list_running_servers(jp_unix_socket_file, http_server): + """Test that a server running on unix sockets is discovered by the server list""" + servers = list(jupyter_server.serverapp.list_running_servers()) + assert len(servers) >= 1 + assert jp_unix_socket_file in {info['sock'] for info in servers} diff --git a/jupyter_server/tests/unix_sockets/test_serverapp_integration.py b/jupyter_server/tests/unix_sockets/test_serverapp_integration.py new file mode 100644 index 0000000000..6d13524a14 --- /dev/null +++ b/jupyter_server/tests/unix_sockets/test_serverapp_integration.py @@ -0,0 +1,174 @@ +import sys +import stat +import pytest + +# Skip this module if on Windows. Unix sockets are not available on Windows. +pytestmark = pytest.mark.skipif( + sys.platform.startswith('win'), + reason="Unix sockets are not available on Windows." +) + +import os +import subprocess +import pathlib +import time + +from jupyter_server.utils import urlencode_unix_socket, urlencode_unix_socket_path + + +@pytest.mark.integration_test +def test_shutdown_sock_server_integration(jp_unix_socket_file): + url = urlencode_unix_socket(jp_unix_socket_file).encode() + encoded_sock_path = urlencode_unix_socket_path(jp_unix_socket_file) + p = subprocess.Popen( + ['jupyter-server', '--sock=%s' % jp_unix_socket_file, '--sock-mode=0700'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + complete = False + for line in iter(p.stderr.readline, b''): + if url in line: + complete = True + break + + assert complete, 'did not find socket URL in stdout when launching notebook' + + socket_path = encoded_sock_path.encode() + assert socket_path in subprocess.check_output(['jupyter-server', 'list']) + + # Ensure umask is properly applied. + assert stat.S_IMODE(os.lstat(jp_unix_socket_file).st_mode) == 0o700 + + try: + subprocess.check_output(['jupyter-server', 'stop'], stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + assert 'There is currently no server running on' in e.output.decode() + else: + raise AssertionError('expected stop command to fail due to target mis-match') + + assert encoded_sock_path.encode() in subprocess.check_output(['jupyter-server', 'list']) + + subprocess.check_output(['jupyter-server', 'stop', jp_unix_socket_file]) + + assert encoded_sock_path.encode() not in subprocess.check_output(['jupyter-server', 'list']) + + p.wait() + + +@pytest.mark.integration_test +def test_sock_server_validate_sockmode_type(): + try: + subprocess.check_output( + ['jupyter-server', '--sock=/tmp/nonexistent', '--sock-mode=badbadbad'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert 'badbadbad' in e.output.decode() + else: + raise AssertionError('expected execution to fail due to validation of --sock-mode param') + + +@pytest.mark.integration_test +def test_sock_server_validate_sockmode_accessible(): + try: + subprocess.check_output( + ['jupyter-server', '--sock=/tmp/nonexistent', '--sock-mode=0444'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert '0444' in e.output.decode() + else: + raise AssertionError('expected execution to fail due to validation of --sock-mode param') + + +def _ensure_stopped(check_msg='There are no running servers'): + try: + subprocess.check_output( + ['jupyter-server', 'stop'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert check_msg in e.output.decode() + else: + raise AssertionError('expected all servers to be stopped') + + +@pytest.mark.integration_test +def test_stop_multi_integration(jp_unix_socket_file, jp_http_port): + """Tests lifecycle behavior for mixed-mode server types w/ default ports. + + Mostly suitable for local dev testing due to reliance on default port binding. + """ + TEST_PORT = '9797' + MSG_TMPL = 'Shutting down server on {}...' + + _ensure_stopped() + + # Default port. + p1 = subprocess.Popen( + ['jupyter-server', '--no-browser'] + ) + + # Unix socket. + p2 = subprocess.Popen( + ['jupyter-server', '--sock=%s' % jp_unix_socket_file] + ) + + # Specified port + p3 = subprocess.Popen( + ['jupyter-server', '--no-browser', '--port=%s' % TEST_PORT] + ) + + time.sleep(3) + + shutdown_msg = MSG_TMPL.format(jp_http_port) + assert shutdown_msg in subprocess.check_output( + ['jupyter-server', 'stop'] + ).decode() + + _ensure_stopped('There is currently no server running on 8888') + + assert MSG_TMPL.format(jp_unix_socket_file) in subprocess.check_output( + ['jupyter-server', 'stop', jp_unix_socket_file] + ).decode() + + assert MSG_TMPL.format(TEST_PORT) in subprocess.check_output( + ['jupyter-server', 'stop', TEST_PORT] + ).decode() + + _ensure_stopped() + + p1.wait() + p2.wait() + p3.wait() + + +@pytest.mark.integration_test +def test_launch_socket_collision(jp_unix_socket_file): + """Tests UNIX socket in-use detection for lifecycle correctness.""" + sock = jp_unix_socket_file + check_msg = 'socket %s is already in use' % sock + + _ensure_stopped() + + # Start a server. + cmd = ['jupyter-server', '--sock=%s' % sock] + p1 = subprocess.Popen(cmd) + time.sleep(3) + + # Try to start a server bound to the same UNIX socket. + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as cpe: + assert check_msg in cpe.output.decode() + except Exception as ex: + raise AssertionError(f"expected 'already in use' error, got '{ex}'!") + else: + raise AssertionError(f"expected 'already in use' error, got success instead!") + + # Stop the background server, ensure it's stopped and wait on the process to exit. + subprocess.check_call(['jupyter-server', 'stop', sock]) + + _ensure_stopped() + + p1.wait() diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index 7a86f38581..e0a532bbdd 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -7,12 +7,18 @@ import errno import inspect import os +import socket import sys from distutils.version import LooseVersion +from contextlib import contextmanager -from urllib.parse import quote, unquote, urlparse, urljoin +from urllib.parse import (quote, unquote, urlparse, urljoin, + urlsplit, urlunsplit, SplitResult) from urllib.request import pathname2url +from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest +from tornado.netutil import Resolver +from tornado.ioloop import IOLoop def url_path_join(*pieces): @@ -222,3 +228,127 @@ def wrapped(): raise e return result return wrapped() + + +def urlencode_unix_socket_path(socket_path): + """Encodes a UNIX socket path string from a socket path for the `http+unix` URI form.""" + return socket_path.replace('/', '%2F') + + +def urldecode_unix_socket_path(socket_path): + """Decodes a UNIX sock path string from an encoded sock path for the `http+unix` URI form.""" + return socket_path.replace('%2F', '/') + + +def urlencode_unix_socket(socket_path): + """Encodes a UNIX socket URL from a socket path for the `http+unix` URI form.""" + return 'http+unix://%s' % urlencode_unix_socket_path(socket_path) + + +def unix_socket_in_use(socket_path): + """Checks whether a UNIX socket path on disk is in use by attempting to connect to it.""" + if not os.path.exists(socket_path): + return False + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(socket_path) + except socket.error: + return False + else: + return True + finally: + sock.close() + + +@contextmanager +def _request_for_tornado_client( + urlstring, + method="GET", + body=None, + headers=None +): + """A utility that provides a context that handles + HTTP, HTTPS, and HTTP+UNIX request. + Creates a tornado HTTPRequest object with a URL + that tornado's HTTPClients can accept. + If the request is made to a unix socket, temporarily + configure the AsyncHTTPClient to resolve the URL + and connect to the proper socket. + """ + parts = urlsplit(urlstring) + if parts.scheme in ["http", "https"]: + pass + elif parts.scheme == "http+unix": + # If unix socket, mimic HTTP. + parts = SplitResult( + scheme="http", + netloc=parts.netloc, + path=parts.path, + query=parts.query, + fragment=parts.fragment + ) + + class UnixSocketResolver(Resolver): + """A resolver that routes HTTP requests to unix sockets + in tornado HTTP clients. + Due to constraints in Tornados' API, the scheme of the + must be `http` (not `http+unix`). Applications should replace + the scheme in URLS before making a request to the HTTP client. + """ + def initialize(self, resolver): + self.resolver = resolver + + def close(self): + self.resolver.close() + + async def resolve(self, host, port, *args, **kwargs): + return [ + (socket.AF_UNIX, urldecode_unix_socket_path(host)) + ] + + resolver = UnixSocketResolver(resolver=Resolver()) + AsyncHTTPClient.configure(None, resolver=resolver) + else: + raise Exception("Unknown URL scheme.") + + # Yield the request for the given client. + url = urlunsplit(parts) + request = HTTPRequest( + url, + method=method, + body=body, + headers=headers + ) + yield request + + +def fetch( + urlstring, + method="GET", + body=None, + headers=None +): + """ + Send a HTTP, HTTPS, or HTTP+UNIX request + to a Tornado Web Server. Returns a tornado HTTPResponse. + """ + with _request_for_tornado_client(urlstring) as request: + response = HTTPClient(AsyncHTTPClient).fetch(request) + return response + + +async def async_fetch( + urlstring, + method="GET", + body=None, + headers=None, + io_loop=None +): + """ + Send an asynchronous HTTP, HTTPS, or HTTP+UNIX request + to a Tornado Web Server. Returns a tornado HTTPResponse. + """ + with _request_for_tornado_client(urlstring) as request: + response = await AsyncHTTPClient(io_loop).fetch(request) + return response diff --git a/setup.cfg b/setup.cfg index b9db7693d2..552401894c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = prometheus_client anyio>=3.1.0,<4 websocket-client + requests-unixsocket [options.extras_require] test = coverage; pytest; pytest-cov; pytest-mock; requests; pytest-tornasync; pytest-console-scripts; ipykernel