|
| 1 | +import copy |
| 2 | +from urllib.parse import quote_plus |
| 3 | + |
| 4 | +from cryptojwt import KeyJar |
| 5 | +from cryptojwt.key_jar import init_key_jar |
| 6 | + |
| 7 | +from oidcmsg.message import Message |
| 8 | +from oidcmsg.storage.init import get_storage_class |
| 9 | +from oidcmsg.storage.init import get_storage_conf |
| 10 | +from oidcmsg.storage.init import init_storage |
| 11 | + |
| 12 | + |
| 13 | +def add_issuer(conf, issuer): |
| 14 | + res = {} |
| 15 | + for key, val in conf.items(): |
| 16 | + if key == 'abstract_storage_cls': |
| 17 | + res[key] = val |
| 18 | + else: |
| 19 | + _val = copy.deepcopy(val) |
| 20 | + _val['issuer'] = quote_plus(issuer) |
| 21 | + res[key] = _val |
| 22 | + return res |
| 23 | + |
| 24 | + |
| 25 | +class OidcContext: |
| 26 | + def __init__(self, config=None, keyjar=None, entity_id=''): |
| 27 | + if config is None: |
| 28 | + config = {} |
| 29 | + |
| 30 | + self.db_conf = config.get('db_conf') |
| 31 | + if self.db_conf: |
| 32 | + _iss = config.get('issuer') |
| 33 | + if _iss: |
| 34 | + self.db_conf = add_issuer(self.db_conf, _iss) |
| 35 | + self.storage_cls = get_storage_class(self.db_conf) |
| 36 | + |
| 37 | + self.db = init_storage(self.db_conf) |
| 38 | + self.keyjar = self._keyjar(keyjar, self.db_conf, config, entity_id=entity_id) |
| 39 | + |
| 40 | + def add_boxes(self, boxes, db_conf): |
| 41 | + for key, attr in boxes.items(): |
| 42 | + setattr(self, attr, init_storage(db_conf, key)) |
| 43 | + |
| 44 | + def _keyjar(self, keyjar=None, db_conf=None, conf=None, entity_id=''): |
| 45 | + if keyjar is None: |
| 46 | + if db_conf: |
| 47 | + storage_args = { |
| 48 | + 'abstract_storage_cls': self.storage_cls, |
| 49 | + 'storage_conf': get_storage_conf(db_conf, 'keyjar') |
| 50 | + } |
| 51 | + else: |
| 52 | + storage_args = {} |
| 53 | + |
| 54 | + if 'keys' in conf: |
| 55 | + args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} |
| 56 | + args.update(storage_args) |
| 57 | + _keyjar = init_key_jar(**args) |
| 58 | + else: |
| 59 | + _keyjar = KeyJar(**storage_args) |
| 60 | + if 'jwks' in conf: |
| 61 | + _keyjar.import_jwks(conf['jwks'], '') |
| 62 | + |
| 63 | + if '' in _keyjar and entity_id: |
| 64 | + # make sure I have the keys under my own name too (if I know it) |
| 65 | + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ''), entity_id) |
| 66 | + |
| 67 | + return _keyjar |
| 68 | + else: |
| 69 | + return keyjar |
| 70 | + |
| 71 | + |
| 72 | + def set(self, item, value): |
| 73 | + if isinstance(value, Message): |
| 74 | + self.db[item] = value.to_dict() |
| 75 | + else: |
| 76 | + self.db[item] = value |
| 77 | + |
| 78 | + |
| 79 | + def get(self, item): |
| 80 | + if item == 'seed': |
| 81 | + return bytes(self.db[item], 'utf-8') |
| 82 | + else: |
| 83 | + return self.db[item] |
0 commit comments