diff --git a/apps/trigger/handler/impl/task/application_task.py b/apps/trigger/handler/impl/task/application_task.py index 307e2210a5f..d25fe33281d 100644 --- a/apps/trigger/handler/impl/task/application_task.py +++ b/apps/trigger/handler/impl/task/application_task.py @@ -6,16 +6,17 @@ @date:2026/1/14 19:14 @desc: """ +import json +from functools import reduce import uuid_utils.compat as uuid from django.db.models import QuerySet -from application.models import ChatUserType, Chat, ChatRecord, ChatSourceChoices +from application.models import ChatUserType, Chat, ChatRecord, ChatSourceChoices, Application from chat.serializers.chat import ChatSerializers from knowledge.models.knowledge_action import State - from trigger.handler.base_task import BaseTriggerTask -from trigger.models import TaskRecord +from trigger.models import TaskRecord, TriggerTask def get_reference(fields, obj): @@ -28,32 +29,82 @@ def get_reference(fields, obj): return obj -def get_field_value(value, kwargs): +def conversion_custom_value(value, _type): + if ['array', 'dict', 'float', 'int', 'boolean'].__contains__(_type): + return json.loads(value) + if _type == 'any': + try: + return json.loads(value) + except Exception as e: + pass + return value + + +def valid_value_type(value, _type): + if _type == 'array': + return isinstance(value, list) + if _type == 'dict': + return isinstance(value, dict) + if _type == 'float': + return isinstance(value, float) + if _type == 'int': + return isinstance(value, int) + if _type == 'boolean': + return isinstance(value, bool) + return isinstance(value, str) + + +def get_field_value(value, kwargs, _type, required, default_value, field): source = value.get('source') if source == 'custom': - return value.get('value') + _value = value.get('value') + if _value: + _value = conversion_custom_value(_value, _type) + else: + if default_value: + return default_value + if required: + raise Exception(f'{field} is required') + else: + return None else: - return get_reference(value.get('value'), kwargs) + _value = get_reference(value.get('value'), kwargs) + valid = valid_value_type(_value, _type) + if not valid: + raise Exception(f'{field} type error') + return _value -def get_application_execute_parameters(parameter_setting, kwargs): +def get_application_execute_parameters(parameter_setting, application_parameters_setting, kwargs): + many_field = ['api_input_field_list', 'user_input_field_list'] parameters = {'form_data': {}} - question_setting = parameter_setting.get('question') - if question_setting: - parameters['message'] = get_field_value(question_setting, kwargs) - filed_list = ['image_list', 'document_list', 'audio_list', 'video_list', 'other_list'] - for field in filed_list: - field_setting = parameter_setting.get(field) - if field_setting: - parameters[field] = get_field_value(field_setting, kwargs) - api_input_field_list = parameter_setting.get('api_input_field_list') - if api_input_field_list: - for key, value in api_input_field_list.items(): - parameters['form_data'][key] = get_field_value(value, kwargs) - user_input_field_list = parameter_setting.get('user_input_field_list') - if user_input_field_list: - for key, value in user_input_field_list.items(): - parameters['form_data'][key] = get_field_value(value, kwargs) + for key, value in application_parameters_setting.items(): + setting = parameter_setting.get(key) + if setting: + if many_field.__contains__(key): + for ck, cv in value.items(): + _setting = setting.get(ck) + if _setting: + _value = get_field_value(_setting, kwargs, cv.get('type'), cv.get('required'), + cv.get('default_value'), cv.get('field')) + parameters['form_data'][ck] = _value + else: + if cv.get('default_value'): + parameters['form_data'][ck] = cv.get('default_value') + else: + if cv.get('required'): + raise Exception(f'{ck} is required') + else: + value = get_field_value(setting, kwargs, value.get('type'), value.get('required'), + value.get('default_value'), value.get('field')) + parameters['message' if key == 'question' else key] = value + else: + if value.get('default_value'): + parameters['message' if key == 'question' else key] = value.get('default_value') + else: + if value.get('required'): + raise Exception(f'{"message" if key == "question" else key} is required') + return parameters @@ -76,13 +127,63 @@ def get_workflow_state(details): return State.SUCCESS +def get_user_field_component_input_type(input_type): + if input_type == "MultiRow": + return 'array' + if input_type == "SwitchInput": + return 'boolean' + return 'string' + + +def get_application_parameters_setting(application): + application_parameter_setting = {'question': { + 'required': True, + 'type': 'string' + }} + if application.type == 'SIMPLE': + return application_parameter_setting + else: + base_node_list = [n for n in application.work_flow.get('nodes') if n.get('type') == "base-node"] + if len(base_node_list) == 0: + raise Exception('错误的应用工作流信息') + base_node = base_node_list[0] + api_input_field_list = base_node.get('properties').get('api_input_field_list') or [] + api_input_field_list = {user_field.get('variable'): { + 'required': user_field.get('is_required'), + 'default_value': user_field.get('default_value'), + 'type': 'string' + } for user_field in api_input_field_list} + user_input_field_list = base_node.get('properties').get('user_input_field_list') or [] + user_input_field_list = {user_field.get('field'): { + 'required': user_field.get('required'), + 'default_value': user_field.get('default_value'), + 'type': get_user_field_component_input_type(user_field.get('input_type')) + } for user_field in user_input_field_list} + application_parameter_setting['api_input_field_list'] = api_input_field_list + application_parameter_setting['user_input_field_list'] = user_input_field_list + node_data = base_node.get('properties').get('node_data') or {} + file_upload_enable = node_data.get('file_upload_enable') + if file_upload_enable: + file_upload_setting = node_data.get('file_upload_setting') or {} + for field in ['audio', 'document', 'image', 'other', 'video']: + v = file_upload_setting.get(field) + if v: + application_parameter_setting[field] = {'required': False, 'default_value': [], 'type': 'array'} + return application_parameter_setting + + class ApplicationTask(BaseTriggerTask): def support(self, trigger_task, **kwargs): return trigger_task.get('source_type') == 'APPLICATION' def execute(self, trigger_task, **kwargs): parameter_setting = trigger_task.get('parameter') - parameters = get_application_execute_parameters(parameter_setting, kwargs) + application = QuerySet(Application).filter(id=trigger_task.get('source_id')).only('type', 'work_flow').first() + if application is None: + QuerySet(TriggerTask).filter(id=trigger_task.get('id')).delete() + return + application_parameters_setting = get_application_parameters_setting(application) + parameters = get_application_execute_parameters(parameter_setting, application_parameters_setting, kwargs) parameters['re_chat'] = False parameters['stream'] = True chat_id = uuid.uuid7()