Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various client authentication related fixes #49

Merged
merged 1 commit into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions src/idpyoidc/server/client_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,15 @@ def _verify(
request: Optional[Union[dict, Message]] = None,
authorization_token: Optional[str] = None,
endpoint=None, # Optional[Endpoint]
get_client_id_from_token=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a pointer to a function/method?

wouldn't it be better Optional[...]?

Originally posted by @peppelinux in #39 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a pointer to a function/method.
In verify_client in the same module it's defined to be:

get_client_id_from_token: Optional[Callable] = None

**kwargs,
):
_token = request.get("access_token")
if _token is None:
raise ClientAuthenticationError("No access token")

res = {"token": _token}
_client_id = request.get("client_id")
_client_id = get_client_id_from_token(endpoint_context, _token, request)
if _client_id:
res["client_id"] = _client_id
return res
Expand Down Expand Up @@ -483,6 +484,7 @@ def verify_client(

auth_info = {}
methods = endpoint_context.client_authn_method
client_id = None
allowed_methods = getattr(endpoint, "client_authn_method")
if not allowed_methods:
allowed_methods = list(methods.keys())
Expand All @@ -499,48 +501,47 @@ def verify_client(
endpoint=endpoint,
get_client_id_from_token=get_client_id_from_token,
)
break
except (BearerTokenAuthenticationError, ClientAuthenticationError):
raise
except Exception as err:
logger.info("Verifying auth using {} failed: {}".format(_method.tag, err))
continue

if auth_info.get("method") == "none":
return auth_info
if auth_info.get("method") == "none" and auth_info.get("client_id") is None:
break

client_id = auth_info.get("client_id")
if client_id is None:
raise ClientAuthenticationError("Failed to verify client")
client_id = auth_info.get("client_id")
if client_id is None:
raise ClientAuthenticationError("Failed to verify client")

if also_known_as:
client_id = also_known_as[client_id]
auth_info["client_id"] = client_id
if also_known_as:
client_id = also_known_as[client_id]
auth_info["client_id"] = client_id

if client_id not in endpoint_context.cdb:
raise UnknownClient("Unknown Client ID")
if client_id not in endpoint_context.cdb:
raise UnknownClient("Unknown Client ID")

_cinfo = endpoint_context.cdb[client_id]
_cinfo = endpoint_context.cdb[client_id]

if not valid_client_info(_cinfo):
logger.warning("Client registration has timed out or " "client secret is expired.")
raise InvalidClient("Not valid client")
if not valid_client_info(_cinfo):
logger.warning("Client registration has timed out or " "client secret is expired.")
raise InvalidClient("Not valid client")

# Validate that the used method is allowed for this client/endpoint
client_allowed_methods = _cinfo.get(
f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method")
)
if client_allowed_methods is not None and _method and _method.tag not in client_allowed_methods:
logger.info(
f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: "
f"`{', '.join(client_allowed_methods)}`"
)
raise UnAuthorizedClient(
f"Authentication method: {_method.tag} not allowed for client: {client_id} in "
f"endpoint: {endpoint.name}"
# Validate that the used method is allowed for this client/endpoint
client_allowed_methods = _cinfo.get(
f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method")
)
if client_allowed_methods is not None and auth_info["method"] not in client_allowed_methods:
logger.info(
f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: "
f"`{', '.join(client_allowed_methods)}`"
)
auth_info = {}
continue
break

# store what authn method was used
if auth_info.get("method"):
if "method" in auth_info and client_id:
_request_type = request.__class__.__name__
_used_authn_method = _cinfo.get("auth_method")
if _used_authn_method:
Expand Down
6 changes: 5 additions & 1 deletion src/idpyoidc/server/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def set_client_authn_methods(self, **kwargs):
kwargs[self.auth_method_attribute] = _methods
elif _methods is not None: # [] or '' or something not None but regarded as nothing.
self.client_authn_method = ["none"] # Ignore default value
elif self.default_capabilities:
self.client_authn_method = self.default_capabilities.get("client_authn_method")
self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs)
return kwargs

def get_provider_info_attributes(self):
Expand Down Expand Up @@ -249,7 +252,8 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No
if authn_info == {} and self.client_authn_method and len(self.client_authn_method):
LOGGER.debug("client_authn_method: %s", self.client_authn_method)
raise UnAuthorizedClient("Authorization failed")

if "client_id" not in authn_info and authn_info.get("method") != "none":
raise UnAuthorizedClient("Authorization failed")
return authn_info

def do_post_parse_request(
Expand Down
2 changes: 1 addition & 1 deletion src/idpyoidc/server/oidc/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def parse_request(self, request, http_info=None, **kwargs):

# Verify that the client is allowed to do this
auth_info = self.client_authentication(request, http_info, **kwargs)
if not auth_info or auth_info["method"] == "none":
if not auth_info:
pass
elif isinstance(auth_info, ResponseMessage):
return auth_info
Expand Down
2 changes: 1 addition & 1 deletion src/idpyoidc/server/oidc/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def parse_request(self, request, http_info=None, **kwargs):
try:
auth_info = self.client_authentication(request, http_info, **kwargs)
except ClientAuthenticationError as e:
return self.error_cls(error="invalid_token", error_description=e.args[0])
return self.error_cls(error="invalid_token", error_description="Invalid token")

if isinstance(auth_info, ResponseMessage):
return auth_info
Expand Down
15 changes: 7 additions & 8 deletions tests/test_server_17_client_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def create_method(self):

def test_bearer_body(self):
request = {"access_token": "1234567890"}
assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"}
assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"}

def test_bearer_body_no_token(self):
request = {}
Expand Down Expand Up @@ -504,13 +504,12 @@ def test_verify_per_client_per_endpoint(self):
)
assert res == {"method": "public", "client_id": client_id}

with pytest.raises(ClientAuthenticationError) as e:
verify_client(
self.endpoint_context,
request,
endpoint=self.server.server_get("endpoint", "endpoint_1"),
)
assert e.value.args[0] == "Failed to verify client"
res = verify_client(
self.endpoint_context,
request,
endpoint=self.server.server_get("endpoint", "endpoint_1"),
)
assert res == {}

request = {"client_id": client_id, "client_secret": client_secret}
res = verify_client(
Expand Down
15 changes: 7 additions & 8 deletions tests/test_server_20d_client_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def create_method(self):

def test_bearer_body(self):
request = {"access_token": "1234567890"}
assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"}
assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"}

def test_bearer_body_no_token(self):
request = {}
Expand Down Expand Up @@ -457,13 +457,12 @@ def test_verify_per_client_per_endpoint(self):
)
assert res == {"method": "public", "client_id": client_id}

with pytest.raises(ClientAuthenticationError) as e:
verify_client(
self.endpoint_context,
request,
endpoint=self.server.server_get("endpoint", "token"),
)
assert e.value.args[0] == "Failed to verify client"
res = verify_client(
self.endpoint_context,
request,
endpoint=self.server.server_get("endpoint", "token"),
)
assert res == {}

request = {"client_id": client_id, "client_secret": client_secret}
res = verify_client(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server_23_oidc_registration_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def create_endpoint(self):
"registration": {
"path": "registration",
"class": Registration,
"kwargs": {"client_auth_method": None},
"kwargs": {"client_authn_method": ["none"]},
},
"authorization": {
"path": "authorization",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server_32_oidc_read_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create_endpoint(self):
"registration": {
"path": "registration",
"class": Registration,
"kwargs": {"client_auth_method": None},
"kwargs": {"client_authn_method": ["none"]},
},
"registration_api": {
"path": "registration_api",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_server_60_dpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def create_endpoint(self):
"class": Authorization,
"kwargs": {},
},
"token": {"path": "{}/token", "class": Token, "kwargs": {}},
"token": {
"path": "{}/token",
"class": Token,
"kwargs": {"client_authn_method": ["none"]},
},
},
"client_authn": verify_client,
"authentication": {
Expand Down