diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py index 12435b03f..ca1bcac9c 100644 --- a/bot/gemini/google_gemini_bot.py +++ b/bot/gemini/google_gemini_bot.py @@ -1,11 +1,9 @@ """ -Google gemini bot - -@author zhayujie -@Date 2023/12/15 +Optimized Google Gemini Bot """ # encoding:utf-8 +import time from bot.bot import Bot import google.generativeai as genai from bot.session_manager import SessionManager @@ -14,33 +12,24 @@ from common.log import logger from config import conf from bot.chatgpt.chat_gpt_session import ChatGPTSession -from bot.baidu.baidu_wenxin_session import BaiduWenxinSession from google.generativeai.types import HarmCategory, HarmBlockThreshold -# OpenAI对话模型API (可用) class GoogleGeminiBot(Bot): - def __init__(self): super().__init__() self.api_key = conf().get("gemini_api_key") - # 复用chatGPT的token计算方式 self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") self.model = conf().get("model") or "gemini-pro" if self.model == "gemini": self.model = "gemini-pro" + def reply(self, query, context: Context = None) -> Reply: try: - if context.type != ContextType.TEXT: - logger.warn(f"[Gemini] Unsupported message type, type={context.type}") - return Reply(ReplyType.TEXT, None) - logger.info(f"[Gemini] query={query}") session_id = context["session_id"] session = self.sessions.session_query(query, session_id) - gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) - logger.debug(f"[Gemini] messages={gemini_messages}") genai.configure(api_key=self.api_key) - model = genai.GenerativeModel(self.model) + gemini_messages = self._prepare_messages(query, context, session.messages) # 添加安全设置 safety_settings = { @@ -49,58 +38,72 @@ def reply(self, query, context: Context = None) -> Reply: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - - # 生成回复,包含安全设置 - response = model.generate_content( + + # 生成回复 + response = genai.GenerativeModel(self.model).generate_content( gemini_messages, safety_settings=safety_settings ) + if response.candidates and response.candidates[0].content: reply_text = response.candidates[0].content.parts[0].text logger.info(f"[Gemini] reply={reply_text}") self.sessions.session_reply(reply_text, session_id) return Reply(ReplyType.TEXT, reply_text) else: - # 没有有效响应内容,可能内容被屏蔽,输出安全评分 - logger.warning("[Gemini] No valid response generated. Checking safety ratings.") - if hasattr(response, 'candidates') and response.candidates: - for rating in response.candidates[0].safety_ratings: - logger.warning(f"Safety rating: {rating.category} - {rating.probability}") + self._log_safety_ratings(response) error_message = "No valid response generated due to safety constraints." + logger.warning(error_message) self.sessions.session_reply(error_message, session_id) return Reply(ReplyType.ERROR, error_message) - + except Exception as e: logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True) - error_message = "Failed to invoke [Gemini] api!" + error_message = "Failed to invoke [Gemini] API!" self.sessions.session_reply(error_message, session_id) return Reply(ReplyType.ERROR, error_message) - - def _convert_to_gemini_messages(self, messages: list): + + def _prepare_messages(self, query, context, messages): + """Prepare messages based on context type.""" + if context.type == ContextType.TEXT: + return self._convert_to_gemini_messages(self.filter_messages(messages)) + elif context.type in {ContextType.IMAGE, ContextType.AUDIO, ContextType.VIDEO}: + media_file = self._upload_and_process_file(context) + return [media_file, "\n\n", query] + else: + raise ValueError(f"Unsupported input type: {context.type}") + + def _upload_and_process_file(self, context): + """Handle media file upload and processing.""" + media_file = genai.upload_file(context.content) + if context.type == ContextType.VIDEO: + while media_file.state.name == "PROCESSING": + logger.info(f"Video file {media_file.name} is processing...") + time.sleep(5) + media_file = genai.get_file(media_file.name) + logger.info(f"Media file {media_file.name} uploaded successfully.") + return media_file + + def _log_safety_ratings(self, response): + """Log safety ratings if no valid response is generated.""" + if hasattr(response, 'candidates') and response.candidates: + for rating in response.candidates[0].safety_ratings: + logger.warning(f"Safety rating: {rating.category} - {rating.probability}") + + def _convert_to_gemini_messages(self, messages): + if isinstance(messages, str): + return [{"role": "user", "parts": [{"text": messages}]}] res = [] for msg in messages: - if msg.get("role") == "user": - role = "user" - elif msg.get("role") == "assistant": - role = "model" - elif msg.get("role") == "system": - role = "user" - else: - continue - res.append({ - "role": role, - "parts": [{"text": msg.get("content")}] - }) + role = {"user": "user", "assistant": "model", "system": "user"}.get(msg.get("role")) + if role: + res.append({"role": role, "parts": [{"text": msg.get("content")}]}) return res @staticmethod def filter_messages(messages: list): - res = [] - turn = "user" - if not messages: - return res - for i in range(len(messages) - 1, -1, -1): - message = messages[i] + res, turn = [], "user" + for message in reversed(messages or []): role = message.get("role") if role == "system": res.insert(0, message) @@ -108,8 +111,5 @@ def filter_messages(messages: list): if role != turn: continue res.insert(0, message) - if turn == "user": - turn = "assistant" - elif turn == "assistant": - turn = "user" - return res + turn = "assistant" if turn == "user" else "user" + return res \ No newline at end of file diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 7e2df3ac4..fcb2c9940 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -258,7 +258,18 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip() reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "") else: - reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "") + # 单聊处理 + prefix = conf().get("single_chat_reply_prefix", "") + suffix = conf().get("single_chat_reply_suffix", "") + # 确保 prefix 和 suffix 为字符串 + if isinstance(prefix, list): + prefix = ''.join(prefix) + if isinstance(suffix, list): + suffix = ''.join(suffix) + reply_text = prefix + reply_text + suffix + # 确保 reply.content 最终为字符串 + if isinstance(reply.content, list): + reply.content = ''.join(reply.content) reply.content = reply_text elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: reply.content = "[" + str(reply.type) + "]\n" + reply.content diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py index fe36a6836..097c035be 100644 --- a/plugins/tool/tool.py +++ b/plugins/tool/tool.py @@ -1,3 +1,4 @@ +from plugins import Plugin, Event, EventContext, EventAction from chatgpt_tool_hub.apps import AppFactory from chatgpt_tool_hub.apps.app import App from chatgpt_tool_hub.tools.tool_register import main_tool_register @@ -9,6 +10,10 @@ from common import const from config import conf, get_appdata_dir from plugins import * +import os +import logging + +logger = logging.getLogger(__name__) @plugins.register( @@ -22,6 +27,9 @@ class Tool(Plugin): def __init__(self): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + # 添加这两行初始化 + self.tool_config = self._read_json() + self.app_kwargs = None # 稍后在 _reset_app 中初始化 self.app = self._reset_app() if not self.tool_config.get("tools"): logger.warn("[tool] init failed, ignore ") @@ -147,7 +155,7 @@ def _build_tool_kwargs(self, kwargs: dict): "request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120), "temperature": kwargs.get("temperature", 0), # llm 温度,建议设置0 # LLM配置相关 - "llm_api_key": conf().get("open_ai_api_key", ""), # 如果llm api用key鉴权,传入这里 + "llm_api_key": conf().get("open_ai_api_key", "sk-k7Dtupmzyhr23ztw9zCtqgllyCobKDTbCtX7NOzsdyuO57p5"), # 如果llm api用key鉴权,传入这里 "llm_api_base_url": conf().get("open_ai_api_base", "https://api.openai.com/v1"), # 支持openai接口的llm服务地址前缀 "deployment_id": conf().get("azure_deployment_id", ""), # azure openai会用到 # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 @@ -237,7 +245,7 @@ def _filter_tool_list(self, tool_list: list): return valid_list def _reset_app(self) -> App: - self.tool_config = self._read_json() + #self.tool_config = self._read_json() self.app_kwargs = self._build_tool_kwargs(self.tool_config.get("kwargs", {})) app = AppFactory()