diff --git a/pycti/api/opencti_api_client.py b/pycti/api/opencti_api_client.py index 07826224f..bbde7d68b 100644 --- a/pycti/api/opencti_api_client.py +++ b/pycti/api/opencti_api_client.py @@ -71,6 +71,25 @@ from pycti.utils.opencti_stix2_utils import OpenCTIStix2Utils +def build_request_headers(token: str, custom_headers: str, app_logger): + headers_dict = { + "User-Agent": "pycti/" + __version__, + "Authorization": "Bearer " + token, + } + # Build and add custom headers + if custom_headers is not None: + for header_pair in custom_headers.strip().split(";"): + if header_pair: # Skip empty header pairs + try: + key, value = header_pair.split(":", 1) + headers_dict[key.strip()] = value.strip() + except ValueError: + app_logger.warning( + "Ignored invalid header pair", {"header_pair": header_pair} + ) + return headers_dict + + class File: def __init__(self, name, data, mime="text/plain"): self.name = name @@ -99,24 +118,28 @@ class OpenCTIApiClient: ``` :param json_logging: format the logs as json if set to True :type json_logging: bool, optional + :param bundle_send_to_queue: if bundle will be sent to queue + :type bundle_send_to_queue: bool, optional :param cert: If String, file path to pem file. If Tuple, a ('path_to_cert.crt', 'path_to_key.key') pair representing the certificate and the key. :type cert: str, tuple, optional - :param auth: Add a AuthBase class with custom authentication for you OpenCTI infrastructure. - :type auth: requests.auth.AuthBase, optional + :param custom_headers: Add custom headers to use with the graphql queries + :type custom_headers: str, optional must in the format header01:value;header02:value + :param perform_health_check: if client init must check the api access + :type perform_health_check: bool, optional """ def __init__( self, url: str, token: str, - log_level="info", + log_level: str = "info", ssl_verify: Union[bool, str] = False, proxies: Union[Dict[str, str], None] = None, - json_logging=False, - bundle_send_to_queue=True, + json_logging: bool = False, + bundle_send_to_queue: bool = True, cert: Union[str, Tuple[str, str], None] = None, - auth=None, - perform_health_check=True, + custom_headers: str = None, + perform_health_check: bool = True, ): """Constructor method""" @@ -138,17 +161,10 @@ def __init__( # Define API self.api_token = token self.api_url = url + "/graphql" - self.request_headers = { - "User-Agent": "pycti/" + __version__, - "Authorization": "Bearer " + token, - } - - if auth is not None: - self.session = requests.session() - self.session.auth = auth - else: - self.session = requests.session() - + self.request_headers = build_request_headers( + token, custom_headers, self.app_logger + ) + self.session = requests.session() # Define the dependencies self.work = OpenCTIApiWork(self) self.playbook = OpenCTIApiPlaybook(self) diff --git a/pycti/connector/opencti_connector_helper.py b/pycti/connector/opencti_connector_helper.py index 5a187126f..2b157c676 100644 --- a/pycti/connector/opencti_connector_helper.py +++ b/pycti/connector/opencti_connector_helper.py @@ -902,6 +902,12 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.opencti_token = get_config_variable( "OPENCTI_TOKEN", ["opencti", "token"], config ) + self.opencti_custom_headers = get_config_variable( + "OPENCTI_CUSTOM_HEADERS", + ["opencti", "custom_headers"], + config, + default=None, + ) self.opencti_ssl_verify = get_config_variable( "OPENCTI_SSL_VERIFY", ["opencti", "ssl_verify"], config, False, False ) @@ -1078,6 +1084,7 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.log_level, self.opencti_ssl_verify, json_logging=self.opencti_json_logging, + custom_headers=self.opencti_custom_headers, bundle_send_to_queue=self.bundle_send_to_queue, ) # - Impersonate API that will use applicant id @@ -1088,6 +1095,7 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.log_level, self.opencti_ssl_verify, json_logging=self.opencti_json_logging, + custom_headers=self.opencti_custom_headers, bundle_send_to_queue=self.bundle_send_to_queue, ) self.connector_logger = self.api.logger_class(self.connect_name)