diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 9c69a0039..4b947b3f0 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -8,7 +8,10 @@ from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav +from astrbot.core.utils.tencent_record_helper import ( + convert_to_pcm_wav, + tencent_silk_to_wav, +) from ..entities import ProviderType from ..provider import STTProvider @@ -111,17 +114,22 @@ async def get_text(self, audio_url: str) -> str: return "" # 2. Check for conversion - needs_conversion = False - if ( - audio_url.endswith((".amr", ".silk")) - or is_tencent - or b"SILK" in audio_bytes[:8] - ): - needs_conversion = True + conversion_type = None + + if b"SILK" in audio_bytes[:8]: + conversion_type = "silk" + elif b"#!AMR" in audio_bytes[:6]: + conversion_type = "amr" + elif audio_url.endswith(".silk") or is_tencent: + conversion_type = "silk" + elif audio_url.endswith(".amr"): + conversion_type = "amr" # 3. Perform conversion if needed - if needs_conversion: - logger.info("Audio requires conversion, using temporary files...") + if conversion_type: + logger.info( + f"Audio requires conversion ({conversion_type}), using temporary files..." + ) temp_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(temp_dir, exist_ok=True) @@ -132,8 +140,12 @@ async def get_text(self, audio_url: str) -> str: with open(input_path, "wb") as f: f.write(audio_bytes) - logger.info("Converting silk/amr file to wav ...") - await tencent_silk_to_wav(input_path, output_path) + if conversion_type == "silk": + logger.info("Converting silk to wav ...") + await tencent_silk_to_wav(input_path, output_path) + elif conversion_type == "amr": + logger.info("Converting amr to wav ...") + await convert_to_pcm_wav(input_path, output_path) with open(output_path, "rb") as f: audio_bytes = f.read()