|
| 1 | +import importlib |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import os |
| 5 | +from typing import Dict |
| 6 | +from typing import List |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +from oidcmsg.logging import configure_logging |
| 10 | +from oidcmsg.util import load_yaml_config |
| 11 | + |
| 12 | +DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', |
| 13 | + 'private_path', 'public_path', 'db_file'] |
| 14 | + |
| 15 | +URIS = ["redirect_uris", 'issuer', 'base_url'] |
| 16 | + |
| 17 | + |
| 18 | +def lower_or_upper(config, param, default=None): |
| 19 | + res = config.get(param.lower(), default) |
| 20 | + if not res: |
| 21 | + res = config.get(param.upper(), default) |
| 22 | + return res |
| 23 | + |
| 24 | + |
| 25 | +def add_base_path(conf: dict, base_path: str, file_attributes: List[str]): |
| 26 | + for key, val in conf.items(): |
| 27 | + if key in file_attributes: |
| 28 | + if val.startswith("/"): |
| 29 | + continue |
| 30 | + elif val == "": |
| 31 | + conf[key] = "./" + val |
| 32 | + else: |
| 33 | + conf[key] = os.path.join(base_path, val) |
| 34 | + if isinstance(val, dict): |
| 35 | + conf[key] = add_base_path(val, base_path, file_attributes) |
| 36 | + |
| 37 | + return conf |
| 38 | + |
| 39 | + |
| 40 | +def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int): |
| 41 | + for key, val in conf.items(): |
| 42 | + if key in uris: |
| 43 | + if not val: |
| 44 | + continue |
| 45 | + |
| 46 | + if isinstance(val, list): |
| 47 | + _new = [v.format(domain=domain, port=port) for v in val] |
| 48 | + else: |
| 49 | + _new = val.format(domain=domain, port=port) |
| 50 | + conf[key] = _new |
| 51 | + elif isinstance(val, dict): |
| 52 | + conf[key] = set_domain_and_port(val, uris, domain, port) |
| 53 | + return conf |
| 54 | + |
| 55 | + |
| 56 | +class Base: |
| 57 | + """ Configuration base class """ |
| 58 | + |
| 59 | + def __init__(self, |
| 60 | + conf: Dict, |
| 61 | + base_path: str = '', |
| 62 | + file_attributes: Optional[List[str]] = None, |
| 63 | + ): |
| 64 | + |
| 65 | + if file_attributes is None: |
| 66 | + file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES |
| 67 | + |
| 68 | + if base_path and file_attributes: |
| 69 | + # this adds a base path to all paths in the configuration |
| 70 | + add_base_path(conf, base_path, file_attributes) |
| 71 | + |
| 72 | + def __getitem__(self, item): |
| 73 | + if item in self.__dict__: |
| 74 | + return self.__dict__[item] |
| 75 | + else: |
| 76 | + raise KeyError |
| 77 | + |
| 78 | + def get(self, item, default=None): |
| 79 | + return getattr(self, item, default) |
| 80 | + |
| 81 | + def __contains__(self, item): |
| 82 | + return item in self.__dict__ |
| 83 | + |
| 84 | + def items(self): |
| 85 | + for key in self.__dict__: |
| 86 | + if key.startswith('__') and key.endswith('__'): |
| 87 | + continue |
| 88 | + yield key, getattr(self, key) |
| 89 | + |
| 90 | + def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): |
| 91 | + for econf in entity_conf: |
| 92 | + _path = econf.get("path") |
| 93 | + _cnf = conf |
| 94 | + if _path: |
| 95 | + for step in _path: |
| 96 | + _cnf = _cnf[step] |
| 97 | + _attr = econf["attr"] |
| 98 | + _cls = econf["class"] |
| 99 | + setattr(self, _attr, |
| 100 | + _cls(_cnf, base_path=base_path, file_attributes=file_attributes, |
| 101 | + domain=domain, port=port)) |
| 102 | + |
| 103 | + |
| 104 | +class Configuration(Base): |
| 105 | + """Server Configuration""" |
| 106 | + |
| 107 | + def __init__(self, |
| 108 | + conf: Dict, |
| 109 | + base_path: str = '', |
| 110 | + entity_conf: Optional[List[dict]] = None, |
| 111 | + file_attributes: Optional[List[str]] = None, |
| 112 | + domain: Optional[str] = "", |
| 113 | + port: Optional[int] = 0, |
| 114 | + ): |
| 115 | + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) |
| 116 | + |
| 117 | + log_conf = conf.get('logging') |
| 118 | + if log_conf: |
| 119 | + self.logger = configure_logging(config=log_conf).getChild(__name__) |
| 120 | + else: |
| 121 | + self.logger = logging.getLogger('oidcrp') |
| 122 | + |
| 123 | + self.web_conf = lower_or_upper(conf, "webserver") |
| 124 | + |
| 125 | + # entity info |
| 126 | + if not domain: |
| 127 | + domain = conf.get("domain", "127.0.0.1") |
| 128 | + |
| 129 | + if not port: |
| 130 | + port = conf.get("port", 80) |
| 131 | + |
| 132 | + if entity_conf: |
| 133 | + self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, |
| 134 | + file_attributes=file_attributes, domain=domain, port=port) |
| 135 | + |
| 136 | + |
| 137 | +def create_from_config_file(cls, |
| 138 | + filename: str, |
| 139 | + base_path: Optional[str] = '', |
| 140 | + entity_conf: Optional[List[dict]] = None, |
| 141 | + file_attributes: Optional[List[str]] = None, |
| 142 | + domain: Optional[str] = "", |
| 143 | + port: Optional[int] = 0): |
| 144 | + if filename.endswith(".yaml"): |
| 145 | + """Load configuration as YAML""" |
| 146 | + _cnf = load_yaml_config(filename) |
| 147 | + elif filename.endswith(".json"): |
| 148 | + _str = open(filename).read() |
| 149 | + _cnf = json.loads(_str) |
| 150 | + elif filename.endswith(".py"): |
| 151 | + head, tail = os.path.split(filename) |
| 152 | + tail = tail[:-3] |
| 153 | + module = importlib.import_module(tail) |
| 154 | + _cnf = getattr(module, "CONFIG") |
| 155 | + else: |
| 156 | + raise ValueError("Unknown file type") |
| 157 | + |
| 158 | + return cls(_cnf, |
| 159 | + entity_conf=entity_conf, |
| 160 | + base_path=base_path, file_attributes=file_attributes, |
| 161 | + domain=domain, port=port) |
0 commit comments