diff --git a/google/oauth2/webauthn_handler.py b/google/oauth2/webauthn_handler.py new file mode 100644 index 000000000..eb4fa2662 --- /dev/null +++ b/google/oauth2/webauthn_handler.py @@ -0,0 +1,82 @@ +import abc +import os +import struct +import subprocess + +from google.auth import exceptions +from google.oauth2.webauthn_types import GetRequest, GetResponse + + +class WebAuthnHandler(abc.ABC): + @abc.abstractmethod + def is_available(self) -> bool: + """Check whether this WebAuthn handler is available""" + raise NotImplementedError("is_available method must be implemented") + + @abc.abstractmethod + def get(self, get_request: GetRequest) -> GetResponse: + """WebAuthn get (assertion)""" + raise NotImplementedError("get method must be implemented") + + +class PluginHandler(WebAuthnHandler): + """Offloads WebAuthn get reqeust to a pluggable command-line tool. + + Offloads WebAuthn get to a plugin which takes the form of a + command-line tool. The command-line tool is configurable via the + PluginHandler._ENV_VAR environment variable. + + The WebAuthn plugin should implement the following interface: + + Communication occurs over stdin/stdout, and messages are both sent and + received in the form: + + [4 bytes - payload size (little-endian)][variable bytes - json payload] + """ + + _ENV_VAR = "GOOGLE_AUTH_WEBAUTHN_PLUGIN" + + def is_available(self) -> bool: + try: + self._find_plugin() + except: + return False + else: + return True + + def get(self, get_request: GetRequest) -> GetResponse: + request_json = get_request.to_json() + cmd = self._find_plugin() + response_json = self._call_plugin(cmd, request_json) + return GetResponse.from_json(response_json) + + def _call_plugin(self, cmd: str, input_json: str) -> str: + # Calculate length of input + input_length = len(input_json) + length_bytes_le = struct.pack(" str: + plugin_cmd = os.environ.get(PluginHandler._ENV_VAR) + if plugin_cmd is None: + raise exceptions.InvalidResource( + "{} env var is not set".format(PluginHandler._ENV_VAR) + ) + return plugin_cmd diff --git a/google/oauth2/webauthn_types.py b/google/oauth2/webauthn_types.py new file mode 100644 index 000000000..92c56a559 --- /dev/null +++ b/google/oauth2/webauthn_types.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass +import json +from typing import List, Dict, Optional + +from google.auth import exceptions + + +@dataclass(frozen=True) +class PublicKeyCredentialDescriptor: + """Descriptor for a security key based credential. + + https://www.w3.org/TR/webauthn-3/#dictionary-credential-descriptor + + Args: + id: credential id (key handle). + transports: <'usb'|'nfc'|'ble'|'internal'> List of supported transports. + """ + + id: str + transports: Optional[List[str]] = None + + def to_dict(self): + cred = {"type": "public-key", "id": self.id} + if self.transports: + cred["transports"] = self.transports + return cred + + +@dataclass +class AuthenticationExtensionsClientInputs: + """Client extensions inputs for WebAuthn extensions. + + Args: + appid: app id that can be asserted with in addition to rpid. + https://www.w3.org/TR/webauthn-3/#sctn-appid-extension + """ + + appid: Optional[str] = None + + def to_dict(self): + extensions = {} + if self.appid: + extensions["appid"] = self.appid + return extensions + + +@dataclass +class GetRequest: + """WebAuthn get request + + Args: + origin: Origin where the WebAuthn get assertion takes place. + rpid: Relying Party ID. + challenge: raw challenge. + timeout_ms: Timeout number in millisecond. + allow_credentials: List of allowed credentials. + user_verification: <'required'|'preferred'|'discouraged'> User verification requirement. + extensions: WebAuthn authentication extensions inputs. + """ + + origin: str + rpid: str + challenge: str + timeout_ms: Optional[int] = None + allow_credentials: Optional[List[PublicKeyCredentialDescriptor]] = None + user_verification: Optional[str] = None + extensions: Optional[AuthenticationExtensionsClientInputs] = None + + def to_json(self) -> str: + req_options = {"rpid": self.rpid, "challenge": self.challenge} + if self.timeout_ms: + req_options["timeout"] = self.timeout_ms + if self.allow_credentials: + req_options["allowCredentials"] = [ + c.to_dict() for c in self.allow_credentials + ] + if self.user_verification: + req_options["userVerification"] = self.user_verification + if self.extensions: + req_options["extensions"] = self.extensions.to_dict() + return json.dumps( + { + "type": "get", + "origin": self.origin, + "requestData": req_options, + } + ) + + +@dataclass(frozen=True) +class AuthenticatorAssertionResponse: + """Authenticator response to a WebAuthn get (assertion) request. + + https://www.w3.org/TR/webauthn-3/#authenticatorassertionresponse + + Args: + client_data_json: client data JSON. + authenticator_data: authenticator data. + signature: signature. + user_handle: user handle. + """ + + client_data_json: str + authenticator_data: str + signature: str + user_handle: Optional[str] + + +@dataclass(frozen=True) +class GetResponse: + """WebAuthn get (assertion) response. + + Args: + id: credential id (key handle). + response: The authenticator assertion response. + authenticator_attachment: <'cross-platform'|'platform'> The attachment status of the authenticator. + client_extension_results: WebAuthn authentication extensions output results in a dictionary. + """ + + id: str + response: AuthenticatorAssertionResponse + authenticator_attachment: Optional[str] + client_extension_results: Optional[Dict] + + @staticmethod + def from_json(json_str: str): + """Verify and construct GetResponse from a JSON string.""" + try: + resp_json = json.loads(json_str) + except ValueError: + raise exceptions.MalformedError("Invalid Get JSON response") + if resp_json.get("type") != "getResponse": + raise exceptions.MalformedError( + "Invalid Get response type: {}".format(resp_json.get("type")) + ) + pk_cred = resp_json.get("responseData") + if pk_cred is None: + if resp_json.get("error"): + raise exceptions.ReauthFailError( + "WebAuthn.get failure: {}".format(resp_json["error"]) + ) + else: + raise exceptions.MalformedError("Get response is empty") + if pk_cred.get("type") != "public-key": + raise exceptions.MalformedError( + "Invalid credential type: {}".format(pk_cred.get("type")) + ) + assertion_json = pk_cred["response"] + assertion_resp = AuthenticatorAssertionResponse( + client_data_json=assertion_json["clientDataJSON"], + authenticator_data=assertion_json["authenticatorData"], + signature=assertion_json["signature"], + user_handle=assertion_json.get("userHandle"), + ) + return GetResponse( + id=pk_cred["id"], + response=assertion_resp, + authenticator_attachment=pk_cred.get("authenticatorAttachment"), + client_extension_results=pk_cred.get("clientExtensionResults"), + ) diff --git a/tests/oauth2/test_webauthn_handler.py b/tests/oauth2/test_webauthn_handler.py new file mode 100644 index 000000000..55d72e7e6 --- /dev/null +++ b/tests/oauth2/test_webauthn_handler.py @@ -0,0 +1,156 @@ +import json +import mock +import pytest +import struct + +from google.auth import exceptions +from google.oauth2 import webauthn_handler +from google.oauth2 import webauthn_types + + +@pytest.fixture +def os_get_stub(): + with mock.patch.object( + webauthn_handler.os.environ, + "get", + return_value="gcloud_webauthn_plugin", + name="fake os.environ.get", + ) as mock_os_environ_get: + yield mock_os_environ_get + + +@pytest.fixture +def subprocess_run_stub(): + with mock.patch.object( + webauthn_handler.subprocess, + "run", + name="fake subprocess.run", + ) as mock_subprocess_run: + yield mock_subprocess_run + + +def test_PluginHandler_is_available(os_get_stub): + test_handler = webauthn_handler.PluginHandler() + + assert test_handler.is_available() == True + + os_get_stub.return_value = None + assert test_handler.is_available() == False + + +GET_ASSERTION_REQUEST = webauthn_types.GetRequest( + origin="fake_origin", + rpid="fake_rpid", + challenge="fake_challenge", + allow_credentials=[ + webauthn_types.PublicKeyCredentialDescriptor(id="fake_id_1"), + ], +) + + +def test_malformated_get_assertion_response(os_get_stub, subprocess_run_stub): + response_len = struct.pack("