33import json
44import logging
55import random
6+ import os
7+ import mimetypes
68from typing import Dict , List , Optional
79from collections .abc import AsyncGenerator
810
@@ -193,6 +195,12 @@ def process_image_url(image_url_dict: dict) -> types.Part:
193195 mime_type = url .split (":" )[1 ].split (";" )[0 ]
194196 image_bytes = base64 .b64decode (url .split ("," , 1 )[1 ])
195197 return types .Part .from_bytes (data = image_bytes , mime_type = mime_type )
198+
199+ def process_inline_data (inline_data_dict : dict ) -> types .Part :
200+ """处理内联数据,如音频""" # TODO: 处理视频?
201+ mime_type = inline_data_dict ["mime_type" ]
202+ data = inline_data_dict .get ("data" , "" )
203+ return types .Part .from_bytes (data = data , mime_type = mime_type )
196204
197205 def append_or_extend (contents : list [types .Content ], part : list [types .Part ], content_cls : type [types .Content ]) -> None :
198206 if contents and isinstance (contents [- 1 ], content_cls ):
@@ -212,12 +220,15 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont
212220
213221 if role == "user" :
214222 if isinstance (content , list ):
215- parts = [
216- types .Part .from_text (text = item ["text" ] or " " )
217- if item ["type" ] == "text"
218- else process_image_url (item ["image_url" ])
219- for item in content
220- ]
223+ parts = []
224+ for item in content :
225+ if item ["type" ] == "text" :
226+ parts .append (types .Part .from_text (text = item ["text" ] or " " ))
227+ elif item ["type" ] == "image_url" :
228+ parts .append (process_image_url (item ["image_url" ]))
229+ elif item ["type" ] == "inline_data" :
230+ # 处理内联数据,如音频
231+ parts .append (process_inline_data (item ["inline_data" ]))
221232 else :
222233 parts = [create_text_part (content )]
223234 append_or_extend (gemini_contents , parts , types .UserContent )
@@ -447,13 +458,14 @@ async def text_chat(
447458 prompt : str ,
448459 session_id : str = None ,
449460 image_urls : List [str ] = None ,
461+ audio_urls : List [str ] = None ,
450462 func_tool : FuncCall = None ,
451463 contexts = [],
452464 system_prompt = None ,
453465 tool_calls_result = None ,
454466 ** kwargs ,
455467 ) -> LLMResponse :
456- new_record = await self .assemble_context (prompt , image_urls )
468+ new_record = await self .assemble_context (prompt , image_urls , audio_urls )
457469 context_query = [* contexts , new_record ]
458470 if system_prompt :
459471 context_query .insert (0 , {"role" : "system" , "content" : system_prompt })
@@ -486,14 +498,15 @@ async def text_chat_stream(
486498 self ,
487499 prompt : str ,
488500 session_id : str = None ,
489- image_urls : List [str ] = [],
501+ image_urls : List [str ] = None ,
502+ audio_urls : List [str ] = None ,
490503 func_tool : FuncCall = None ,
491504 contexts = [],
492505 system_prompt = None ,
493506 tool_calls_result = None ,
494507 ** kwargs ,
495508 ) -> AsyncGenerator [LLMResponse , None ]:
496- new_record = await self .assemble_context (prompt , image_urls )
509+ new_record = await self .assemble_context (prompt , image_urls , audio_urls )
497510 context_query = [* contexts , new_record ]
498511 if system_prompt :
499512 context_query .insert (0 , {"role" : "system" , "content" : system_prompt })
@@ -545,30 +558,55 @@ def set_key(self, key):
545558 self .chosen_api_key = key
546559 self ._init_client ()
547560
548- async def assemble_context (self , text : str , image_urls : List [str ] = None ):
561+ async def assemble_context (self , text : str , image_urls : List [str ] = None , audio_urls : List [ str ] = None ):
549562 """
550563 组装上下文。
551564 """
552- if image_urls :
565+ has_media = (image_urls and len (image_urls ) > 0 ) or (audio_urls and len (audio_urls ) > 0 )
566+
567+ if has_media :
553568 user_content = {
554569 "role" : "user" ,
555- "content" : [{"type" : "text" , "text" : text if text else "[图片 ]" }],
570+ "content" : [{"type" : "text" , "text" : text if text else "[媒体内容 ]" }],
556571 }
557- for image_url in image_urls :
558- if image_url .startswith ("http" ):
559- image_path = await download_image_by_url (image_url )
560- image_data = await self .encode_image_bs64 (image_path )
561- elif image_url .startswith ("file:///" ):
562- image_path = image_url .replace ("file:///" , "" )
563- image_data = await self .encode_image_bs64 (image_path )
564- else :
565- image_data = await self .encode_image_bs64 (image_url )
566- if not image_data :
567- logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
568- continue
569- user_content ["content" ].append (
570- {"type" : "image_url" , "image_url" : {"url" : image_data }}
571- )
572+
573+ # 处理图片
574+ if image_urls :
575+ for image_url in image_urls :
576+ if image_url .startswith ("http" ):
577+ image_path = await download_image_by_url (image_url )
578+ image_data = await self .encode_image_bs64 (image_path )
579+ elif image_url .startswith ("file:///" ):
580+ image_path = image_url .replace ("file:///" , "" )
581+ image_data = await self .encode_image_bs64 (image_path )
582+ else :
583+ image_data = await self .encode_image_bs64 (image_url )
584+ if not image_data :
585+ logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
586+ continue
587+ user_content ["content" ].append (
588+ {"type" : "image_url" , "image_url" : {"url" : image_data }}
589+ )
590+
591+ # 处理音频
592+ if audio_urls :
593+ for audio_url in audio_urls :
594+ audio_bytes , mime_type = await self .encode_audio_data (audio_url )
595+ if not audio_bytes or not mime_type :
596+ logger .warning (f"音频 { audio_url } 处理失败,将忽略。" )
597+ continue
598+
599+ # 添加音频数据
600+ user_content ["content" ].append (
601+ {
602+ "type" : "inline_data" ,
603+ "inline_data" : {
604+ "mime_type" : mime_type ,
605+ "data" : audio_bytes
606+ }
607+ }
608+ )
609+
572610 return user_content
573611 else :
574612 return {"role" : "user" , "content" : text }
@@ -584,5 +622,41 @@ async def encode_image_bs64(self, image_url: str) -> str:
584622 return "data:image/jpeg;base64," + image_bs64
585623 return ""
586624
625+ async def encode_audio_data (self , audio_url : str ) -> tuple :
626+ """
627+ 读取音频文件并返回二进制数据
628+
629+ Returns:
630+ tuple: (音频二进制数据, MIME类型)
631+ """
632+ try :
633+ # 直接读取文件二进制数据
634+ with open (audio_url , "rb" ) as f :
635+ audio_bytes = f .read ()
636+
637+ # 推断 MIME 类型
638+ mime_type = mimetypes .guess_type (audio_url )[0 ]
639+ if not mime_type :
640+ # 根据文件扩展名确定 MIME 类型
641+ extension = os .path .splitext (audio_url )[1 ].lower ()
642+ if extension == '.wav' :
643+ mime_type = 'audio/wav'
644+ elif extension == '.mp3' :
645+ mime_type = 'audio/mpeg'
646+ elif extension == '.ogg' :
647+ mime_type = 'audio/ogg'
648+ elif extension == '.flac' :
649+ mime_type = 'audio/flac'
650+ elif extension == '.m4a' :
651+ mime_type = 'audio/mp4'
652+ else :
653+ mime_type = 'audio/wav' # 默认
654+
655+ logger .info (f"音频文件处理成功: { audio_url } ,mime类型: { mime_type } ,大小: { len (audio_bytes )} 字节" )
656+ return audio_bytes , mime_type
657+ except Exception as e :
658+ logger .error (f"音频文件处理失败: { e } " )
659+ return None , None
660+
587661 async def terminate (self ):
588662 logger .info ("Google GenAI 适配器已终止。" )
0 commit comments