From 63463c40a51e8a7454d1f587da83077ef697206d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Sat, 14 Oct 2023 19:24:36 -0500 Subject: [PATCH] Add typings to commonly used APIs (#1333) --- examples/simple/simple_ext1/handlers.py | 2 +- examples/simple/simple_ext2/handlers.py | 2 +- jupyter_server/_tz.py | 11 +- jupyter_server/auth/decorator.py | 10 +- jupyter_server/auth/identity.py | 18 +- jupyter_server/auth/login.py | 7 +- jupyter_server/base/handlers.py | 188 ++++++++------- jupyter_server/config_manager.py | 21 +- jupyter_server/extension/application.py | 4 +- jupyter_server/extension/handler.py | 39 +-- jupyter_server/extension/serverextension.py | 44 ++-- jupyter_server/files/handlers.py | 8 +- jupyter_server/gateway/connections.py | 2 +- jupyter_server/gateway/handlers.py | 13 +- jupyter_server/gateway/managers.py | 14 +- jupyter_server/nbconvert/handlers.py | 1 + jupyter_server/serverapp.py | 224 +++++++++--------- .../services/contents/filemanager.py | 4 +- jupyter_server/services/contents/handlers.py | 2 +- jupyter_server/services/contents/manager.py | 4 +- jupyter_server/services/events/handlers.py | 12 +- jupyter_server/services/kernels/handlers.py | 6 +- .../services/kernelspecs/handlers.py | 5 +- jupyter_server/services/sessions/handlers.py | 14 +- jupyter_server/services/shutdown.py | 3 +- jupyter_server/utils.py | 54 +++-- pyproject.toml | 3 +- tests/base/test_handlers.py | 6 +- tests/services/sessions/test_manager.py | 3 +- tests/test_gateway.py | 4 +- tests/test_utils.py | 2 +- 31 files changed, 407 insertions(+), 323 deletions(-) diff --git a/examples/simple/simple_ext1/handlers.py b/examples/simple/simple_ext1/handlers.py index 8bcd843808..fefbdf610b 100644 --- a/examples/simple/simple_ext1/handlers.py +++ b/examples/simple/simple_ext1/handlers.py @@ -36,7 +36,7 @@ class ParameterHandler(ExtensionHandlerMixin, JupyterHandler): def get(self, matched_part=None, *args, **kwargs): """Handle a get with parameters.""" - var1 = self.get_argument("var1", default=None) + var1 = self.get_argument("var1", default="") components = [x for x in self.request.path.split("/") if x] self.write("

Hello Simple App 1 from Handler.

") self.write(f"

matched_part: {url_escape(matched_part)}

") diff --git a/examples/simple/simple_ext2/handlers.py b/examples/simple/simple_ext2/handlers.py index 743790963d..ea649b68d2 100644 --- a/examples/simple/simple_ext2/handlers.py +++ b/examples/simple/simple_ext2/handlers.py @@ -9,7 +9,7 @@ class ParameterHandler(ExtensionHandlerMixin, JupyterHandler): def get(self, matched_part=None, *args, **kwargs): """Get a parameterized response.""" - var1 = self.get_argument("var1", default=None) + var1 = self.get_argument("var1", default="") components = [x for x in self.request.path.split("/") if x] self.write("

Hello Simple App 2 from Handler.

") self.write(f"

matched_part: {url_escape(matched_part)}

") diff --git a/jupyter_server/_tz.py b/jupyter_server/_tz.py index 4444c93db0..24847b4307 100644 --- a/jupyter_server/_tz.py +++ b/jupyter_server/_tz.py @@ -5,7 +5,10 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + from datetime import datetime, timedelta, tzinfo +from typing import Callable # constant for zero offset ZERO = timedelta(0) @@ -14,11 +17,11 @@ class tzUTC(tzinfo): # noqa """tzinfo object for UTC (zero offset)""" - def utcoffset(self, d): + def utcoffset(self, d: datetime | None) -> timedelta: """Compute utcoffset.""" return ZERO - def dst(self, d): + def dst(self, d: datetime | None) -> timedelta: """Compute dst.""" return ZERO @@ -26,7 +29,7 @@ def dst(self, d): UTC = tzUTC() # type:ignore[abstract] -def utc_aware(unaware): +def utc_aware(unaware: Callable[..., datetime]) -> Callable[..., datetime]: """decorator for adding UTC tzinfo to datetime's utcfoo methods""" def utc_method(*args, **kwargs): @@ -40,7 +43,7 @@ def utc_method(*args, **kwargs): utcnow = utc_aware(datetime.utcnow) -def isoformat(dt): +def isoformat(dt: datetime) -> str: """Return iso-formatted timestamp Like .isoformat(), but uses Z for UTC instead of +00:00 diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index 930d79be47..a5d6c0543f 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -3,19 +3,21 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. from functools import wraps -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union, cast from tornado.log import app_log from tornado.web import HTTPError from .utils import HTTP_METHOD_TO_AUTH_ACTION +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + def authorized( - action: Optional[Union[str, Callable]] = None, + action: Optional[Union[str, FuncT]] = None, resource: Optional[str] = None, message: Optional[str] = None, -) -> Callable: +) -> FuncT: """A decorator for tornado.web.RequestHandler methods that verifies whether the current user is authorized to make the following request. @@ -73,4 +75,4 @@ def inner(self, *args, **kwargs): # no-arguments `@authorized` decorator called return wrapper(method) - return wrapper + return cast(FuncT, wrapper) diff --git a/jupyter_server/auth/identity.py b/jupyter_server/auth/identity.py index 72f4b469e0..2440710186 100644 --- a/jupyter_server/auth/identity.py +++ b/jupyter_server/auth/identity.py @@ -29,7 +29,7 @@ # circular imports for type checking if TYPE_CHECKING: - from jupyter_server.base.handlers import JupyterHandler + from jupyter_server.base.handlers import AuthenticatedHandler, JupyterHandler from jupyter_server.serverapp import ServerApp _non_alphanum = re.compile(r"[^A-Za-z0-9]") @@ -321,7 +321,7 @@ def user_from_cookie(self, cookie_value: str) -> User | None: user["color"], ) - def get_cookie_name(self, handler: JupyterHandler) -> str: + def get_cookie_name(self, handler: AuthenticatedHandler) -> str: """Return the login cookie name Uses IdentityProvider.cookie_name, if defined. @@ -333,7 +333,7 @@ def get_cookie_name(self, handler: JupyterHandler) -> str: else: return _non_alphanum.sub("-", f"username-{handler.request.host}") - def set_login_cookie(self, handler: JupyterHandler, user: User) -> None: + def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None: """Call this on handlers to set the login cookie for success""" cookie_options = {} cookie_options.update(self.cookie_options) @@ -350,7 +350,7 @@ def set_login_cookie(self, handler: JupyterHandler, user: User) -> None: handler.set_secure_cookie(cookie_name, self.user_to_cookie(user), **cookie_options) def _force_clear_cookie( - self, handler: JupyterHandler, name: str, path: str = "/", domain: str | None = None + self, handler: AuthenticatedHandler, name: str, path: str = "/", domain: str | None = None ) -> None: """Deletes the cookie with the given name. @@ -376,7 +376,7 @@ def _force_clear_cookie( morsel["domain"] = domain handler.add_header("Set-Cookie", morsel.OutputString()) - def clear_login_cookie(self, handler: JupyterHandler) -> None: + def clear_login_cookie(self, handler: AuthenticatedHandler) -> None: """Clear the login cookie, effectively logging out the session.""" cookie_options = {} cookie_options.update(self.cookie_options) @@ -478,7 +478,7 @@ def generate_anonymous_user(self, handler: JupyterHandler) -> User: handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") return User(user_id, name, display_name, initials, None, color) - def should_check_origin(self, handler: JupyterHandler) -> bool: + def should_check_origin(self, handler: AuthenticatedHandler) -> bool: """Should the Handler check for CORS origin validation? Origin check should be skipped for token-authenticated requests. @@ -489,7 +489,7 @@ def should_check_origin(self, handler: JupyterHandler) -> bool: """ return not self.is_token_authenticated(handler) - def is_token_authenticated(self, handler: JupyterHandler) -> bool: + def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool: """Returns True if handler has been token authenticated. Otherwise, False. Login with a token is used to signal certain things, such as: @@ -713,11 +713,11 @@ def login_available(self): self.settings ) - def should_check_origin(self, handler: JupyterHandler) -> bool: + def should_check_origin(self, handler: AuthenticatedHandler) -> bool: """Whether we should check origin.""" return self.login_handler_class.should_check_origin(handler) # type:ignore[attr-defined] - def is_token_authenticated(self, handler: JupyterHandler) -> bool: + def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool: """Whether we are token authenticated.""" return self.login_handler_class.is_token_authenticated(handler) # type:ignore[attr-defined] diff --git a/jupyter_server/auth/login.py b/jupyter_server/auth/login.py index 2105bab3b8..b9eda58e08 100644 --- a/jupyter_server/auth/login.py +++ b/jupyter_server/auth/login.py @@ -123,9 +123,10 @@ def post(self): if new_password and getattr(self.identity_provider, "allow_password_change", False): config_dir = self.settings.get("config_dir", "") config_file = os.path.join(config_dir, "jupyter_server_config.json") - self.identity_provider.hashed_password = self.settings[ - "password" - ] = set_password(new_password, config_file=config_file) + if hasattr(self.identity_provider, "hashed_password"): + self.identity_provider.hashed_password = self.settings[ + "password" + ] = set_password(new_password, config_file=config_file) self.log.info("Wrote hashed password to %s" % config_file) else: self.set_status(401) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 1c5f93c7e0..e2b136e85a 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -14,7 +14,8 @@ import types import warnings from http.client import responses -from typing import TYPE_CHECKING, Awaitable +from logging import Logger +from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast from urllib.parse import urlparse import prometheus_client @@ -42,7 +43,17 @@ ) if TYPE_CHECKING: - from jupyter_server.auth.identity import User + from jupyter_client.kernelspec import KernelSpecManager + from jupyter_server_terminals.terminalmanager import TerminalManager + from tornado.concurrent import Future + + from jupyter_server.auth.authorizer import Authorizer + from jupyter_server.auth.identity import IdentityProvider, User + from jupyter_server.serverapp import ServerApp + from jupyter_server.services.config.manager import ConfigManager + from jupyter_server.services.contents.manager import ContentsManager + from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager + from jupyter_server.services.sessions.sessionmanager import SessionManager # ----------------------------------------------------------------------------- # Top-level handlers @@ -59,7 +70,7 @@ def json_sys_info(): return _sys_info_cache -def log(): +def log() -> Logger: """Get the application log.""" if Application.initialized(): return Application.instance().log @@ -75,7 +86,7 @@ def base_url(self) -> str: return self.settings.get("base_url", "/") @property - def content_security_policy(self): + def content_security_policy(self) -> str: """The default Content-Security-Policy header Can be overridden by defining Content-Security-Policy in settings['headers'] @@ -93,7 +104,7 @@ def content_security_policy(self): ] ) - def set_default_headers(self): + def set_default_headers(self) -> None: """Set the default headers.""" headers = {} headers["X-Content-Type-Options"] = "nosniff" @@ -114,7 +125,7 @@ def set_default_headers(self): ) @property - def cookie_name(self): + def cookie_name(self) -> str: warnings.warn( """JupyterHandler.login_handler is deprecated in 2.0, use JupyterHandler.identity_provider. @@ -124,7 +135,7 @@ def cookie_name(self): ) return self.identity_provider.get_cookie_name(self) - def force_clear_cookie(self, name, path="/", domain=None): + def force_clear_cookie(self, name: str, path: str = "/", domain: str | None = None) -> None: """Force a cookie clear.""" warnings.warn( """JupyterHandler.login_handler is deprecated in 2.0, @@ -133,9 +144,9 @@ def force_clear_cookie(self, name, path="/", domain=None): DeprecationWarning, stacklevel=2, ) - return self.identity_provider._force_clear_cookie(self, name, path=path, domain=domain) + self.identity_provider._force_clear_cookie(self, name, path=path, domain=domain) - def clear_login_cookie(self): + def clear_login_cookie(self) -> None: """Clear a login cookie.""" warnings.warn( """JupyterHandler.login_handler is deprecated in 2.0, @@ -144,9 +155,9 @@ def clear_login_cookie(self): DeprecationWarning, stacklevel=2, ) - return self.identity_provider.clear_login_cookie(self) + self.identity_provider.clear_login_cookie(self) - def get_current_user(self): + def get_current_user(self) -> str: """Get the current user.""" clsname = self.__class__.__name__ msg = ( @@ -164,7 +175,7 @@ def get_current_user(self): # haven't called get_user in prepare, raise raise RuntimeError(msg) - def skip_check_origin(self): + def skip_check_origin(self) -> bool: """Ask my login_handler if I should skip the origin_check For example: in the default LoginHandler, if a request is token-authenticated, @@ -176,18 +187,18 @@ def skip_check_origin(self): return not self.identity_provider.should_check_origin(self) @property - def token_authenticated(self): + def token_authenticated(self) -> bool: """Have I been authenticated with a token?""" return self.identity_provider.is_token_authenticated(self) @property - def logged_in(self): + def logged_in(self) -> bool: """Is a user currently logged in?""" user = self.current_user return user and user != "anonymous" @property - def login_handler(self): + def login_handler(self) -> Any: """Return the login handler for this application, if any.""" warnings.warn( """JupyterHandler.login_handler is deprecated in 2.0, @@ -199,12 +210,12 @@ def login_handler(self): return self.identity_provider.login_handler_class @property - def token(self): + def token(self) -> str | None: """Return the login token for this application, if any.""" return self.identity_provider.token @property - def login_available(self): + def login_available(self) -> bool: """May a user proceed to log in? This returns True if login capability is available, irrespective of @@ -214,7 +225,7 @@ def login_available(self): return self.identity_provider.login_available @property - def authorizer(self): + def authorizer(self) -> Authorizer: if "authorizer" not in self.settings: warnings.warn( "The Tornado web application does not have an 'authorizer' defined " @@ -234,10 +245,10 @@ def authorizer(self): identity_provider=self.identity_provider, ) - return self.settings.get("authorizer") + return cast("Authorizer", self.settings.get("authorizer")) @property - def identity_provider(self): + def identity_provider(self) -> IdentityProvider: if "identity_provider" not in self.settings: warnings.warn( "The Tornado web application does not have an 'identity_provider' defined " @@ -265,21 +276,21 @@ class JupyterHandler(AuthenticatedHandler): """ @property - def config(self): + def config(self) -> dict[str, Any] | None: return self.settings.get("config", None) @property - def log(self): + def log(self) -> Logger: """use the Jupyter log by default, falling back on tornado's logger""" return log() @property - def jinja_template_vars(self): + def jinja_template_vars(self) -> dict[str, Any]: """User-supplied values to supply to jinja templates.""" return self.settings.get("jinja_template_vars", {}) @property - def serverapp(self): + def serverapp(self) -> ServerApp | None: return self.settings["serverapp"] # --------------------------------------------------------------- @@ -287,31 +298,31 @@ def serverapp(self): # --------------------------------------------------------------- @property - def version_hash(self): + def version_hash(self) -> str: """The version hash to use for cache hints for static files""" return self.settings.get("version_hash", "") @property - def mathjax_url(self): + def mathjax_url(self) -> str: url = self.settings.get("mathjax_url", "") if not url or url_is_absolute(url): return url return url_path_join(self.base_url, url) @property - def mathjax_config(self): + def mathjax_config(self) -> str: return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe") @property - def default_url(self): + def default_url(self) -> str: return self.settings.get("default_url", "") @property - def ws_url(self): + def ws_url(self) -> str: return self.settings.get("websocket_url", "") @property - def contents_js_source(self): + def contents_js_source(self) -> str: self.log.debug( "Using contents: %s", self.settings.get("contents_js_source", "services/contents"), @@ -323,27 +334,27 @@ def contents_js_source(self): # --------------------------------------------------------------- @property - def kernel_manager(self): + def kernel_manager(self) -> AsyncMappingKernelManager: return self.settings["kernel_manager"] @property - def contents_manager(self): + def contents_manager(self) -> ContentsManager: return self.settings["contents_manager"] @property - def session_manager(self): + def session_manager(self) -> SessionManager: return self.settings["session_manager"] @property - def terminal_manager(self): + def terminal_manager(self) -> TerminalManager: return self.settings["terminal_manager"] @property - def kernel_spec_manager(self): + def kernel_spec_manager(self) -> KernelSpecManager: return self.settings["kernel_spec_manager"] @property - def config_manager(self): + def config_manager(self) -> ConfigManager: return self.settings["config_manager"] @property @@ -355,25 +366,25 @@ def event_logger(self) -> EventLogger: # --------------------------------------------------------------- @property - def allow_origin(self): + def allow_origin(self) -> str: """Normal Access-Control-Allow-Origin""" return self.settings.get("allow_origin", "") @property - def allow_origin_pat(self): + def allow_origin_pat(self) -> str: """Regular expression version of allow_origin""" return self.settings.get("allow_origin_pat", None) @property - def allow_credentials(self): + def allow_credentials(self) -> bool: """Whether to set Access-Control-Allow-Credentials""" return self.settings.get("allow_credentials", False) - def set_default_headers(self): + def set_default_headers(self) -> None: """Add CORS headers, if defined""" super().set_default_headers() - def set_cors_headers(self): + def set_cors_headers(self) -> None: """Add CORS headers, if defined Now that current_user is async (jupyter-server 2.0), @@ -395,7 +406,7 @@ def set_cors_headers(self): if self.allow_credentials: self.set_header("Access-Control-Allow-Credentials", "true") - def set_attachment_header(self, filename): + def set_attachment_header(self, filename: str) -> None: """Set Content-Disposition: attachment header As a method to ensure handling of filename encoding @@ -406,7 +417,7 @@ def set_attachment_header(self, filename): f"attachment; filename*=utf-8''{escaped_filename}", ) - def get_origin(self): + def get_origin(self) -> str | None: # Handle WebSocket Origin naming convention differences # The difference between version 8 and 13 is that in 8 the # client sends a "Sec-Websocket-Origin" header and in 13 it's @@ -419,7 +430,7 @@ def get_origin(self): # origin_to_satisfy_tornado is present because tornado requires # check_origin to take an origin argument, but we don't use it - def check_origin(self, origin_to_satisfy_tornado=""): + def check_origin(self, origin_to_satisfy_tornado: str = "") -> bool: """Check Origin for cross-site API requests, including websockets Copied from WebSocket with changes: @@ -466,7 +477,7 @@ def check_origin(self, origin_to_satisfy_tornado=""): ) return allow - def check_referer(self): + def check_referer(self) -> bool: """Check Referer for cross-site requests. Disables requests to certain endpoints with external or missing Referer. @@ -512,7 +523,7 @@ def check_referer(self): ) return allow - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: """Bypass xsrf cookie checks when token-authenticated""" if not hasattr(self, "_jupyter_current_user"): # Called too early, will be checked later @@ -536,7 +547,7 @@ def check_xsrf_cookie(self): else: raise - def check_host(self): + def check_host(self) -> bool: """Check the host header if remote access disallowed. Returns True if the request should continue, False otherwise. @@ -578,7 +589,7 @@ def check_host(self): ) return allow - async def prepare(self): + async def prepare(self) -> Awaitable[None] | None: # type:ignore[override] """Prepare a response.""" # Set the current Jupyter Handler context variable. CallContext.set(CallContext.JUPYTER_HANDLER, self) @@ -603,7 +614,7 @@ async def prepare(self): DeprecationWarning # stacklevel not useful here ) - user = self.get_current_user() + user = User(self.get_current_user()) else: _user = self.identity_provider.get_user(self) if isinstance(_user, Awaitable): @@ -636,7 +647,7 @@ def render_template(self, name, **ns): return template.render(**ns) @property - def template_namespace(self): + def template_namespace(self) -> dict[str, Any]: return dict( base_url=self.base_url, default_url=self.default_url, @@ -659,7 +670,7 @@ def template_namespace(self): **self.jinja_template_vars, ) - def get_json_body(self): + def get_json_body(self) -> dict[str, Any] | None: """Return the body of the request as JSON data.""" if not self.request.body: return None @@ -673,7 +684,7 @@ def get_json_body(self): raise web.HTTPError(400, "Invalid JSON in body of request") from e return model - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs: Any) -> None: """render custom error pages""" exc_info = kwargs.get("exc_info") message = "" @@ -715,13 +726,13 @@ def write_error(self, status_code, **kwargs): class APIHandler(JupyterHandler): """Base class for API handlers""" - async def prepare(self): + async def prepare(self) -> None: """Prepare an API response.""" await super().prepare() if not self.check_origin(): raise web.HTTPError(404) - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs: Any) -> None: """APIHandler errors are JSON, not human pages""" self.set_header("Content-Type", "application/json") message = responses.get(status_code, "Unknown HTTP Error") @@ -741,7 +752,7 @@ def write_error(self, status_code, **kwargs): self.log.warning("wrote error: %r", reply["message"], exc_info=True) self.finish(json.dumps(reply)) - def get_login_url(self): + def get_login_url(self) -> str: """Get the login url.""" # if get_login_url is invoked in an API handler, # that means @web.authenticated is trying to trigger a redirect. @@ -751,7 +762,7 @@ def get_login_url(self): return super().get_login_url() @property - def content_security_policy(self): + def content_security_policy(self) -> str: csp = "; ".join( [ super().content_security_policy, @@ -763,7 +774,7 @@ def content_security_policy(self): # set _track_activity = False on API handlers that shouldn't track activity _track_activity = True - def update_api_activity(self): + def update_api_activity(self) -> None: """Update last_activity of API requests""" # record activity of authenticated requests if ( @@ -773,7 +784,7 @@ def update_api_activity(self): ): self.settings["api_last_activity"] = utcnow() - def finish(self, *args, **kwargs): + def finish(self, *args: Any, **kwargs: Any) -> Future[Any]: """Finish an API response.""" self.update_api_activity() # Allow caller to indicate content-type... @@ -781,7 +792,7 @@ def finish(self, *args, **kwargs): self.set_header("Content-Type", set_content_type) return super().finish(*args, **kwargs) - def options(self, *args, **kwargs): + def options(self, *args: Any, **kwargs: Any) -> None: """Get the options.""" if "Access-Control-Allow-Headers" in self.settings.get("headers", {}): self.set_header( @@ -824,7 +835,7 @@ def options(self, *args, **kwargs): class Template404(JupyterHandler): """Render our 404 template""" - async def prepare(self): + async def prepare(self) -> None: """Prepare a 404 response.""" await super().prepare() raise web.HTTPError(404) @@ -836,21 +847,23 @@ class AuthenticatedFileHandler(JupyterHandler, web.StaticFileHandler): auth_resource = "contents" @property - def content_security_policy(self): + def content_security_policy(self) -> str: # In case we're serving HTML/SVG, confine any Javascript to a unique # origin so it can't interact with the Jupyter server. return super().content_security_policy + "; sandbox allow-scripts" @web.authenticated @authorized - def head(self, path): + def head(self, path: str) -> Awaitable[None]: # type:ignore[override] """Get the head response for a path.""" self.check_xsrf_cookie() return super().head(path) @web.authenticated @authorized - def get(self, path, **kwargs): + def get( # type:ignore[override] + self, path: str, **kwargs: Any + ) -> Awaitable[None]: """Get a file by path.""" self.check_xsrf_cookie() if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", None): @@ -859,7 +872,7 @@ def get(self, path, **kwargs): return web.StaticFileHandler.get(self, path, **kwargs) - def get_content_type(self): + def get_content_type(self) -> str: """Get the content type.""" assert self.absolute_path is not None path = self.absolute_path.strip("/") @@ -876,18 +889,18 @@ def get_content_type(self): else: return super().get_content_type() - def set_headers(self): + def set_headers(self) -> None: """Set the headers.""" super().set_headers() # disable browser caching, rely on 304 replies for savings if "v" not in self.request.arguments: self.add_header("Cache-Control", "no-cache") - def compute_etag(self): + def compute_etag(self) -> str | None: """Compute the etag.""" return None - def validate_absolute_path(self, root, absolute_path): + def validate_absolute_path(self, root: str, absolute_path: str) -> str: """Validate and return the absolute path. Requires tornado 3.1 @@ -905,7 +918,7 @@ def validate_absolute_path(self, root, absolute_path): return abs_path -def json_errors(method): # pragma: no cover +def json_errors(method: Any) -> Any: # pragma: no cover """Decorate methods with this to return GitHub style JSON errors. This should be used on any JSON API on any handler method that can raise HTTPErrors. @@ -949,10 +962,10 @@ class FileFindHandler(JupyterHandler, web.StaticFileHandler): """ # cache search results, don't search for files more than once - _static_paths: dict = {} - root: tuple # type:ignore[assignment] + _static_paths: dict[str, Any] = {} + root: tuple[str] # type:ignore[assignment] - def set_headers(self): + def set_headers(self) -> None: """Set the headers.""" super().set_headers() @@ -968,22 +981,29 @@ def set_headers(self): ): self.set_header("Cache-Control", "no-cache") - def initialize(self, path, default_filename=None, no_cache_paths=None): + def initialize( + self, + path: str | list[str], + default_filename: str | None = None, + no_cache_paths: list[str] | None = None, + ) -> None: """Initialize the file find handler.""" self.no_cache_paths = no_cache_paths or [] if isinstance(path, str): path = [path] - self.root = tuple(os.path.abspath(os.path.expanduser(p)) + os.sep for p in path) + self.root = tuple( + os.path.abspath(os.path.expanduser(p)) + os.sep for p in path + ) # type:ignore[assignment] self.default_filename = default_filename - def compute_etag(self): + def compute_etag(self) -> str | None: """Compute the etag.""" return None @classmethod - def get_absolute_path(cls, roots, path): + def get_absolute_path(cls, roots: Sequence[str], path: str) -> str: """locate a file to serve on our static file search path""" with cls._lock: if path in cls._static_paths: @@ -999,7 +1019,7 @@ def get_absolute_path(cls, roots, path): log().debug(f"Path {path} served from {abspath}") return abspath - def validate_absolute_path(self, root, absolute_path): + def validate_absolute_path(self, root: str, absolute_path: str) -> str | None: """check if the file should be served (raises 404, 403, etc.)""" if not absolute_path: raise web.HTTPError(404) @@ -1016,7 +1036,7 @@ class APIVersionHandler(APIHandler): _track_activity = False - def get(self): + def get(self) -> None: """Get the server version info.""" # not authenticated, so give as few info as possible self.finish(json.dumps({"version": jupyter_server.__version__})) @@ -1028,7 +1048,7 @@ class TrailingSlashHandler(web.RequestHandler): This should be the first, highest priority handler. """ - def get(self): + def get(self) -> None: """Handle trailing slashes in a get.""" assert self.request.uri is not None path, *rest = self.request.uri.partition("?") @@ -1044,7 +1064,7 @@ def get(self): class MainHandler(JupyterHandler): """Simple handler for base_url.""" - def get(self): + def get(self) -> None: """Get the main template.""" html = self.render_template("main.html") self.write(html) @@ -1056,7 +1076,7 @@ class FilesRedirectHandler(JupyterHandler): """Handler for redirecting relative URLs to the /files/ handler""" @staticmethod - async def redirect_to_files(self, path): + async def redirect_to_files(self: Any, path: str) -> None: """make redirect logic a reusable static method so it can be called from other handlers. @@ -1084,19 +1104,19 @@ async def redirect_to_files(self, path): self.log.debug("Redirecting %s to %s", self.request.path, url) self.redirect(url) - def get(self, path=""): + def get(self, path: str = "") -> Awaitable: return self.redirect_to_files(self, path) class RedirectWithParams(web.RequestHandler): """Sam as web.RedirectHandler, but preserves URL parameters""" - def initialize(self, url, permanent=True): + def initialize(self, url: str, permanent: bool = True) -> None: """Initialize a redirect handler.""" self._url = url self._permanent = permanent - def get(self): + def get(self) -> None: """Get a redirect.""" sep = "&" if "?" in self._url else "?" url = sep.join([self._url, self.request.query]) @@ -1108,7 +1128,7 @@ class PrometheusMetricsHandler(JupyterHandler): Return prometheus metrics for this server """ - def get(self): + def get(self) -> None: """Get prometheus metrics.""" if self.settings["authenticate_prometheus"] and not self.logged_in: raise web.HTTPError(403) diff --git a/jupyter_server/config_manager.py b/jupyter_server/config_manager.py index 9f34bb33db..76268d8a23 100644 --- a/jupyter_server/config_manager.py +++ b/jupyter_server/config_manager.py @@ -1,17 +1,22 @@ """Manager to read and modify config data in JSON files.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import copy import errno import glob import json import os +import typing as t from traitlets.config import LoggingConfigurable from traitlets.traitlets import Bool, Unicode +StrDict = t.Dict[str, t.Any] + -def recursive_update(target, new): +def recursive_update(target: StrDict, new: StrDict) -> None: """Recursively update one dictionary using another. None values will delete their keys. @@ -32,7 +37,7 @@ def recursive_update(target, new): target[k] = v -def remove_defaults(data, defaults): +def remove_defaults(data: StrDict, defaults: StrDict) -> None: """Recursively remove items from dict that are already in defaults""" # copy the iterator, since data will be modified for key, value in list(data.items()): @@ -55,7 +60,7 @@ class BaseJSONConfigManager(LoggingConfigurable): config_dir = Unicode(".") read_directory = Bool(True) - def ensure_config_dir_exists(self): + def ensure_config_dir_exists(self) -> None: """Will try to create the config_dir directory.""" try: os.makedirs(self.config_dir, 0o755) @@ -63,15 +68,15 @@ def ensure_config_dir_exists(self): if e.errno != errno.EEXIST: raise - def file_name(self, section_name): + def file_name(self, section_name: str) -> str: """Returns the json filename for the section_name: {config_dir}/{section_name}.json""" return os.path.join(self.config_dir, section_name + ".json") - def directory(self, section_name): + def directory(self, section_name: str) -> str: """Returns the directory name for the section name: {config_dir}/{section_name}.d""" return os.path.join(self.config_dir, section_name + ".d") - def get(self, section_name, include_root=True): + def get(self, section_name: str, include_root: bool = True) -> t.Any: """Retrieve the config data for the specified section. Returns the data as a dictionary, or an empty dictionary if the file @@ -101,7 +106,7 @@ def get(self, section_name, include_root=True): recursive_update(data, json.load(f)) return data - def set(self, section_name, data): + def set(self, section_name: str, data: t.Any) -> None: """Store the given config data.""" filename = self.file_name(section_name) self.ensure_config_dir_exists() @@ -118,7 +123,7 @@ def set(self, section_name, data): with open(filename, "w", encoding="utf-8") as f: f.write(json_content) - def update(self, section_name, new_data): + def update(self, section_name: str, new_data: t.Any) -> None: """Modify the config section by recursively updating it with new_data. Returns the modified config data as a dictionary. diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index 984b2438cd..82ff73a0de 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -550,7 +550,7 @@ def load_classic_server_extension(cls, serverapp): serverapp_class = ServerApp @classmethod - def make_serverapp(cls, **kwargs): + def make_serverapp(cls, **kwargs: t.Any) -> ServerApp: """Instantiate the ServerApp Override to customize the ServerApp before it loads any configuration @@ -573,7 +573,7 @@ def initialize_server(cls, argv=None, load_other_extensions=True, **kwargs): cls.serverapp_config["jpserver_extensions"] = jpserver_extensions find_extensions = False serverapp = cls.make_serverapp(jpserver_extensions=jpserver_extensions, **kwargs) - serverapp.aliases.update(cls.aliases) + serverapp.aliases.update(cls.aliases) # type:ignore[has-type] serverapp.initialize( argv=argv or [], starter_extension=cls.name, diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py index 618011abc8..3018aae1c2 100644 --- a/jupyter_server/extension/handler.py +++ b/jupyter_server/extension/handler.py @@ -1,24 +1,33 @@ """An extension handler.""" -from typing import no_type_check +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from jinja2.exceptions import TemplateNotFound from jupyter_server.base.handlers import FileFindHandler +if TYPE_CHECKING: + from logging import Logger + + from traitlets.config import Config + + from jupyter_server.extension.application import ExtensionApp + from jupyter_server.serverapp import ServerApp + class ExtensionHandlerJinjaMixin: """Mixin class for ExtensionApp handlers that use jinja templating for template rendering. """ - @no_type_check - def get_template(self, name): + def get_template(self, name: str) -> str: """Return the jinja template object for a given name""" try: - env = f"{self.name}_jinja2_env" - return self.settings[env].get_template(name) + env = f"{self.name}_jinja2_env" # type:ignore[attr-defined] + return self.settings[env].get_template(name) # type:ignore[attr-defined] except TemplateNotFound: - return super().get_template(name) + return super().get_template(name) # type:ignore[misc] class ExtensionHandlerMixin: @@ -32,7 +41,7 @@ class ExtensionHandlerMixin: other extensions. """ - def initialize(self, name, *args, **kwargs): + def initialize(self, name: str, *args: Any, **kwargs: Any) -> None: self.name = name try: super().initialize(*args, **kwargs) # type:ignore[misc] @@ -40,16 +49,16 @@ def initialize(self, name, *args, **kwargs): pass @property - def extensionapp(self): + def extensionapp(self) -> ExtensionApp: return self.settings[self.name] # type:ignore[attr-defined] @property - def serverapp(self): + def serverapp(self) -> ServerApp: key = "serverapp" return self.settings[key] # type:ignore[attr-defined] @property - def log(self): + def log(self) -> Logger: if not hasattr(self, "name"): return super().log # type:ignore[misc] # Attempt to pull the ExtensionApp's log, otherwise fall back to ServerApp. @@ -59,11 +68,11 @@ def log(self): return self.serverapp.log @property - def config(self): + def config(self) -> Config: return self.settings[f"{self.name}_config"] # type:ignore[attr-defined] @property - def server_config(self): + def server_config(self) -> Config: return self.settings["config"] # type:ignore[attr-defined] @property @@ -71,14 +80,14 @@ def base_url(self) -> str: return self.settings.get("base_url", "/") # type:ignore[attr-defined] @property - def static_url_prefix(self): + def static_url_prefix(self) -> str: return self.extensionapp.static_url_prefix @property - def static_path(self): + def static_path(self) -> str: return self.settings[f"{self.name}_static_paths"] # type:ignore[attr-defined] - def static_url(self, path, include_host=None, **kwargs): + def static_url(self, path: str, include_host: bool | None = None, **kwargs: Any) -> str: """Returns a static URL for the given relative static file path. This method requires you set the ``{name}_static_path`` setting in your extension (which specifies the root directory diff --git a/jupyter_server/extension/serverextension.py b/jupyter_server/extension/serverextension.py index 6cd8dc14a3..3147ab7dd0 100644 --- a/jupyter_server/extension/serverextension.py +++ b/jupyter_server/extension/serverextension.py @@ -1,9 +1,12 @@ """Utilities for installing extensions""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import logging import os import sys +import typing as t from jupyter_core.application import JupyterApp from jupyter_core.paths import ENV_CONFIG_PATH, SYSTEM_CONFIG_PATH, jupyter_config_dir @@ -15,7 +18,7 @@ from jupyter_server.extension.manager import ExtensionManager, ExtensionPackage -def _get_config_dir(user=False, sys_prefix=False): +def _get_config_dir(user: bool = False, sys_prefix: bool = False) -> str: """Get the location of config files for the current context Returns the string to the environment @@ -38,7 +41,9 @@ def _get_config_dir(user=False, sys_prefix=False): return extdir -def _get_extmanager_for_context(write_dir="jupyter_server_config.d", user=False, sys_prefix=False): +def _get_extmanager_for_context( + write_dir: str = "jupyter_server_config.d", user: bool = False, sys_prefix: bool = False +) -> tuple[str, ExtensionManager]: """Get an extension manager pointing at the current context Returns the path to the current context and an ExtensionManager object. @@ -67,7 +72,7 @@ class ArgumentConflict(ValueError): pass -_base_flags = {} +_base_flags: dict[str, t.Any] = {} _base_flags.update(JupyterApp.flags) _base_flags.pop("y", None) _base_flags.pop("generate-config", None) @@ -110,7 +115,7 @@ class ArgumentConflict(ValueError): ) _base_flags["python"] = _base_flags["py"] -_base_aliases = {} +_base_aliases: dict[str, t.Any] = {} _base_aliases.update(JupyterApp.aliases) @@ -126,12 +131,12 @@ class BaseExtensionApp(JupyterApp): sys_prefix = Bool(True, config=True, help="Use the sys.prefix as the prefix") python = Bool(False, config=True, help="Install from a Python package") - def _log_format_default(self): + def _log_format_default(self) -> str: """A default format for messages""" return "%(message)s" @property - def config_dir(self): + def config_dir(self) -> str: # type:ignore[override] return _get_config_dir(user=self.user, sys_prefix=self.sys_prefix) @@ -148,8 +153,12 @@ def config_dir(self): def toggle_server_extension_python( - import_name, enabled=None, parent=None, user=False, sys_prefix=True -): + import_name: str, + enabled: bool | None = None, + parent: t.Any = None, + user: bool = False, + sys_prefix: bool = True, +) -> None: """Toggle the boolean setting for a given server extension in a Jupyter config file. """ @@ -228,7 +237,7 @@ class ToggleServerExtensionApp(BaseExtensionApp): _toggle_pre_message = "" _toggle_post_message = "" - def toggle_server_extension(self, import_name): + def toggle_server_extension(self, import_name: str) -> None: """Change the status of a named server extension. Uses the value of `self._toggle_value`. @@ -257,17 +266,18 @@ def toggle_server_extension(self, import_name): # Toggle extension config. config = extension_manager.config_manager - if self._toggle_value is True: - config.enable(import_name) - else: - config.disable(import_name) + if config: + if self._toggle_value is True: + config.enable(import_name) + else: + config.disable(import_name) # If successful, let's log. self.log.info(f" - Extension successfully {self._toggle_post_message}.") except Exception as err: self.log.info(f" {RED_X} Validation failed: {err}") - def start(self): + def start(self) -> None: """Perform the App's actions as configured""" if not self.extra_args: sys.exit("Please specify a server extension/package to enable or disable") @@ -312,7 +322,7 @@ class ListServerExtensionsApp(BaseExtensionApp): version = __version__ description = "List all server extensions known by the configuration system" - def list_server_extensions(self): + def list_server_extensions(self) -> None: """List all enabled and disabled server extensions, by config path Enabled extensions are validated, potentially generating warnings. @@ -351,7 +361,7 @@ def list_server_extensions(self): # Add a blank line between paths. self.log.info("") - def start(self): + def start(self) -> None: """Perform the App's actions as configured""" self.list_server_extensions() @@ -377,7 +387,7 @@ class ServerExtensionApp(BaseExtensionApp): "list": (ListServerExtensionsApp, "List server extensions"), } - def start(self): + def start(self) -> None: """Perform the App's actions as configured""" super().start() diff --git a/jupyter_server/files/handlers.py b/jupyter_server/files/handlers.py index 89a8cd580d..9195cc7b9c 100644 --- a/jupyter_server/files/handlers.py +++ b/jupyter_server/files/handlers.py @@ -1,9 +1,11 @@ """Serve files directly from the ContentsManager.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import mimetypes from base64 import decodebytes -from typing import List +from typing import Awaitable from jupyter_core.utils import ensure_async from tornado import web @@ -34,7 +36,7 @@ def content_security_policy(self): @web.authenticated @authorized - def head(self, path): + def head(self, path: str) -> Awaitable[None] | None: # type:ignore[override] """The head response.""" self.get(path, include_body=False) self.check_xsrf_cookie() @@ -91,4 +93,4 @@ async def get(self, path, include_body=True): self.flush() -default_handlers: List[JupyterHandler] = [] +default_handlers: list[JupyterHandler] = [] diff --git a/jupyter_server/gateway/connections.py b/jupyter_server/gateway/connections.py index 0fe973eef6..d0a3a03633 100644 --- a/jupyter_server/gateway/connections.py +++ b/jupyter_server/gateway/connections.py @@ -35,7 +35,7 @@ async def connect(self): # websocket is initialized before connection self.ws = None ws_url = url_path_join( - GatewayClient.instance().ws_url, + GatewayClient.instance().ws_url or "", GatewayClient.instance().kernels_endpoint, url_escape(self.kernel_id), "channels", diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index 7d532211ed..e5ed39371d 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -168,15 +168,18 @@ async def _connect(self, kernel_id, message_callback): # websocket is initialized before connection self.ws = None self.kernel_id = kernel_id + client = GatewayClient.instance() + assert client.ws_url is not None + ws_url = url_path_join( - GatewayClient.instance().ws_url, - GatewayClient.instance().kernels_endpoint, + client.ws_url, + client.kernels_endpoint, url_escape(kernel_id), "channels", ) self.log.info(f"Connecting to {ws_url}") kwargs: dict = {} - kwargs = GatewayClient.instance().load_connection_args(**kwargs) + kwargs = client.load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) self.ws_future = cast(Future, websocket_connect(request)) @@ -289,7 +292,9 @@ async def get(self, kernel_name, path, include_body=True): """Get a gateway resource by name and path.""" mimetype: Optional[str] = None ksm = self.kernel_spec_manager - kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path) + kernel_spec_res = await ksm.get_kernel_spec_resource( # type:ignore[attr-defined] + kernel_name, path + ) if kernel_spec_res is None: self.log.warning( "Kernelspec resource '{}' for '{}' not found. Gateway may not support" diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 5a09ede926..eb203d59f4 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -50,7 +50,7 @@ def __init__(self, **kwargs): """Initialize a gateway mapping kernel manager.""" super().__init__(**kwargs) self.kernels_url = url_path_join( - GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint + GatewayClient.instance().url or "", GatewayClient.instance().kernels_endpoint or "" ) def remove_kernel(self, kernel_id): @@ -214,12 +214,12 @@ def __init__(self, **kwargs): """Initialize a gateway kernel spec manager.""" super().__init__(**kwargs) base_endpoint = url_path_join( - GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint + GatewayClient.instance().url or "", GatewayClient.instance().kernelspecs_endpoint ) self.base_endpoint = GatewayKernelSpecManager._get_endpoint_for_user_filter(base_endpoint) self.base_resource_endpoint = url_path_join( - GatewayClient.instance().url, + GatewayClient.instance().url or "", GatewayClient.instance().kernelspecs_resource_endpoint, ) @@ -386,7 +386,7 @@ def __init__(self, **kwargs): """Initialize the gateway kernel manager.""" super().__init__(**kwargs) self.kernels_url = url_path_join( - GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint + GatewayClient.instance().url or "", GatewayClient.instance().kernels_endpoint ) self.kernel_url: str self.kernel = self.kernel_id = None @@ -420,7 +420,7 @@ def client(self, **kwargs): # add kwargs last, for manual overrides kw.update(kwargs) - return self.client_factory(**kw) # type:ignore[operator] + return self.client_factory(**kw) async def refresh_model(self, model=None): """Refresh the kernel model. @@ -485,7 +485,7 @@ async def start_kernel(self, **kwargs): # Let KERNEL_USERNAME take precedent over http_user config option. if os.environ.get("KERNEL_USERNAME") is None and GatewayClient.instance().http_user: - os.environ["KERNEL_USERNAME"] = GatewayClient.instance().http_user + os.environ["KERNEL_USERNAME"] = GatewayClient.instance().http_user or "" payload_envs = os.environ.copy() payload_envs.update(kwargs.get("env", {})) # Add any env entries in this request @@ -739,7 +739,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont """ ws_url = url_path_join( - GatewayClient.instance().ws_url, + GatewayClient.instance().ws_url or "", GatewayClient.instance().kernels_endpoint, url_escape(self.kernel_id), "channels", diff --git a/jupyter_server/nbconvert/handlers.py b/jupyter_server/nbconvert/handlers.py index 3d65b392d3..5de2587fde 100644 --- a/jupyter_server/nbconvert/handlers.py +++ b/jupyter_server/nbconvert/handlers.py @@ -167,6 +167,7 @@ async def post(self, format): exporter = get_exporter(format, config=self.config) model = self.get_json_body() + assert model is not None name = model.get("name", "notebook.ipynb") nbnode = from_dict(model["content"]) diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 09c2078871..7ab7530dd5 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -21,6 +21,7 @@ import sys import threading import time +import typing as t import urllib import warnings from base64 import encodebytes @@ -191,7 +192,7 @@ # ----------------------------------------------------------------------------- -def random_ports(port, n): +def random_ports(port: int, n: int) -> t.Generator[int, None, None]: """Generate a list of n random ports near the given port. The first 5 ports will be sequential, and the remaining n-5 will be @@ -203,7 +204,7 @@ def random_ports(port, n): yield max(1, port + random.randint(-2 * n, 2 * n)) # noqa -def load_handlers(name): +def load_handlers(name: str) -> t.Any: """Load the (URL pattern, handler) tuples for each component.""" mod = __import__(name, fromlist=["default_handlers"]) return mod.default_handlers @@ -836,11 +837,11 @@ class ServerApp(JupyterApp): _stopping = Bool(False, help="Signal that we've begun stopping.") @default("log_level") - def _default_log_level(self): + def _default_log_level(self) -> int: return logging.INFO @default("log_format") - def _default_log_format(self): + def _default_log_format(self) -> str: """override default log format to include date & time""" return ( "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s" @@ -909,7 +910,7 @@ def _default_log_format(self): ) @default("ip") - def _default_ip(self): + def _default_ip(self) -> str: """Return localhost if available, 127.0.0.1 otherwise. On some (horribly broken) systems, localhost cannot be bound. @@ -927,7 +928,7 @@ def _default_ip(self): return "localhost" @validate("ip") - def _validate_ip(self, proposal): + def _validate_ip(self, proposal: t.Any) -> str: value = proposal["value"] if value == "*": value = "" @@ -959,7 +960,7 @@ def _validate_ip(self, proposal): ) @default("port") - def _port_default(self): + def _port_default(self) -> int: return int(os.getenv(self.port_env, self.port_default_value)) port_retries_env = "JUPYTER_PORT_RETRIES" @@ -974,7 +975,7 @@ def _port_default(self): ) @default("port_retries") - def _port_retries_default(self): + def _port_retries_default(self) -> int: return int(os.getenv(self.port_retries_env, self.port_retries_default_value)) sock = Unicode("", config=True, help="The UNIX socket the Jupyter server will listen on.") @@ -986,7 +987,7 @@ def _port_retries_default(self): ) @validate("sock_mode") - def _validate_sock_mode(self, proposal): + def _validate_sock_mode(self, proposal: t.Any) -> int: value = proposal["value"] try: converted_value = int(value.encode(), 8) @@ -1034,7 +1035,7 @@ def _validate_sock_mode(self, proposal): ) @default("cookie_secret_file") - def _default_cookie_secret_file(self): + def _default_cookie_secret_file(self) -> str: return os.path.join(self.runtime_dir, "jupyter_cookie_secret") cookie_secret = Bytes( @@ -1050,7 +1051,7 @@ def _default_cookie_secret_file(self): ) @default("cookie_secret") - def _default_cookie_secret(self): + def _default_cookie_secret(self) -> bytes: if os.path.exists(self.cookie_secret_file): with open(self.cookie_secret_file, "rb") as f: key = f.read() @@ -1061,7 +1062,7 @@ def _default_cookie_secret(self): h.update(self.password.encode()) return h.digest() - def _write_cookie_secret_file(self, secret): + def _write_cookie_secret_file(self, secret: bytes) -> None: """write my secret to my secret_file""" self.log.info(_i18n("Writing Jupyter server cookie secret to %s"), self.cookie_secret_file) try: @@ -1081,11 +1082,11 @@ def _write_cookie_secret_file(self, secret): ) @observe("token") - def _deprecated_token(self, change): + def _deprecated_token(self, change: t.Any) -> None: self._warn_deprecated_config(change, "IdentityProvider") @default("token") - def _deprecated_token_access(self): + def _deprecated_token_access(self) -> None: warnings.warn( "ServerApp.token config is deprecated in jupyter-server 2.0. Use IdentityProvider.token", DeprecationWarning, @@ -1105,7 +1106,7 @@ def _deprecated_token_access(self): ) @default("min_open_files_limit") - def _default_min_open_files_limit(self): + def _default_min_open_files_limit(self) -> t.Optional[int]: if resource is None: # Ignoring min_open_files_limit because the limit cannot be adjusted (for example, on Windows) return None # type:ignore[unreachable] @@ -1164,7 +1165,9 @@ def _default_min_open_files_limit(self): help="""DEPRECATED in 2.0. Use PasswordIdentityProvider.allow_password_change""", ) - def _warn_deprecated_config(self, change, clsname, new_name=None): + def _warn_deprecated_config( + self, change: t.Any, clsname: str, new_name: t.Optional[str] = None + ) -> None: """Warn on deprecated config.""" if new_name is None: new_name = change.name @@ -1184,11 +1187,11 @@ def _warn_deprecated_config(self, change, clsname, new_name=None): ) @observe("password") - def _deprecated_password(self, change): + def _deprecated_password(self, change: t.Any) -> None: self._warn_deprecated_config(change, "PasswordIdentityProvider", new_name="hashed_password") @observe("password_required", "allow_password_change") - def _deprecated_password_config(self, change): + def _deprecated_password_config(self, change: t.Any) -> None: self._warn_deprecated_config(change, "PasswordIdentityProvider") disable_check_xsrf = Bool( @@ -1227,7 +1230,7 @@ def _deprecated_password_config(self, change): ) @default("allow_remote_access") - def _default_allow_remote(self): + def _default_allow_remote(self) -> bool: """Disallow remote access if we're listening only on loopback addresses""" # if blank, self.ip was configured to "*" meaning bind to all interfaces, @@ -1367,7 +1370,7 @@ def _default_allow_remote(self): ) @observe("cookie_options", "get_secure_cookie_kwargs") - def _deprecated_cookie_config(self, change): + def _deprecated_cookie_config(self, change: t.Any) -> None: self._warn_deprecated_config(change, "IdentityProvider") ssl_options = Dict( @@ -1400,7 +1403,7 @@ def _deprecated_cookie_config(self, change): ) @validate("base_url") - def _update_base_url(self, proposal): + def _update_base_url(self, proposal: t.Any) -> str: value = proposal["value"] if not value.startswith("/"): value = "/" + value @@ -1418,14 +1421,14 @@ def _update_base_url(self, proposal): ) @property - def static_file_path(self): + def static_file_path(self) -> t.List[str]: """return extra paths + the default location""" return [*self.extra_static_paths, DEFAULT_STATIC_FILES_PATH] static_custom_path = List(Unicode(), help=_i18n("""Path to search for custom.js, css""")) @default("static_custom_path") - def _default_static_custom_path(self): + def _default_static_custom_path(self) -> t.List[str]: return [os.path.join(d, "custom") for d in (self.config_dir, DEFAULT_STATIC_FILES_PATH)] extra_template_paths = List( @@ -1439,7 +1442,7 @@ def _default_static_custom_path(self): ) @property - def template_file_path(self): + def template_file_path(self) -> t.List[str]: """return extra paths + the default locations""" return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST @@ -1481,7 +1484,7 @@ def template_file_path(self): ) @default("kernel_manager_class") - def _default_kernel_manager_class(self): + def _default_kernel_manager_class(self) -> t.Union[str, t.Type[AsyncMappingKernelManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewayMappingKernelManager" return AsyncMappingKernelManager @@ -1492,7 +1495,7 @@ def _default_kernel_manager_class(self): ) @default("session_manager_class") - def _default_session_manager_class(self): + def _default_session_manager_class(self) -> t.Union[str, t.Type[SessionManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewaySessionManager" return SessionManager @@ -1504,7 +1507,9 @@ def _default_session_manager_class(self): ) @default("kernel_websocket_connection_class") - def _default_kernel_websocket_connection_class(self): + def _default_kernel_websocket_connection_class( + self, + ) -> t.Union[str, t.Type[ZMQChannelsWebsocketConnection]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection @@ -1529,7 +1534,7 @@ def _default_kernel_websocket_connection_class(self): ) @default("kernel_spec_manager_class") - def _default_kernel_spec_manager_class(self): + def _default_kernel_spec_manager_class(self) -> t.Union[str, t.Type[KernelSpecManager]]: if self.gateway_config.gateway_enabled: return "jupyter_server.gateway.managers.GatewayKernelSpecManager" return KernelSpecManager @@ -1585,7 +1590,7 @@ def _default_kernel_spec_manager_class(self): info_file = Unicode() @default("info_file") - def _default_info_file(self): + def _default_info_file(self) -> str: info_file = "jpserver-%s.json" % os.getpid() return os.path.join(self.runtime_dir, info_file) @@ -1596,14 +1601,14 @@ def _default_info_file(self): browser_open_file = Unicode() @default("browser_open_file") - def _default_browser_open_file(self): + def _default_browser_open_file(self) -> str: basename = "jpserver-%s-open.html" % os.getpid() return os.path.join(self.runtime_dir, basename) browser_open_file_to_run = Unicode() @default("browser_open_file_to_run") - def _default_browser_open_file_to_run(self): + def _default_browser_open_file_to_run(self) -> str: basename = "jpserver-file-to-run-%s-open.html" % os.getpid() return os.path.join(self.runtime_dir, basename) @@ -1618,7 +1623,7 @@ def _default_browser_open_file_to_run(self): ) @observe("pylab") - def _update_pylab(self, change): + def _update_pylab(self, change: t.Any) -> None: """when --pylab is specified, display a warning and exit""" backend = " %s" % change["new"] if change["new"] != "warn" else "" self.log.error( @@ -1634,7 +1639,7 @@ def _update_pylab(self, change): notebook_dir = Unicode(config=True, help=_i18n("DEPRECATED, use root_dir.")) @observe("notebook_dir") - def _update_notebook_dir(self, change): + def _update_notebook_dir(self, change: t.Any) -> None: if self._root_dir_set: # only use deprecated config if new config is not set return @@ -1665,14 +1670,14 @@ def _update_notebook_dir(self, change): _root_dir_set = False @default("root_dir") - def _default_root_dir(self): + def _default_root_dir(self) -> str: if self.file_to_run: self._root_dir_set = True return os.path.dirname(os.path.abspath(self.file_to_run)) else: return os.getcwd() - def _normalize_dir(self, value): + def _normalize_dir(self, value: str) -> str: """Normalize a directory.""" # Strip any trailing slashes # *except* if it's root @@ -1686,14 +1691,14 @@ def _normalize_dir(self, value): return value @validate("root_dir") - def _root_dir_validate(self, proposal): + def _root_dir_validate(self, proposal: t.Any) -> str: value = self._normalize_dir(proposal["value"]) if not os.path.isdir(value): raise TraitError(trans.gettext("No such directory: '%r'") % value) return value @observe("root_dir") - def _root_dir_changed(self, change): + def _root_dir_changed(self, change: t.Any) -> None: # record that root_dir is set, # which affects loading of deprecated notebook_dir self._root_dir_set = True @@ -1705,18 +1710,18 @@ def _root_dir_changed(self, change): ) @default("preferred_dir") - def _default_prefered_dir(self): + def _default_prefered_dir(self) -> str: return self.root_dir @validate("preferred_dir") - def _preferred_dir_validate(self, proposal): + def _preferred_dir_validate(self, proposal: t.Any) -> str: value = self._normalize_dir(proposal["value"]) if not os.path.isdir(value): raise TraitError(trans.gettext("No such preferred dir: '%r'") % value) return value @observe("server_extensions") - def _update_server_extensions(self, change): + def _update_server_extensions(self, change: t.Any) -> None: self.log.warning(_i18n("server_extensions is deprecated, use jpserver_extensions")) self.server_extensions = change["new"] @@ -1747,7 +1752,7 @@ def _update_server_extensions(self, change): ) @observe("kernel_ws_protocol") - def _deprecated_kernel_ws_protocol(self, change): + def _deprecated_kernel_ws_protocol(self, change: t.Any) -> None: self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") limit_rate = Bool( @@ -1757,7 +1762,7 @@ def _deprecated_kernel_ws_protocol(self, change): ) @observe("limit_rate") - def _deprecated_limit_rate(self, change): + def _deprecated_limit_rate(self, change: t.Any) -> None: self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") iopub_msg_rate_limit = Float( @@ -1767,7 +1772,7 @@ def _deprecated_limit_rate(self, change): ) @observe("iopub_msg_rate_limit") - def _deprecated_iopub_msg_rate_limit(self, change): + def _deprecated_iopub_msg_rate_limit(self, change: t.Any) -> None: self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") iopub_data_rate_limit = Float( @@ -1777,7 +1782,7 @@ def _deprecated_iopub_msg_rate_limit(self, change): ) @observe("iopub_data_rate_limit") - def _deprecated_iopub_data_rate_limit(self, change): + def _deprecated_iopub_data_rate_limit(self, change: t.Any) -> None: self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") rate_limit_window = Float( @@ -1787,7 +1792,7 @@ def _deprecated_iopub_data_rate_limit(self, change): ) @observe("rate_limit_window") - def _deprecated_rate_limit_window(self, change): + def _deprecated_rate_limit_window(self, change: t.Any) -> None: self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") shutdown_no_activity_timeout = Integer( @@ -1819,7 +1824,7 @@ def _deprecated_rate_limit_window(self, change): ) @default("terminals_enabled") - def _default_terminals_enabled(self): + def _default_terminals_enabled(self) -> bool: return True authenticate_prometheus = Bool( @@ -1848,11 +1853,11 @@ def _default_terminals_enabled(self): ) @property - def starter_app(self): + def starter_app(self) -> t.Any: """Get the Extension that started this server.""" return self._starter_app - def parse_command_line(self, argv=None): + def parse_command_line(self, argv: t.Optional[t.List[str]] = None) -> None: """Parse the command line options.""" super().parse_command_line(argv) @@ -1873,7 +1878,7 @@ def parse_command_line(self, argv=None): c.ServerApp.file_to_run = f self.update_config(c) - def init_configurables(self): + def init_configurables(self) -> None: """Initialize configurables.""" # If gateway server is configured, replace appropriate managers to perform redirection. To make # this determination, instantiate the GatewayClient config singleton. @@ -1897,7 +1902,7 @@ def init_configurables(self): stacklevel=2, ) - self.kernel_spec_manager = self.kernel_spec_manager_class( # type:ignore[operator] + self.kernel_spec_manager = self.kernel_spec_manager_class( parent=self, ) @@ -1927,13 +1932,13 @@ def init_configurables(self): # Trigger a default/validation here explicitly while we still support the # deprecated trait on ServerApp (FIXME remove when deprecation finalized) self.contents_manager.preferred_dir # noqa - self.session_manager = self.session_manager_class( # type:ignore[operator] + self.session_manager = self.session_manager_class( parent=self, log=self.log, kernel_manager=self.kernel_manager, contents_manager=self.contents_manager, ) - self.config_manager = self.config_manager_class( # type:ignore[operator] + self.config_manager = self.config_manager_class( parent=self, log=self.log, ) @@ -1989,7 +1994,7 @@ def init_configurables(self): parent=self, log=self.log, identity_provider=self.identity_provider ) - def init_logging(self): + def init_logging(self) -> None: """Initialize logging.""" # This prevents double log messages because tornado use a root logger that # self.log is a child of. The logging module dipatches log messages to a log @@ -2005,7 +2010,7 @@ def init_logging(self): logger.parent = self.log logger.setLevel(self.log.level) - def init_event_logger(self): + def init_event_logger(self) -> None: """Initialize the Event Bus.""" self.event_logger = EventLogger(parent=self) # Load the core Jupyter Server event schemas @@ -2023,7 +2028,7 @@ def init_event_logger(self): # Use this pathlib object to register the schema self.event_logger.register_event_schema(schema_path) - def init_webapp(self): + def init_webapp(self) -> None: """initialize tornado webapp""" self.tornado_settings["allow_origin"] = self.allow_origin self.tornado_settings["websocket_compression_options"] = self.websocket_compression_options @@ -2125,7 +2130,7 @@ def init_webapp(self): # LegacyIdentityProvider needs access to the tornado settings dict self.identity_provider.settings = self.web_app.settings - def init_resources(self): + def init_resources(self) -> None: """initialize system resources""" if resource is None: self.log.debug( # type:ignore[unreachable] @@ -2144,7 +2149,7 @@ def init_resources(self): ) resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) - def _get_urlparts(self, path=None, include_token=False): + def _get_urlparts(self, path: t.Optional[str] = None, include_token: bool = False) -> t.Any: """Constructs a urllib named tuple, ParseResult, with default values set by server config. The returned tuple can be manipulated using the `_replace` method. @@ -2178,7 +2183,7 @@ def _get_urlparts(self, path=None, include_token=False): return urlparts @property - def public_url(self): + def public_url(self) -> str: parts = self._get_urlparts(include_token=True) # Update with custom pieces. if self.custom_display_url: @@ -2191,7 +2196,7 @@ def public_url(self): return parts.geturl() @property - def local_url(self): + def local_url(self) -> str: parts = self._get_urlparts(include_token=True) # Update with custom pieces. if not self.sock: @@ -2199,7 +2204,7 @@ def local_url(self): return parts.geturl() @property - def display_url(self): + def display_url(self) -> str: """Human readable string with URLs for interacting with the running Jupyter Server """ @@ -2207,11 +2212,11 @@ def display_url(self): return url @property - def connection_url(self): + def connection_url(self) -> str: urlparts = self._get_urlparts(path=self.base_url) return urlparts.geturl() - def init_signal(self): + def init_signal(self) -> None: """Initialize signal handlers.""" if ( not sys.platform.startswith("win") @@ -2227,7 +2232,7 @@ def init_signal(self): # only on BSD-based systems signal.signal(signal.SIGINFO, self._signal_info) - def _handle_sigint(self, sig, frame): + def _handle_sigint(self, sig: t.Any, frame: t.Any) -> None: """SIGINT handler spawns confirmation dialog""" # register more forceful signal handler for ^C^C case signal.signal(signal.SIGINT, self._signal_stop) @@ -2237,11 +2242,11 @@ def _handle_sigint(self, sig, frame): thread.daemon = True thread.start() - def _restore_sigint_handler(self): + def _restore_sigint_handler(self) -> None: """callback for restoring original SIGINT handler""" signal.signal(signal.SIGINT, self._handle_sigint) - def _confirm_exit(self): + def _confirm_exit(self) -> None: """confirm shutdown on ^C A second ^C, or answering 'y' within 5s will cause shutdown, @@ -2285,21 +2290,21 @@ def _confirm_exit(self): # from main thread self.io_loop.add_callback_from_signal(self._restore_sigint_handler) - def _signal_stop(self, sig, frame): + def _signal_stop(self, sig: t.Any, frame: t.Any) -> None: """Handle a stop signal.""" self.log.critical(_i18n("received signal %s, stopping"), sig) self.stop(from_signal=True) - def _signal_info(self, sig, frame): + def _signal_info(self, sig: t.Any, frame: t.Any) -> None: """Handle an info signal.""" self.log.info(self.running_server_info()) - def init_components(self): + def init_components(self) -> None: """Check the components submodule, and warn if it's unclean""" # TODO: this should still check, but now we use bower, not git submodule pass - def find_server_extensions(self): + def find_server_extensions(self) -> None: """ Searches Jupyter paths for jpserver_extensions. """ @@ -2322,7 +2327,7 @@ def find_server_extensions(self): self.config.ServerApp.jpserver_extensions.update({modulename: enabled}) self.jpserver_extensions.update({modulename: enabled}) - def init_server_extensions(self): + def init_server_extensions(self) -> None: """ If an extension's metadata includes an 'app' key, the value must be a subclass of ExtensionApp. An instance @@ -2335,7 +2340,7 @@ def init_server_extensions(self): self.extension_manager.from_jpserver_extensions(self.jpserver_extensions) self.extension_manager.link_all_extensions() - def load_server_extensions(self): + def load_server_extensions(self) -> None: """Load any extensions specified by config. Import the module, then call the load_jupyter_server_extension function, @@ -2345,7 +2350,7 @@ def load_server_extensions(self): """ self.extension_manager.load_all_extensions() - def init_mime_overrides(self): + def init_mime_overrides(self) -> None: # On some Windows machines, an application has registered incorrect # mimetypes in the registry. # Tornado uses this when serving .css and .js files, causing browsers to @@ -2361,7 +2366,7 @@ def init_mime_overrides(self): # for python <3.8 mimetypes.add_type("application/wasm", ".wasm") - def shutdown_no_activity(self): + def shutdown_no_activity(self) -> None: """Shutdown server on timeout when there are no kernels or terminals.""" km = self.kernel_manager if len(km) != 0: @@ -2379,7 +2384,7 @@ def shutdown_no_activity(self): ) self.stop() - def init_shutdown_no_activity(self): + def init_shutdown_no_activity(self) -> None: """Initialize a shutdown on no activity.""" if self.shutdown_no_activity_timeout > 0: self.log.info( @@ -2390,7 +2395,7 @@ def init_shutdown_no_activity(self): pc.start() @property - def http_server(self): + def http_server(self) -> httpserver.HTTPServer: """An instance of Tornado's HTTPServer class for the Server Web Application.""" try: return self._http_server @@ -2402,7 +2407,7 @@ def http_server(self): ) raise AttributeError(msg) from None - def init_httpserver(self): + def init_httpserver(self) -> None: """Creates an instance of a Tornado HTTPServer for the Server Web Application and sets the http_server attribute. """ @@ -2428,7 +2433,7 @@ def init_httpserver(self): self._find_http_port() self.io_loop.add_callback(self._bind_http_server) - def _bind_http_server(self): + def _bind_http_server(self) -> None: """Bind our http server.""" success = self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp() if not success: @@ -2440,7 +2445,7 @@ def _bind_http_server(self): ) self.exit(1) - def _bind_http_server_unix(self): + def _bind_http_server_unix(self) -> bool: """Bind an http server on unix.""" if unix_socket_in_use(self.sock): self.log.warning(_i18n("The socket %s is already in use.") % self.sock) @@ -2461,12 +2466,12 @@ def _bind_http_server_unix(self): else: return True - def _bind_http_server_tcp(self): + def _bind_http_server_tcp(self) -> bool: """Bind a tcp server.""" self.http_server.listen(self.port, self.ip) return True - def _find_http_port(self): + def _find_http_port(self) -> None: """Find an available http port.""" success = False port = self.port @@ -2514,7 +2519,7 @@ def _find_http_port(self): self.exit(1) @staticmethod - def _init_asyncio_patch(): + def _init_asyncio_patch() -> None: """set default asyncio policy to be compatible with tornado Tornado 6.0 is not compatible with default asyncio @@ -2541,11 +2546,11 @@ def _init_asyncio_patch(): @catch_config_error def initialize( self, - argv=None, - find_extensions=True, - new_httpserver=True, - starter_extension=None, - ): + argv: t.Optional[t.List[str]] = None, + find_extensions: bool = True, + new_httpserver: bool = True, + starter_extension: t.Any = None, + ) -> None: """Initialize the Server application class, configurables, web application, and http server. Parameters @@ -2604,7 +2609,7 @@ def initialize( if new_httpserver: self.init_httpserver() - async def cleanup_kernels(self): + async def cleanup_kernels(self) -> None: """Shutdown all kernels. The kernels will shutdown themselves when this process no longer exists, @@ -2619,7 +2624,7 @@ async def cleanup_kernels(self): self.log.info(kernel_msg % n_kernels) await ensure_async(self.kernel_manager.shutdown_all()) - async def cleanup_extensions(self): + async def cleanup_extensions(self) -> None: """Call shutdown hooks in all extensions.""" if not getattr(self, "extension_manager", None): return @@ -2630,7 +2635,7 @@ async def cleanup_extensions(self): self.log.info(extension_msg % n_extensions) await ensure_async(self.extension_manager.stop_all_extensions()) - def running_server_info(self, kernel_count=True): + def running_server_info(self, kernel_count: bool = True) -> str: """Return the current working directory and the server url information""" info = self.contents_manager.info_string() + "\n" if kernel_count: @@ -2647,7 +2652,7 @@ def running_server_info(self, kernel_count=True): ) return info - def server_info(self): + def server_info(self) -> t.Dict[str, t.Any]: """Return a JSONable dict of information about this server.""" return { "url": self.connection_url, @@ -2663,7 +2668,7 @@ def server_info(self): "version": ServerApp.version, } - def write_server_info_file(self): + def write_server_info_file(self) -> None: """Write the result of server_info() to the JSON file info_file.""" try: with secure_write(self.info_file) as f: @@ -2671,7 +2676,7 @@ def write_server_info_file(self): except OSError as e: self.log.error(_i18n("Failed to write server-info to %s: %r"), self.info_file, e) - def remove_server_info_file(self): + def remove_server_info_file(self) -> None: """Remove the jpserver-.json file created for this server. Ignores the error raised when the file has already been removed. @@ -2682,7 +2687,7 @@ def remove_server_info_file(self): if e.errno != errno.ENOENT: raise - def _resolve_file_to_run_and_root_dir(self): + def _resolve_file_to_run_and_root_dir(self) -> str: """Returns a relative path from file_to_run to root_dir. If root_dir and file_to_run are incompatible, i.e. on different subtrees, @@ -2712,8 +2717,9 @@ def _resolve_file_to_run_and_root_dir(self): "is on the same path as `root_dir`." ) self.exit(1) + return "" - def _write_browser_open_file(self, url, fh): + def _write_browser_open_file(self, url: str, fh: t.Any) -> None: """Write the browser open file.""" if self.identity_provider.token: url = url_concat(url, {"token": self.identity_provider.token}) @@ -2723,7 +2729,7 @@ def _write_browser_open_file(self, url, fh): template = jinja2_env.get_template("browser-open.html") fh.write(template.render(open_url=url, base_url=self.base_url)) - def write_browser_open_files(self): + def write_browser_open_files(self) -> None: """Write an `browser_open_file` and `browser_open_file_to_run` files This can be used to open a file directly in a browser. @@ -2744,7 +2750,7 @@ def write_browser_open_files(self): with open(self.browser_open_file_to_run, "w", encoding="utf-8") as f: self._write_browser_open_file(file_open_url, f) - def write_browser_open_file(self): + def write_browser_open_file(self) -> None: """Write an jpserver--open.html file This can be used to open the notebook in a browser @@ -2755,7 +2761,7 @@ def write_browser_open_file(self): with open(self.browser_open_file, "w", encoding="utf-8") as f: self._write_browser_open_file(open_url, f) - def remove_browser_open_files(self): + def remove_browser_open_files(self) -> None: """Remove the `browser_open_file` and `browser_open_file_to_run` files created for this server. @@ -2768,7 +2774,7 @@ def remove_browser_open_files(self): if e.errno != errno.ENOENT: raise - def remove_browser_open_file(self): + def remove_browser_open_file(self) -> None: """Remove the jpserver--open.html file created for this server. Ignores the error raised when the file has already been removed. @@ -2779,7 +2785,7 @@ def remove_browser_open_file(self): if e.errno != errno.ENOENT: raise - def _prepare_browser_open(self): + def _prepare_browser_open(self) -> t.Tuple[str, t.Optional[str]]: """Prepare to open the browser.""" if not self.use_redirect_file: uri = self.default_url[len(self.base_url) :] @@ -2802,7 +2808,7 @@ def _prepare_browser_open(self): return assembled_url, open_file - def launch_browser(self): + def launch_browser(self) -> None: """Launch the browser.""" # Deferred import for environments that do not have # the webbrowser module. @@ -2825,7 +2831,7 @@ def target(): threading.Thread(target=target).start() - def start_app(self): + def start_app(self) -> None: """Start the Jupyter Server application.""" super().start() @@ -2908,7 +2914,7 @@ def start_app(self): self.log.critical("\n".join(message)) - async def _cleanup(self): + async def _cleanup(self) -> None: """General cleanup of files, extensions and kernels created by this instance ServerApp. """ @@ -2937,7 +2943,7 @@ async def _cleanup(self): # Stop a server if its set. self.http_server.stop() - def start_ioloop(self): + def start_ioloop(self) -> None: """Start the IO Loop.""" if sys.platform.startswith("win"): # add no-op to wake every 5s @@ -2949,11 +2955,11 @@ def start_ioloop(self): except KeyboardInterrupt: self.log.info(_i18n("Interrupted...")) - def init_ioloop(self): + def init_ioloop(self) -> None: """init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks""" self.io_loop = ioloop.IOLoop.current() - def start(self): + def start(self) -> None: """Start the Jupyter server app, after initialization This method takes no arguments so all configuration and initialization @@ -2961,13 +2967,13 @@ def start(self): self.start_app() self.start_ioloop() - async def _stop(self): + async def _stop(self) -> None: """Cleanup resources and stop the IO Loop.""" await self._cleanup() if getattr(self, "io_loop", None): self.io_loop.stop() - def stop(self, from_signal=False): + def stop(self, from_signal: bool = False) -> None: """Cleanup resources and stop the server.""" # signal that stopping has begun self._stopping = True @@ -2983,7 +2989,9 @@ def stop(self, from_signal=False): self.io_loop.add_callback(self._stop) -def list_running_servers(runtime_dir=None, log=None): +def list_running_servers( + runtime_dir: t.Optional[str] = None, log: t.Optional[logging.Logger] = None +) -> t.Generator[t.Any, None, None]: """Iterate over the server info files of running Jupyter servers. Given a runtime directory, find jpserver-* files in the security directory, diff --git a/jupyter_server/services/contents/filemanager.py b/jupyter_server/services/contents/filemanager.py index 4ddf9cd721..3558b67402 100644 --- a/jupyter_server/services/contents/filemanager.py +++ b/jupyter_server/services/contents/filemanager.py @@ -33,8 +33,8 @@ try: from os.path import samefile except ImportError: - # windows + py2 - from jupyter_server.utils import samefile_simple as samefile + # windows + from jupyter_server.utils import samefile_simple as samefile # type:ignore[assignment] _script_exporter = None diff --git a/jupyter_server/services/contents/handlers.py b/jupyter_server/services/contents/handlers.py index 18e4a7e685..50e7703db8 100644 --- a/jupyter_server/services/contents/handlers.py +++ b/jupyter_server/services/contents/handlers.py @@ -353,7 +353,7 @@ def get(self, path): class TrustNotebooksHandler(JupyterHandler): """Handles trust/signing of notebooks""" - @web.authenticated + @web.authenticated # type:ignore[misc] @authorized(resource=AUTH_RESOURCE) async def post(self, path=""): """Trust a notebook by path.""" diff --git a/jupyter_server/services/contents/manager.py b/jupyter_server/services/contents/manager.py index dc9c754f84..ebde32e29f 100644 --- a/jupyter_server/services/contents/manager.py +++ b/jupyter_server/services/contents/manager.py @@ -325,7 +325,7 @@ def run_post_save_hooks(self, model, os_path): @default("checkpoints") def _default_checkpoints(self): - return self.checkpoints_class(**self.checkpoints_kwargs) # type:ignore[operator] + return self.checkpoints_class(**self.checkpoints_kwargs) @default("checkpoints_kwargs") def _default_checkpoints_kwargs(self): @@ -761,7 +761,7 @@ class AsyncContentsManager(ContentsManager): @default("checkpoints") def _default_checkpoints(self): - return self.checkpoints_class(**self.checkpoints_kwargs) # type:ignore[operator] + return self.checkpoints_class(**self.checkpoints_kwargs) @default("checkpoints_kwargs") def _default_checkpoints_kwargs(self): diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index 9f58692ea3..fb17d61e6b 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -4,9 +4,9 @@ """ import json from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast -from jupyter_events import EventLogger +import jupyter_events.logger from tornado import web, websocket from jupyter_server.auth import authorized @@ -47,7 +47,9 @@ async def get(self, *args, **kwargs): if res is not None: await res - async def event_listener(self, logger: EventLogger, schema_id: str, data: dict) -> None: + async def event_listener( + self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict + ) -> None: """Write an event message.""" capsule = dict(schema_id=schema_id, **data) self.write_message(json.dumps(capsule)) @@ -105,8 +107,8 @@ async def post(self): try: validate_model(payload) self.event_logger.emit( - schema_id=payload.get("schema_id"), - data=payload.get("data"), + schema_id=cast(str, payload.get("schema_id")), + data=cast("Dict[str, Any]", payload.get("data")), timestamp_override=get_timestamp(payload), ) self.set_status(204) diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 0ebebf4a51..763988dec7 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -53,7 +53,9 @@ async def post(self): model.setdefault("name", km.default_kernel_name) kernel_id = await ensure_async( - km.start_kernel(kernel_name=model["name"], path=model.get("path")) + km.start_kernel( # type:ignore[has-type] + kernel_name=model["name"], path=model.get("path") + ) ) model = await ensure_async(km.kernel_model(kernel_id)) location = url_path_join(self.base_url, "api", "kernels", url_escape(kernel_id)) @@ -92,7 +94,7 @@ async def post(self, kernel_id, action): """Interrupt or restart a kernel.""" km = self.kernel_manager if action == "interrupt": - await ensure_async(km.interrupt_kernel(kernel_id)) + await ensure_async(km.interrupt_kernel(kernel_id)) # type:ignore[func-returns-value] self.set_status(204) if action == "restart": try: diff --git a/jupyter_server/services/kernelspecs/handlers.py b/jupyter_server/services/kernelspecs/handlers.py index bee99bb158..4afccee246 100644 --- a/jupyter_server/services/kernelspecs/handlers.py +++ b/jupyter_server/services/kernelspecs/handlers.py @@ -4,9 +4,12 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import glob import json import os +from typing import Any pjoin = os.path.join @@ -63,7 +66,7 @@ async def get(self): """Get the list of kernel specs.""" ksm = self.kernel_spec_manager km = self.kernel_manager - model = {} + model: dict[str, Any] = {} model["default"] = km.default_kernel_name model["kernelspecs"] = specs = {} kspecs = await ensure_async(ksm.get_all_specs()) diff --git a/jupyter_server/services/sessions/handlers.py b/jupyter_server/services/sessions/handlers.py index 1b042b152d..ab2ce5c939 100644 --- a/jupyter_server/services/sessions/handlers.py +++ b/jupyter_server/services/sessions/handlers.py @@ -83,10 +83,10 @@ async def post(self): exists = await ensure_async(sm.session_exists(path=path)) if exists: - model = await sm.get_session(path=path) + s_model = await sm.get_session(path=path) else: try: - model = await sm.create_session( + s_model = await sm.create_session( path=path, kernel_name=kernel_name, kernel_id=kernel_id, @@ -106,10 +106,10 @@ async def post(self): except Exception as e: raise web.HTTPError(500, str(e)) from e - location = url_path_join(self.base_url, "api", "sessions", model["id"]) + location = url_path_join(self.base_url, "api", "sessions", s_model["id"]) self.set_header("Location", location) self.set_status(201) - self.finish(json.dumps(model, default=json_default)) + self.finish(json.dumps(s_model, default=json_default)) class SessionHandler(SessionsAPIHandler): @@ -170,16 +170,16 @@ async def patch(self, session_id): changes["kernel_id"] = kernel_id await sm.update_session(session_id, **changes) - model = await sm.get_session(session_id=session_id) + s_model = await sm.get_session(session_id=session_id) - if model["kernel"]["id"] != before["kernel"]["id"]: + if s_model["kernel"]["id"] != before["kernel"]["id"]: # kernel_id changed because we got a new kernel # shutdown the old one fut = asyncio.ensure_future(ensure_async(km.shutdown_kernel(before["kernel"]["id"]))) # If we are not using pending kernels, wait for the kernel to shut down if not getattr(km, "use_pending_kernels", None): await fut - self.finish(json.dumps(model, default=json_default)) + self.finish(json.dumps(s_model, default=json_default)) @web.authenticated @authorized diff --git a/jupyter_server/services/shutdown.py b/jupyter_server/services/shutdown.py index 2f5d490e4a..e64b9dbc91 100644 --- a/jupyter_server/services/shutdown.py +++ b/jupyter_server/services/shutdown.py @@ -19,7 +19,8 @@ async def post(self): """Shut down the server.""" self.log.info("Shutting down on /api/shutdown request.") - await self.serverapp._cleanup() + if self.serverapp: + await self.serverapp._cleanup() ioloop.IOLoop.current().stop() diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index f08d279872..bd41710275 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -1,6 +1,8 @@ """Notebook related utilities""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import errno import importlib.util import os @@ -8,7 +10,7 @@ import sys import warnings from contextlib import contextmanager -from typing import NewType +from typing import Any, Generator, NewType, Sequence from urllib.parse import ( SplitResult, quote, @@ -23,13 +25,13 @@ from _frozen_importlib_external import _NamespacePath # type:ignore[import] from jupyter_core.utils import ensure_async from packaging.version import Version -from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest +from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest, HTTPResponse from tornado.netutil import Resolver ApiPath = NewType("ApiPath", str) -def url_path_join(*pieces): +def url_path_join(*pieces: str) -> str: """Join components of url into a relative url Use to prevent double slash when joining subpath. This will leave the @@ -48,12 +50,12 @@ def url_path_join(*pieces): return result -def url_is_absolute(url): +def url_is_absolute(url: str) -> bool: """Determine whether a given URL is absolute""" return urlparse(url).path.startswith("/") -def path2url(path): +def path2url(path: str) -> str: """Convert a local file path to a URL""" pieces = [quote(p) for p in path.split(os.sep)] # preserve trailing / @@ -63,14 +65,14 @@ def path2url(path): return url -def url2path(url): +def url2path(url: str) -> str: """Convert a URL to a local file path""" pieces = [unquote(p) for p in url.split("/")] path = os.path.join(*pieces) return path -def url_escape(path): +def url_escape(path: str) -> str: """Escape special characters in a URL path Turns '/foo bar/' into '/foo%20bar/' @@ -79,7 +81,7 @@ def url_escape(path): return "/".join([quote(p) for p in parts]) -def url_unescape(path): +def url_unescape(path: str) -> str: """Unescape special characters in a URL path Turns '/foo%20bar/' into '/foo bar/' @@ -87,7 +89,7 @@ def url_unescape(path): return "/".join([unquote(p) for p in path.split("/")]) -def samefile_simple(path, other_path): +def samefile_simple(path: str, other_path: str) -> bool: """ Fill in for os.path.samefile when it is unavailable (Windows+py2). @@ -140,7 +142,7 @@ def to_api_path(os_path: str, root: str = "") -> ApiPath: return ApiPath(path) -def check_version(v, check): +def check_version(v: str, check: str) -> bool: """check version string v >= check If dev/prerelease tags result in TypeError for string-number comparison, @@ -156,7 +158,7 @@ def check_version(v, check): # Copy of IPython.utils.process.check_pid: -def _check_pid_win32(pid): +def _check_pid_win32(pid: int) -> bool: import ctypes # OpenProcess returns 0 if no such process (of ours) exists @@ -164,7 +166,7 @@ def _check_pid_win32(pid): return bool(ctypes.windll.kernel32.OpenProcess(1, 0, pid)) # type:ignore[attr-defined] -def _check_pid_posix(pid): +def _check_pid_posix(pid: int) -> bool: """Copy of IPython.utils.process.check_pid""" try: os.kill(pid, 0) @@ -195,22 +197,22 @@ async def run_sync_in_loop(maybe_async): return ensure_async(maybe_async) -def urlencode_unix_socket_path(socket_path): +def urlencode_unix_socket_path(socket_path: str) -> str: """Encodes a UNIX socket path string from a socket path for the `http+unix` URI form.""" return socket_path.replace("/", "%2F") -def urldecode_unix_socket_path(socket_path): +def urldecode_unix_socket_path(socket_path: str) -> str: """Decodes a UNIX sock path string from an encoded sock path for the `http+unix` URI form.""" return socket_path.replace("%2F", "/") -def urlencode_unix_socket(socket_path): +def urlencode_unix_socket(socket_path: str) -> str: """Encodes a UNIX socket URL from a socket path for the `http+unix` URI form.""" return "http+unix://%s" % urlencode_unix_socket_path(socket_path) -def unix_socket_in_use(socket_path): +def unix_socket_in_use(socket_path: str) -> bool: """Checks whether a UNIX socket path on disk is in use by attempting to connect to it.""" if not os.path.exists(socket_path): return False @@ -227,7 +229,9 @@ def unix_socket_in_use(socket_path): @contextmanager -def _request_for_tornado_client(urlstring, method="GET", body=None, headers=None): +def _request_for_tornado_client( + urlstring: str, method: str = "GET", body: Any = None, headers: Any = None +) -> Generator[HTTPRequest, None, None]: """A utility that provides a context that handles HTTP, HTTPS, and HTTP+UNIX request. Creates a tornado HTTPRequest object with a URL @@ -278,7 +282,9 @@ async def resolve(self, host, port, *args, **kwargs): yield request -def fetch(urlstring, method="GET", body=None, headers=None): +def fetch( + urlstring: str, method: str = "GET", body: Any = None, headers: Any = None +) -> HTTPResponse: """ Send a HTTP, HTTPS, or HTTP+UNIX request to a Tornado Web Server. Returns a tornado HTTPResponse. @@ -290,7 +296,9 @@ def fetch(urlstring, method="GET", body=None, headers=None): return response -async def async_fetch(urlstring, method="GET", body=None, headers=None, io_loop=None): +async def async_fetch( + urlstring: str, method: str = "GET", body: Any = None, headers: Any = None, io_loop: Any = None +) -> HTTPResponse: """ Send an asynchronous HTTP, HTTPS, or HTTP+UNIX request to a Tornado Web Server. Returns a tornado HTTPResponse. @@ -302,7 +310,7 @@ async def async_fetch(urlstring, method="GET", body=None, headers=None, io_loop= return response -def is_namespace_package(namespace): +def is_namespace_package(namespace: str) -> bool | None: """Is the provided namespace a Python Namespace Package (PEP420). https://www.python.org/dev/peps/pep-0420/#specification @@ -324,7 +332,7 @@ def is_namespace_package(namespace): return isinstance(spec.submodule_search_locations, _NamespacePath) -def filefind(filename, path_dirs=None): +def filefind(filename: str, path_dirs: Sequence[str] | str | None = None) -> str: """Find a file by looking through a sequence of paths. This iterates through a sequence of paths looking for a file and returns the full, absolute path of the first occurrence of the file. If no set of @@ -378,7 +386,7 @@ def filefind(filename, path_dirs=None): raise OSError(msg) -def expand_path(s): +def expand_path(s: str) -> str: """Expand $VARS and ~names in a string, like a shell :Examples: @@ -399,7 +407,7 @@ def expand_path(s): return s -def import_item(name): +def import_item(name: str) -> Any: """Import and return ``bar`` given the string ``foo.bar``. Calling ``bar = import_item("foo.bar")`` is the functional equivalent of executing the code ``from foo import bar``. diff --git a/pyproject.toml b/pyproject.toml index 47695dc9b5..75d632c834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,7 @@ nowarn = "test -W default {args}" [tool.hatch.envs.typing] features = ["test"] -dependencies = [ "mypy>=1.5.1", "traitlets>=5.10.1" ] +dependencies = [ "mypy>=1.5.1", "traitlets>=5.11.2", "jupyter_core>=5.3.2"] [tool.hatch.envs.typing.scripts] test = "mypy --install-types --non-interactive {args:.}" @@ -318,6 +318,7 @@ pydist_resource_paths = ["jupyter_server/static/style/bootstrap.min.css", "jupyt post-version-spec = "dev" [tool.mypy] +python_version = "3.8" check_untyped_defs = true disallow_incomplete_defs = true no_implicit_optional = true diff --git a/tests/base/test_handlers.py b/tests/base/test_handlers.py index 98d6ff73dd..370100fe9d 100644 --- a/tests/base/test_handlers.py +++ b/tests/base/test_handlers.py @@ -50,9 +50,9 @@ def test_jupyter_handler(jp_serverapp): handler.settings["mathjax_config"] = "bar" assert handler.mathjax_url == "/foo" assert handler.mathjax_config == "bar" - handler.settings["terminal_manager"] = "fizz" - assert handler.terminal_manager == "fizz" - handler.settings["allow_origin"] = True + handler.settings["terminal_manager"] = None + assert handler.terminal_manager is None + handler.settings["allow_origin"] = True # type:ignore[unreachable] handler.set_cors_headers() handler.settings["allow_origin"] = False handler.settings["allow_origin_pat"] = "foo" diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index e7f919a98a..813eb853b3 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -1,4 +1,5 @@ import asyncio +from datetime import datetime import pytest from tornado import web @@ -18,7 +19,7 @@ class DummyKernel: execution_state: str - last_activity: str + last_activity: datetime def __init__(self, kernel_name="python"): self.kernel_name = kernel_name diff --git a/tests/test_gateway.py b/tests/test_gateway.py index fec747afeb..37ce8b03ee 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -322,13 +322,13 @@ def test_token_renewer_config(jp_server_config, jp_configurable_serverapp, renew if renewer_type == "default": assert isinstance(gw_client.gateway_token_renewer, NoOpTokenRenewer) token = gw_client.gateway_token_renewer.get_token( - gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token + gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token or "" ) assert token == gw_client.auth_token else: assert isinstance(gw_client.gateway_token_renewer, CustomTestTokenRenewer) token = gw_client.gateway_token_renewer.get_token( - gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token + gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token or "" ) assert token == CustomTestTokenRenewer.TEST_EXPECTED_TOKEN_VALUE diff --git a/tests/test_utils.py b/tests/test_utils.py index fced26cad2..5a3f33138b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -95,7 +95,7 @@ def test_path_utils(tmp_path): def test_check_version(): assert check_version("1.0.2", "1.0.1") assert not check_version("1.0.0", "1.0.1") - assert check_version(1.0, "1.0.1") + assert check_version(1.0, "1.0.1") # type:ignore[arg-type] def test_check_pid():