diff --git a/pycti/api/opencti_api_client.py b/pycti/api/opencti_api_client.py index 6b39ec86d..8b083383a 100644 --- a/pycti/api/opencti_api_client.py +++ b/pycti/api/opencti_api_client.py @@ -751,6 +751,7 @@ def send_bundle_to_api(self, **kwargs): """ connector_id = kwargs.get("connector_id", None) + work_id = kwargs.get("work_id", None) bundle = kwargs.get("bundle", None) if connector_id is not None and bundle is not None: @@ -758,13 +759,13 @@ def send_bundle_to_api(self, **kwargs): "Pushing a bundle to queue through API", {connector_id} ) mutation = """ - mutation StixBundlePush($connectorId: String!, $bundle: String!) { - stixBundlePush(connectorId: $connectorId, bundle: $bundle) + mutation StixBundlePush($connectorId: String!, $bundle: String!, $work_id: String) { + stixBundlePush(connectorId: $connectorId, bundle: $bundle, work_id: $work_id) } """ return self.query( mutation, - {"connectorId": connector_id, "bundle": bundle}, + {"connectorId": connector_id, "bundle": bundle, "work_id": work_id}, ) else: self.app_logger.error( diff --git a/pycti/api/opencti_api_connector.py b/pycti/api/opencti_api_connector.py index cabf9d0c3..09dc79016 100644 --- a/pycti/api/opencti_api_connector.py +++ b/pycti/api/opencti_api_connector.py @@ -72,6 +72,7 @@ def list(self) -> Dict: } listen listen_exchange + listen_callback_uri push push_exchange push_routing diff --git a/pycti/connector/opencti_connector.py b/pycti/connector/opencti_connector.py index 114874434..1601dcb22 100644 --- a/pycti/connector/opencti_connector.py +++ b/pycti/connector/opencti_connector.py @@ -43,6 +43,7 @@ def __init__( auto: bool, only_contextual: bool, playbook_compatible: bool, + listen_callback_uri=None, ): self.id = connector_id self.name = connector_name @@ -56,6 +57,7 @@ def __init__( self.auto = auto self.only_contextual = only_contextual self.playbook_compatible = playbook_compatible + self.listen_callback_uri = listen_callback_uri def to_input(self) -> dict: """connector input to use in API query @@ -72,5 +74,6 @@ def to_input(self) -> dict: "auto": self.auto, "only_contextual": self.only_contextual, "playbook_compatible": self.playbook_compatible, + "listen_callback_uri": self.listen_callback_uri, } } diff --git a/pycti/connector/opencti_connector_helper.py b/pycti/connector/opencti_connector_helper.py index 0441d88e2..f188c59b3 100644 --- a/pycti/connector/opencti_connector_helper.py +++ b/pycti/connector/opencti_connector_helper.py @@ -18,6 +18,9 @@ from typing import Callable, Dict, List, Optional, Union import pika +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from filigran_sseclient import SSEClient from pika.exceptions import NackError, UnroutableError from pydantic import TypeAdapter @@ -30,6 +33,8 @@ TRUTHY: List[str] = ["yes", "true", "True"] FALSY: List[str] = ["no", "false", "False"] +app = FastAPI() + def killProgramHook(etype, value, tb): os.kill(os.getpid(), signal.SIGTERM) @@ -141,6 +146,35 @@ def ssl_cert_chain(ssl_context, cert_data, key_data, passphrase): os.unlink(key_file_path) +def create_callback_ssl_context(config) -> ssl.SSLContext: + listen_protocol_api_ssl_key = get_config_variable( + "LISTEN_PROTOCOL_API_SSL_KEY", + ["connector", "listen_protocol_api_ssl_key"], + config, + default="", + ) + listen_protocol_api_ssl_cert = get_config_variable( + "LISTEN_PROTOCOL_API_SSL_CERT", + ["connector", "listen_protocol_api_ssl_cert"], + config, + default="", + ) + listen_protocol_api_ssl_passphrase = get_config_variable( + "LISTEN_PROTOCOL_API_SSL_PASSPHRASE", + ["connector", "listen_protocol_api_ssl_passphrase"], + config, + default="", + ) + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_cert_chain( + ssl_context, + listen_protocol_api_ssl_cert, + listen_protocol_api_ssl_key, + listen_protocol_api_ssl_passphrase, + ) + return ssl_context + + def create_mq_ssl_context(config) -> ssl.SSLContext: use_ssl_ca = get_config_variable("MQ_USE_SSL_CA", ["mq", "use_ssl_ca"], config) use_ssl_cert = get_config_variable( @@ -183,9 +217,14 @@ class ListenQueue(threading.Thread): def __init__( self, helper, + opencti_token, config: Dict, connector_config: Dict, applicant_id, + listen_protocol, + listen_protocol_api_ssl, + listen_protocol_api_path, + listen_protocol_api_port, callback, ) -> None: threading.Thread.__init__(self) @@ -196,6 +235,11 @@ def __init__( self.helper = helper self.callback = callback self.config = config + self.opencti_token = opencti_token + self.listen_protocol = listen_protocol + self.listen_protocol_api_ssl = listen_protocol_api_ssl + self.listen_protocol_api_path = listen_protocol_api_path + self.listen_protocol_api_port = listen_protocol_api_port self.connector_applicant_id = applicant_id self.host = connector_config["connection"]["host"] self.vhost = connector_config["connection"]["vhost"] @@ -375,52 +419,122 @@ def _data_handler(self, json_data) -> None: "Failing reporting the processing" ) + async def _http_process_callback(self, request: Request): + # 01. Check the authentication + authorization: str = request.headers.get("Authorization", "") + items = authorization.split() if isinstance(authorization, str) else [] + if ( + len(items) != 2 + or items[0].lower() != "bearer" + or items[1] != self.opencti_token + ): + return JSONResponse( + status_code=401, content={"error": "Invalid credentials"} + ) + # 02. Parse the data and execute + try: + data = await request.json() # Get the JSON payload + except json.JSONDecodeError as e: + self.helper.connector_logger.error( + "Invalid JSON payload", {"cause": str(e)} + ) + return JSONResponse( + status_code=400, + content={"error": "Invalid JSON payload"}, + ) + try: + self._data_handler(data) + except Exception as e: + self.helper.connector_logger.error( + "Error processing message", {"cause": str(e)} + ) + return JSONResponse( + status_code=500, + content={"error": "Error processing message"}, + ) + # all good + return JSONResponse( + status_code=202, content={"message": "Message successfully received"} + ) + def run(self) -> None: - self.helper.connector_logger.info("Starting ListenQueue thread") - while not self.exit_event.is_set(): - try: - self.helper.connector_logger.info("ListenQueue connecting to rabbitMq.") - # Connect the broker - self.pika_credentials = pika.PlainCredentials(self.user, self.password) - self.pika_parameters = pika.ConnectionParameters( - heartbeat=10, - blocked_connection_timeout=30, - host=self.host, - port=self.port, - virtual_host=self.vhost, - credentials=self.pika_credentials, - ssl_options=( - pika.SSLOptions(create_mq_ssl_context(self.config), self.host) - if self.use_ssl - else None - ), - ) - self.pika_connection = pika.BlockingConnection(self.pika_parameters) - self.channel = self.pika_connection.channel() + if self.listen_protocol == "AMQP": + self.helper.connector_logger.info("Starting ListenQueue thread") + while not self.exit_event.is_set(): try: - # confirm_delivery is only for cluster mode rabbitMQ - # when not in cluster mode this line raise an exception - self.channel.confirm_delivery() + self.helper.connector_logger.info( + "ListenQueue connecting to rabbitMq." + ) + # Connect the broker + self.pika_credentials = pika.PlainCredentials( + self.user, self.password + ) + self.pika_parameters = pika.ConnectionParameters( + heartbeat=10, + blocked_connection_timeout=30, + host=self.host, + port=self.port, + virtual_host=self.vhost, + credentials=self.pika_credentials, + ssl_options=( + pika.SSLOptions( + create_mq_ssl_context(self.config), self.host + ) + if self.use_ssl + else None + ), + ) + self.pika_connection = pika.BlockingConnection(self.pika_parameters) + self.channel = self.pika_connection.channel() + try: + # confirm_delivery is only for cluster mode rabbitMQ + # when not in cluster mode this line raise an exception + self.channel.confirm_delivery() + except Exception as err: # pylint: disable=broad-except + self.helper.connector_logger.debug(str(err)) + self.channel.basic_qos(prefetch_count=1) + assert self.channel is not None + self.channel.basic_consume( + queue=self.queue_name, on_message_callback=self._process_message + ) + self.channel.start_consuming() except Exception as err: # pylint: disable=broad-except - self.helper.connector_logger.debug(str(err)) - self.channel.basic_qos(prefetch_count=1) - assert self.channel is not None - self.channel.basic_consume( - queue=self.queue_name, on_message_callback=self._process_message - ) - self.channel.start_consuming() - except Exception as err: # pylint: disable=broad-except - try: - self.pika_connection.close() - except Exception as errInException: - self.helper.connector_logger.debug( - type(errInException).__name__, {"reason": str(errInException)} + try: + self.pika_connection.close() + except Exception as errInException: + self.helper.connector_logger.debug( + type(errInException).__name__, + {"reason": str(errInException)}, + ) + self.helper.connector_logger.error( + type(err).__name__, {"reason": str(err)} ) - self.helper.connector_logger.error( - type(err).__name__, {"reason": str(err)} - ) - # Wait some time and then retry ListenQueue again. - time.sleep(10) + # Wait some time and then retry ListenQueue again. + time.sleep(10) + elif self.listen_protocol == "API": + self.helper.connector_logger.info("Starting Listen HTTP thread") + app.add_api_route( + self.listen_protocol_api_path, + self._http_process_callback, + methods=["POST"], + ) + config = uvicorn.Config( + app, + host="0.0.0.0", + port=self.listen_protocol_api_port, + reload=False, + log_config=None, + log_level=None, + ) + config.load() # Manually calling the .load() to trigger needed actions outside HTTPS + if self.listen_protocol_api_ssl: + ssl_ctx = create_callback_ssl_context(self.config) + config.ssl = ssl_ctx + server = uvicorn.Server(config) + server.run() + + else: + raise ValueError("Unsupported listen protocol type") def stop(self): self.helper.connector_logger.info("Preparing ListenQueue for clean shutdown") @@ -790,8 +904,39 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.connect_id = get_config_variable( "CONNECTOR_ID", ["connector", "id"], config ) - self.queue_protocol = get_config_variable( - "QUEUE_PROTOCOL", ["connector", "queue_protocol"], config, default="amqp" + self.listen_protocol = get_config_variable( + "CONNECTOR_LISTEN_PROTOCOL", + ["connector", "listen_protocol"], + config, + default="AMQP", + ).upper() + self.listen_protocol_api_port = get_config_variable( + "CONNECTOR_LISTEN_PROTOCOL_API_PORT", + ["connector", "listen_protocol_api_port"], + config, + default=7070, + ) + self.listen_protocol_api_path = get_config_variable( + "CONNECTOR_LISTEN_PROTOCOL_API_PATH", + ["connector", "listen_protocol_api_path"], + config, + default="/api/callback", + ) + self.listen_protocol_api_ssl = get_config_variable( + "CONNECTOR_LISTEN_PROTOCOL_API_SSL", + ["connector", "listen_protocol_api_ssl"], + config, + default=False, + ) + self.listen_protocol_api_uri = get_config_variable( + "CONNECTOR_LISTEN_PROTOCOL_API_URI", + ["connector", "listen_protocol_api_uri"], + config, + default=( + "https://127.0.0.1:7070" + if self.listen_protocol_api_ssl + else "http://127.0.0.1:7070" + ), ) self.connect_type = get_config_variable( "CONNECTOR_TYPE", ["connector", "type"], config @@ -957,6 +1102,7 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.connect_auto, self.connect_only_contextual, playbook_compatible, + self.listen_protocol_api_uri + self.listen_protocol_api_path, ) connector_configuration = self.api.connector.register(self.connector) self.connector_logger.info( @@ -972,6 +1118,25 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.connector_state = connector_configuration["connector_state"] self.connector_config = connector_configuration["config"] + # Configure the push information protocol + self.queue_protocol = get_config_variable( + env_var="CONNECTOR_QUEUE_PROTOCOL", + yaml_path=["connector", "queue_protocol"], + config=config, + ) + if not self.queue_protocol: # for backwards compatibility + self.queue_protocol = get_config_variable( + env_var="QUEUE_PROTOCOL", + yaml_path=["connector", "queue_protocol"], + config=config, + ) + if self.queue_protocol: + self.connector_logger.error( + "QUEUE_PROTOCOL is deprecated, please use CONNECTOR_QUEUE_PROTOCOL instead." + ) + if not self.queue_protocol: + self.queue_protocol = "amqp" + # Overwrite connector config for RabbitMQ if given manually / in conf self.connector_config["connection"]["host"] = get_config_variable( "MQ_HOST", @@ -1441,9 +1606,14 @@ def listen( self.listen_queue = ListenQueue( self, + self.opencti_token, self.config, self.connector_config, self.applicant_id, + self.listen_protocol, + self.listen_protocol_api_ssl, + self.listen_protocol_api_path, + self.listen_protocol_api_port, message_callback, ) self.listen_queue.start() @@ -1742,13 +1912,13 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list: raise ValueError("Nothing to import") if bundle_send_to_queue: - if work_id: - self.api.work.add_expectations(work_id, expectations_number) - if draft_id: - self.api.work.add_draft_context(work_id, draft_id) + if work_id and draft_id: + self.api.work.add_draft_context(work_id, draft_id) if entities_types is None: entities_types = [] if self.queue_protocol == "amqp": + if work_id: + self.api.work.add_expectations(work_id, expectations_number) pika_credentials = pika.PlainCredentials( self.connector_config["connection"]["user"], self.connector_config["connection"]["pass"], @@ -1791,7 +1961,7 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list: pika_connection.close() elif self.queue_protocol == "api": self.api.send_bundle_to_api( - connector_id=self.connector_id, bundle=bundle + connector_id=self.connector_id, bundle=bundle, work_id=work_id ) else: raise ValueError( diff --git a/requirements.txt b/requirements.txt index 5f6da73b2..5853077d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,8 @@ prometheus-client~=0.21.1 opentelemetry-api>=1.22.0,<=1.30.0 opentelemetry-sdk>=1.22.0,<=1.30.0 deprecation~=2.1.0 +fastapi>=0.115.8,<0.116.0 +uvicorn[standard]>=0.33.0,<0.35.0 # OpenCTI filigran-sseclient>=1.0.2 stix2~=3.0.1 \ No newline at end of file