Skip to content

Commit

Permalink
Merge pull request #9 from UniversitaDellaCalabria/dev
Browse files Browse the repository at this point in the history
v0.5.0
  • Loading branch information
peppelinux authored Dec 14, 2021
2 parents 7a58b38 + 8f97ecb commit a3cc919
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
2 changes: 1 addition & 1 deletion satosa_oidcop/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.7"
__version__ = "0.5.0"
45 changes: 17 additions & 28 deletions satosa_oidcop/core/storage/mongo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import copy
import datetime
import json
import logging
import pymongo

Expand Down Expand Up @@ -44,14 +45,7 @@ def _connect(self):

def get_client_by_id(self, client_id: str):
self._connect()
res = self.client_db.find({"client_id": client_id})

# improvement: unique index on client_id in client collection
if res.count():
# it returns the first one
return res.next()
else:
return {}
return self.client_db.find_one({"client_id": client_id}) or {}

def store_session_to_db(self, session_manager: SessionManager, claims: dict):
ses_man_dump = session_manager.dump()
Expand All @@ -68,7 +62,7 @@ def store_session_to_db(self, session_manager: SessionManager, claims: dict):
"id_token": "",
"refresh_token": "",
"claims": claims or {},
"dump": _db,
"dump": json.dumps(_db),
"key": ses_man_dump["key"],
"salt": ses_man_dump["salt"],
}
Expand Down Expand Up @@ -105,13 +99,13 @@ def store_session_to_db(self, session_manager: SessionManager, claims: dict):

self._connect()
q = {"grant_id": data["grant_id"]}
grant = self.session_db.find(q)
if grant.count():
grant = self.session_db.find_one(q)
if grant:
# if update preserve the claims
data["claims"] = grant.next()["claims"]
data["claims"] = grant["claims"]
self.session_db.update_one(q, {"$set": data})
else:
self.session_db.insert(data, check_keys=False)
self.session_db.insert_one(data)

def load_session_from_db(
self, parse_req, http_headers: dict, session_manager: SessionManager, **kwargs
Expand Down Expand Up @@ -156,21 +150,20 @@ def load_session_from_db(
return data

self._connect()
res = self.session_db.find(_q)
if res.count():
_data = res.next()
data["key"] = _data["key"]
data["salt"] = _data["salt"]
data["db"] = _data["dump"]
res = self.session_db.find_one(_q)
if res:
data["key"] = res["key"]
data["salt"] = res["salt"]
data["db"] = json.loads(res["dump"])
session_manager.flush()
session_manager.load(data)
return data

def get_claims_from_sid(self, sid: str):
self._connect()
res = self.session_db.find({"sid": sid})
if res.count():
return res.next()["claims"]
res = self.session_db.find_one({"sid": sid})
if res:
return res["claims"]

def insert_client(self, client_data: dict):
_client_data = copy.deepcopy(client_data)
Expand All @@ -180,7 +173,7 @@ def insert_client(self, client_data: dict):
logger.warning(
f"OIDC Client {client_id} already present in the client db")
return
self.client_db.insert(_client_data)
self.client_db.insert_one(_client_data)

def get_client_by_basic_auth(self, request_authorization: str):
cred = base64.b64decode(
Expand All @@ -194,11 +187,7 @@ def get_client_by_basic_auth(self, request_authorization: str):
client_secret = cred[1]

self._connect()
res = self.client_db.find(
{"client_id": client_id, "client_secret": client_secret}
)
if res.count():
return res.next()
return self.client_db.find_one({"client_id": client_id, "client_secret": client_secret})

def get_registered_clients_id(self):
self._connect()
Expand Down

0 comments on commit a3cc919

Please sign in to comment.