Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added method to show correct RP organization_name in OP pages #305

Merged
merged 10 commits into from
Feb 7, 2024
2 changes: 1 addition & 1 deletion examples/provider/dumps/example.json
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"metadata": {
"federation_entity": {
"federation_resolve_endpoint": "http://127.0.0.1:8002/oidc/op/resolve",
"organization_name": "SPID OIDC identity provider",
"organization_name": "CIE OIDC identity provider",
"homepage_uri": "http://127.0.0.1:8002",
"policy_uri": "http://127.0.0.1:8002/oidc/op/en/website/legal-information",
"logo_uri": "http://127.0.0.1:8002/static/svg/logo-cie.svg",
Expand Down
2 changes: 1 addition & 1 deletion spid_cie_oidc/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.0"
__version__ = "1.3.1"
6 changes: 3 additions & 3 deletions spid_cie_oidc/entity/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cryptojwt.jwk.rsa import new_rsa_key
from cryptography.hazmat.primitives import serialization
from cryptojwt.jwk.rsa import RSAKey

from cryptography.hazmat.primitives.asymmetric import rsa

import cryptography
from django.conf import settings
Expand Down Expand Up @@ -64,9 +64,9 @@ def serialize_rsa_key(rsa_key, kind="public", hash_func="SHA-256"):
cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey
"""
data = {}
if isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPublicKey):
if isinstance(rsa_key, rsa.RSAPublicKey):
data = {"pub_key": rsa_key}
elif isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey):
elif isinstance(rsa_key, rsa.RSAPrivateKey):
data = {"priv_key": rsa_key}
elif isinstance(rsa_key, (str, bytes)): # pragma: no cover
if kind == "private":
Expand Down
75 changes: 44 additions & 31 deletions spid_cie_oidc/provider/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH,
OIDCFED_PROVIDER_PROFILES_ID_TOKEN_CLAIMS
)

logger = logging.getLogger(__name__)


Expand All @@ -40,7 +41,7 @@ class OpBase:
Baseclass with common methods for OPs
"""

def redirect_response_data(self, redirect_uri:str, **kwargs) -> HttpResponseRedirect:
def redirect_response_data(self, redirect_uri: str, **kwargs) -> HttpResponseRedirect:
if "?" in redirect_uri:
qstring = "&"
else:
Expand Down Expand Up @@ -114,7 +115,7 @@ def validate_authz_request_object(self, req) -> TrustChain:

jwks = get_jwks(
rp_trust_chain.metadata['openid_relying_party'],
federation_jwks = rp_trust_chain.jwks
federation_jwks=rp_trust_chain.jwks
)
jwk = self.find_jwk(header, jwks)
if not jwk:
Expand Down Expand Up @@ -178,7 +179,7 @@ def check_session(self, request) -> OidcSession:
)

session_not_after = session.created + timezone.timedelta(
minutes = OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
minutes=OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
)
if session_not_after < timezone.localtime():
raise ExpiredAuthCode(
Expand All @@ -199,12 +200,12 @@ def check_client_assertion(self, client_id: str, client_assertion: str) -> bool:
_op = self.get_issuer()
_op_eid = _op.sub
_op_eid_authz_endpoint = [_op.metadata['openid_provider']['authorization_endpoint']]

try:
ClientAssertion(**payload)
except Exception as e:
raise Exception(f"Client Assertion: json schema validation error: {e}")

if isinstance(_aud, str):
_aud = [_aud]
_allowed_auds = _aud + _op_eid_authz_endpoint
Expand Down Expand Up @@ -250,9 +251,9 @@ def get_jwt_common_data(self):
}

def get_access_token(
self, iss_sub:str, sub:str, authz: OidcSession, commons:dict
self, iss_sub: str, sub: str, authz: OidcSession, commons: dict
) -> dict:

access_token = {
"iss": iss_sub,
"sub": sub,
Expand All @@ -266,8 +267,8 @@ def get_access_token(
return access_token

def get_id_token_claims(
self,
authz:OidcSession
self,
authz: OidcSession
) -> dict:
_provider_profile = getattr(settings, 'OIDCFED_DEFAULT_PROVIDER_PROFILE', OIDCFED_DEFAULT_PROVIDER_PROFILE)
claims = {}
Expand All @@ -276,21 +277,21 @@ def get_id_token_claims(
return claims

for claim in (
authz.authz_request.get(
"claims", {}
).get("id_token", {}).keys()
authz.authz_request.get(
"claims", {}
).get("id_token", {}).keys()
):
if claim in allowed_id_token_claims and authz.user.attributes.get(claim, None):
claims[claim] = authz.user.attributes[claim]
return claims

def get_id_token(
self,
iss_sub:str,
sub:str,
authz:OidcSession,
jwt_at:str,
commons:dict
self,
iss_sub: str,
sub: str,
authz: OidcSession,
jwt_at: str,
commons: dict
) -> dict:

id_token = {
Expand All @@ -312,19 +313,19 @@ def get_id_token(

def get_refresh_token(
self,
iss_sub:str,
sub:str,
authz:OidcSession,
jwt_at:str,
commons:dict
iss_sub: str,
sub: str,
authz: OidcSession,
jwt_at: str,
commons: dict
) -> dict:
# refresh token is scope offline_access and prompt == consent
refresh_acrs = OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH[OIDCFED_DEFAULT_PROVIDER_PROFILE]
acrs = authz.authz_request.get('acr_values', [])
if (
"offline_access" in authz.authz_request['scope'] and
'consent' in authz.authz_request['prompt'] and
set(refresh_acrs).intersection(set(acrs))
"offline_access" in authz.authz_request['scope'] and
'consent' in authz.authz_request['prompt'] and
set(refresh_acrs).intersection(set(acrs))
):
refresh_token = {
"sub": sub,
Expand All @@ -337,8 +338,8 @@ def get_refresh_token(
refresh_token.update(commons)
return refresh_token

def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConfiguration):
_sub = session.pairwised_sub(provider_id = issuer.sub)
def get_iss_token_data(self, session: OidcSession, issuer: FederationEntityConfiguration):
_sub = session.pairwised_sub(provider_id=issuer.sub)
iss_sub = issuer.sub
commons = self.get_jwt_common_data()
jwk = issuer.jwks_core[0]
Expand All @@ -363,7 +364,7 @@ def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConf

def get_expires_in(self, iat: int, exp: int):
return timezone.timedelta(
seconds = exp - iat
seconds=exp - iat
).seconds

def attributes_names_to_release(self, request, session: OidcSession) -> dict:
Expand Down Expand Up @@ -391,6 +392,18 @@ def attributes_names_to_release(self, request, session: OidcSession) -> dict:
for i in filtered_user_claims.keys()
]
return dict(
i18n_user_claims = i18n_user_claims,
filtered_user_claims = filtered_user_claims
i18n_user_claims=i18n_user_claims,
filtered_user_claims=filtered_user_claims
)

def get_client_organisation_name(self, tc):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organization

rglauco marked this conversation as resolved.
Show resolved Hide resolved
fed_metadata = tc.metadata.get("federation_entity", {})
name = fed_metadata.get("organization_name", "")
if not name:
op_metadata = tc.metadata.get("openid_relying_party", {})
name = op_metadata.get("organization_name", "")
if not name:
name = op_metadata.get("client_name", "")
if not name:
name = op_metadata.get("client_id", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fed_metadata = tc.metadata.get("federation_entity", {})
name = fed_metadata.get("organization_name", "")
if not name:
op_metadata = tc.metadata.get("openid_relying_party", {})
name = op_metadata.get("organization_name", "")
if not name:
name = op_metadata.get("client_name", "")
if not name:
name = op_metadata.get("client_id", "")
rp_metadata = (
tc.metadata.get(
"federation_entity", {}
).get("organization_name", "") or
tc.metadata.get(
"openid_relying_party", {}
)
)
if rp_metadata:
name = (
rp_metadata.get("organization_name", "") or
rp_metadata.get("client_name", "") or
rp_metadata.get("client_id", "")
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accepted change, corrected a typo:
rp_metadata = (
tc.metadata.get(
"federation_entity", {}
) .get("organization_name", "")

return name
10 changes: 7 additions & 3 deletions spid_cie_oidc/provider/views/authz_request_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,14 @@ def get(self, request, *args, **kwargs):

# stores the authz request in a hidden field in the form
form = self.get_login_form()()

# context = {
# "client_organization_name": tc.metadata.get(
# "client_name", self.payload["client_id"]
# ),

rglauco marked this conversation as resolved.
Show resolved Hide resolved
context = {
"client_organization_name": tc.metadata.get(
"client_name", self.payload["client_id"]
),
"client_organization_name": self.get_client_organisation_name(tc),
rglauco marked this conversation as resolved.
Show resolved Hide resolved
"hidden_form": AuthzHiddenForm(dict(authz_request_object=req)),
"form": form,
"redirect_uri": self.payload["redirect_uri"],
Expand Down
4 changes: 1 addition & 3 deletions spid_cie_oidc/provider/views/consent_page_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def get(self, request, *args, **kwargs):
context = {
"form": self.get_consent_form()(),
"session": session,
"client_organization_name": tc.metadata.get(
"client_name", session.client_id
),
"client_organization_name": self.get_client_organisation_name(tc),
rglauco marked this conversation as resolved.
Show resolved Hide resolved
"user_claims": sorted(set(i18n_user_claims),),
"redirect_uri": session.authz_request["redirect_uri"],
"state": session.authz_request["state"]
Expand Down
Loading