diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 9c0658465..8fd8ce5f9 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,5 +1,7 @@ import pytest +import pytest_asyncio import os +import asyncio from quart import Quart from astrbot.dashboard.server import AstrBotDashboard from astrbot.core.db.sqlite import SQLiteDatabase @@ -9,36 +11,46 @@ from astrbot.core.star.star import star_registry -@pytest.fixture(scope="module") -def core_lifecycle_td(): - db = SQLiteDatabase("data/data_v3.db") +@pytest_asyncio.fixture(scope="module") +async def core_lifecycle_td(tmp_path_factory): + """Creates and initializes a core lifecycle instance with a temporary database.""" + tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db" + db = SQLiteDatabase(str(tmp_db_path)) log_broker = LogBroker() - core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db) - return core_lifecycle_td + core_lifecycle = AstrBotCoreLifecycle(log_broker, db) + await core_lifecycle.initialize() + return core_lifecycle @pytest.fixture(scope="module") -def app(core_lifecycle_td): - db = SQLiteDatabase("data/data_v3.db") - server = AstrBotDashboard(core_lifecycle_td, db) +def app(core_lifecycle_td: AstrBotCoreLifecycle): + """Creates a Quart app instance for testing.""" + shutdown_event = asyncio.Event() + # The db instance is already part of the core_lifecycle_td + server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) return server.app -@pytest.fixture(scope="module") -def header(): - return {} - - -@pytest.mark.asyncio -async def test_init_core_lifecycle_td(core_lifecycle_td): - await core_lifecycle_td.initialize() - assert core_lifecycle_td is not None +@pytest_asyncio.fixture(scope="module") +async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): + """Handles login and returns an authenticated header.""" + test_client = app.test_client() + response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": core_lifecycle_td.astrbot_config["dashboard"]["password"], + }, + ) + data = await response.get_json() + assert data["status"] == "ok" + token = data["data"]["token"] + return {"Authorization": f"Bearer {token}"} @pytest.mark.asyncio -async def test_auth_login( - app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict -): +async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): + """Tests the login functionality with both wrong and correct credentials.""" test_client = app.test_client() response = await test_client.post( "/api/auth/login", json={"username": "wrong", "password": "password"} @@ -55,31 +67,32 @@ async def test_auth_login( ) data = await response.get_json() assert data["status"] == "ok" and "token" in data["data"] - header["Authorization"] = f"Bearer {data['data']['token']}" @pytest.mark.asyncio -async def test_get_stat(app: Quart, header: dict): +async def test_get_stat(app: Quart, authenticated_header: dict): test_client = app.test_client() response = await test_client.get("/api/stat/get") assert response.status_code == 401 - response = await test_client.get("/api/stat/get", headers=header) + response = await test_client.get("/api/stat/get", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "ok" and "platform" in data["data"] @pytest.mark.asyncio -async def test_plugins(app: Quart, header: dict): +async def test_plugins(app: Quart, authenticated_header: dict): test_client = app.test_client() # 已经安装的插件 - response = await test_client.get("/api/plugin/get", headers=header) + response = await test_client.get("/api/plugin/get", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "ok" # 插件市场 - response = await test_client.get("/api/plugin/market_list", headers=header) + response = await test_client.get( + "/api/plugin/market_list", headers=authenticated_header + ) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "ok" @@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict): response = await test_client.post( "/api/plugin/install", json={"url": "https://github.com/Soulter/astrbot_plugin_essential"}, - headers=header, + headers=authenticated_header, ) assert response.status_code == 200 data = await response.get_json() @@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict): # 插件更新 response = await test_client.post( - "/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header + "/api/plugin/update", + json={"name": "astrbot_plugin_essential"}, + headers=authenticated_header, ) assert response.status_code == 200 data = await response.get_json() @@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict): response = await test_client.post( "/api/plugin/uninstall", json={"name": "astrbot_plugin_essential"}, - headers=header, + headers=authenticated_header, ) assert response.status_code == 200 data = await response.get_json() @@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, header: dict): +async def test_check_update(app: Quart, authenticated_header: dict): test_client = app.test_client() - response = await test_client.get("/api/update/check", headers=header) + response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "success" @@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict): @pytest.mark.asyncio async def test_do_update( - app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, + tmp_path_factory, ): - global VERSION test_client = app.test_client() - os.makedirs("data/astrbot_release", exist_ok=True) - core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release" - VERSION = "114.514.1919810" - response = await test_client.post( - "/api/update/do", headers=header, json={"version": "latest"} + + # Use a temporary path for the mock update to avoid side effects + temp_release_dir = tmp_path_factory.mktemp("release") + release_path = temp_release_dir / "astrbot" + + async def mock_update(*args, **kwargs): + """Mocks the update process by creating a directory in the temp path.""" + os.makedirs(release_path, exist_ok=True) + return + + async def mock_download_dashboard(*args, **kwargs): + """Mocks the dashboard download to prevent network access.""" + return + + async def mock_pip_install(*args, **kwargs): + """Mocks pip install to prevent actual installation.""" + return + + monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update) + monkeypatch.setattr( + "astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard + ) + monkeypatch.setattr( + "astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "error" # 已经是最新版本 response = await test_client.post( - "/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False} + "/api/update/do", + headers=authenticated_header, + json={"version": "v3.4.0", "reboot": False}, ) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "ok" - assert os.path.exists("data/astrbot_release/astrbot") + assert os.path.exists(release_path) diff --git a/tests/test_main.py b/tests/test_main.py index 0f5e51d13..d7e45b01d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,9 @@ import os import sys + +# 将项目根目录添加到 sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + import pytest from unittest import mock from main import check_env, check_dashboard_files @@ -27,29 +31,58 @@ def test_check_env(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files(monkeypatch): +async def test_check_dashboard_files_not_exists(monkeypatch): + """Tests dashboard download when files do not exist.""" monkeypatch.setattr(os.path, "exists", lambda x: False) - async def mock_get(*args, **kwargs): - class MockResponse: - status = 200 + with mock.patch("main.download_dashboard") as mock_download: + await check_dashboard_files() + mock_download.assert_called_once() + - async def read(self): - return b"content" +@pytest.mark.asyncio +async def test_check_dashboard_files_exists_and_version_match(monkeypatch): + """Tests that dashboard is not downloaded when it exists and version matches.""" + # Mock os.path.exists to return True + monkeypatch.setattr(os.path, "exists", lambda x: True) + + # Mock get_dashboard_version to return the current version + with mock.patch("main.get_dashboard_version") as mock_get_version: + # We need to import VERSION from main's context + from main import VERSION - return MockResponse() + mock_get_version.return_value = f"v{VERSION}" - with mock.patch("aiohttp.ClientSession.get", new=mock_get): - with mock.patch("builtins.open", mock.mock_open()) as mock_file: - with mock.patch("zipfile.ZipFile.extractall") as mock_extractall: + with mock.patch("main.download_dashboard") as mock_download: + await check_dashboard_files() + # Assert that download_dashboard was NOT called + mock_download.assert_not_called() - async def mock_aenter(_): - await check_dashboard_files() - mock_file.assert_called_once_with("data/dashboard.zip", "wb") - mock_extractall.assert_called_once() - async def mock_aexit(obj, exc_type, exc, tb): - return +@pytest.mark.asyncio +async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): + """Tests that a warning is logged when dashboard version mismatches.""" + monkeypatch.setattr(os.path, "exists", lambda x: True) + + with mock.patch("main.get_dashboard_version") as mock_get_version: + mock_get_version.return_value = "v0.0.1" # A different version + + with mock.patch("main.logger.warning") as mock_logger_warning: + await check_dashboard_files() + mock_logger_warning.assert_called_once() + call_args, _ = mock_logger_warning.call_args + assert "不符" in call_args[0] + + +@pytest.mark.asyncio +async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch): + """Tests that providing a valid webui_dir skips all checks.""" + valid_dir = "/tmp/my-custom-webui" + monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir) - mock_extractall.__aenter__ = mock_aenter - mock_extractall.__aexit__ = mock_aexit + with mock.patch("main.download_dashboard") as mock_download: + with mock.patch("main.get_dashboard_version") as mock_get_version: + result = await check_dashboard_files(webui_dir=valid_dir) + assert result == valid_dir + mock_download.assert_not_called() + mock_get_version.assert_not_called() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py deleted file mode 100644 index 7afedaefa..000000000 --- a/tests/test_pipeline.py +++ /dev/null @@ -1,285 +0,0 @@ -import pytest -import logging -import os -import asyncio -from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext -from astrbot.core.star import PluginManager -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.astrbot_message import ( - AstrBotMessage, - MessageMember, - MessageType, -) -from astrbot.core.message.message_event_result import MessageChain, ResultContentType -from astrbot.core.message.components import Plain, At -from astrbot.core.platform.platform_metadata import PlatformMetadata -from astrbot.core.platform.manager import PlatformManager -from astrbot.core.provider.manager import ProviderManager -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.star.context import Context -from asyncio import Queue - -SESSION_ID_IN_WHITELIST = "test_sid_wl" -SESSION_ID_NOT_IN_WHITELIST = "test_sid" -TEST_LLM_PROVIDER = { - "id": "zhipu_default", - "type": "openai_chat_completion", - "enable": True, - "key": [os.getenv("ZHIPU_API_KEY")], - "api_base": "https://open.bigmodel.cn/api/paas/v4/", - "model_config": { - "model": "glm-4-flash", - }, -} - -TEST_COMMANDS = [ - ["help", "已注册的 AstrBot 内置指令"], - ["tool ls", "函数工具"], - ["tool on websearch", "激活工具"], - ["tool off websearch", "停用工具"], - ["plugin", "已加载的插件"], - ["t2i", "文本转图片模式"], - ["sid", "此 ID 可用于设置会话白名单。"], - ["op test_op", "授权成功。"], - ["deop test_op", "取消授权成功。"], - ["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"], - ["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"], - ["provider", "当前载入的 LLM 提供商"], - ["reset", "重置成功"], - # ["model", "查看、切换提供商模型列表"], - ["history", "历史记录:"], - ["key", "当前 Key"], - ["persona", "[Persona]"], -] - - -class FakeAstrMessageEvent(AstrMessageEvent): - def __init__(self, abm: AstrBotMessage = None): - meta = PlatformMetadata("test_platform", "test") - super().__init__( - message_str=abm.message_str, - message_obj=abm, - platform_meta=meta, - session_id=abm.session_id, - ) - - async def send(self, message: MessageChain): - await super().send(message) - - @staticmethod - def create_fake_event( - message_str: str, - session_id: str = "test_sid", - is_at: bool = False, - is_group: bool = False, - sender_id: str = "123456", - ): - abm = AstrBotMessage() - abm.message_str = message_str - abm.group_id = "test" - abm.message = [Plain(message_str)] - if is_at: - abm.message.append(At(qq="bot")) - abm.self_id = "bot" - abm.sender = MessageMember(sender_id, "mika") - abm.timestamp = 1234567890 - abm.message_id = "test" - abm.session_id = session_id - if is_group: - abm.type = MessageType.GROUP_MESSAGE - else: - abm.type = MessageType.FRIEND_MESSAGE - return FakeAstrMessageEvent(abm) - - -@pytest.fixture(scope="module") -def event_queue(): - return Queue() - - -@pytest.fixture(scope="module") -def config(): - cfg = AstrBotConfig() - cfg["platform_settings"]["id_whitelist"] = [ - "test_platform:FriendMessage:test_sid_wl", - "test_platform:GroupMessage:test_sid_wl", - ] - cfg["admins_id"] = ["123456"] - cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"] - cfg["provider"] = [TEST_LLM_PROVIDER] - return cfg - - -@pytest.fixture(scope="module") -def db(): - return SQLiteDatabase("data/data_v3.db") - - -@pytest.fixture(scope="module") -def platform_manager(event_queue, config): - return PlatformManager(config, event_queue) - - -@pytest.fixture(scope="module") -def provider_manager(config, db): - return ProviderManager(config, db) - - -@pytest.fixture(scope="module") -def star_context(event_queue, config, db, platform_manager, provider_manager): - star_context = Context(event_queue, config, db, provider_manager, platform_manager) - return star_context - - -@pytest.fixture(scope="module") -def plugin_manager(star_context, config): - plugin_manager = PluginManager(star_context, config) - # await plugin_manager.reload() - asyncio.run(plugin_manager.reload()) - return plugin_manager - - -@pytest.fixture(scope="module") -def pipeline_context(config, plugin_manager): - return PipelineContext(config, plugin_manager) - - -@pytest.fixture(scope="module") -def pipeline_scheduler(pipeline_context): - return PipelineScheduler(pipeline_context) - - -@pytest.mark.asyncio -async def test_platform_initialization(platform_manager: PlatformManager): - await platform_manager.initialize() - - -@pytest.mark.asyncio -async def test_provider_initialization(provider_manager: ProviderManager): - await provider_manager.initialize() - - -@pytest.mark.asyncio -async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler): - await pipeline_scheduler.initialize() - - -@pytest.mark.asyncio -async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog): - """测试唤醒""" - # 群聊无 @ 无指令 - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True) - with caplog.at_level(logging.DEBUG): - await pipeline_scheduler.execute(mock_event) - assert any( - "执行阶段 WhitelistCheckStage" not in message for message in caplog.messages - ) - # 群聊有 @ 无指令 - mock_event = FakeAstrMessageEvent.create_fake_event( - "test", is_group=True, is_at=True - ) - with caplog.at_level(logging.DEBUG): - await pipeline_scheduler.execute(mock_event) - assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages) - # 群聊有指令 - mock_event = FakeAstrMessageEvent.create_fake_event( - "/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST - ) - await pipeline_scheduler.execute(mock_event) - assert mock_event._has_send_oper is True - - -@pytest.mark.asyncio -async def test_pipeline_wl( - pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog -): - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event( - "test", SESSION_ID_IN_WHITELIST, sender_id="123" - ) - with caplog.at_level(logging.INFO): - await pipeline_scheduler.execute(mock_event) - assert any( - "不在会话白名单中,已终止事件传播。" not in message - for message in caplog.messages - ), "日志中未找到预期的消息" - - mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123") - with caplog.at_level(logging.INFO): - await pipeline_scheduler.execute(mock_event) - assert any( - "不在会话白名单中,已终止事件传播。" in message for message in caplog.messages - ), "日志中未找到预期的消息" - - -@pytest.mark.asyncio -async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog): - # 测试默认屏蔽词 - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event( - "色情", session_id=SESSION_ID_IN_WHITELIST - ) # 测试需要。 - with caplog.at_level(logging.INFO): - await pipeline_scheduler.execute(mock_event) - assert any("内容安全检查不通过" in message for message in caplog.messages), ( - "日志中未找到预期的消息" - ) - # 测试额外屏蔽词 - mock_event = FakeAstrMessageEvent.create_fake_event( - "TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST - ) - with caplog.at_level(logging.INFO): - await pipeline_scheduler.execute(mock_event) - assert any("内容安全检查不通过" in message for message in caplog.messages), ( - "日志中未找到预期的消息" - ) - mock_event = FakeAstrMessageEvent.create_fake_event( - "_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST - ) - with caplog.at_level(logging.INFO): - await pipeline_scheduler.execute(mock_event) - assert any("内容安全检查不通过" not in message for message in caplog.messages) - # TODO: 测试 百度AI 的内容安全检查 - - -@pytest.mark.asyncio -async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog): - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event( - "just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST - ) - with caplog.at_level(logging.DEBUG): - await pipeline_scheduler.execute(mock_event) - assert any("请求 LLM" in message for message in caplog.messages) - assert mock_event.get_result() is not None - assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT - - -@pytest.mark.asyncio -async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog): - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event( - "help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST - ) - with caplog.at_level(logging.DEBUG): - await pipeline_scheduler.execute(mock_event) - assert any("请求 LLM" in message for message in caplog.messages) - assert any( - "web_searcher - search_from_search_engine" in message - for message in caplog.messages - ) - - -@pytest.mark.asyncio -async def test_commands(pipeline_scheduler: PipelineScheduler, caplog): - for command in TEST_COMMANDS: - caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event( - command[0], session_id=SESSION_ID_IN_WHITELIST - ) - with caplog.at_level(logging.DEBUG): - await pipeline_scheduler.execute(mock_event) - # assert any("执行阶段 ProcessStage" in message for message in caplog.messages) - assert any(command[1] in message for message in caplog.messages) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1a7831536..44486a11d 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,5 +1,6 @@ import pytest import os +from unittest.mock import MagicMock from astrbot.core.star.star_manager import PluginManager from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star import star_registry @@ -8,18 +9,51 @@ from astrbot.core.db.sqlite import SQLiteDatabase from asyncio import Queue -event_queue = Queue() - -config = AstrBotConfig() - -db = SQLiteDatabase("data/data_v3.db") - -star_context = Context(event_queue, config, db) - @pytest.fixture -def plugin_manager_pm(): - return PluginManager(star_context, config) +def plugin_manager_pm(tmp_path): + """ + Provides a fully isolated PluginManager instance for testing. + - Uses a temporary directory for plugins. + - Uses a temporary database. + - Creates a fresh context for each test. + """ + # Create temporary resources + temp_plugins_path = tmp_path / "plugins" + temp_plugins_path.mkdir() + temp_db_path = tmp_path / "test_db.db" + + # Create fresh, isolated instances for the context + event_queue = Queue() + config = AstrBotConfig() + db = SQLiteDatabase(str(temp_db_path)) + + # Set the plugin store path in the config to the temporary directory + config.plugin_store_path = str(temp_plugins_path) + + # Mock dependencies for the context + provider_manager = MagicMock() + platform_manager = MagicMock() + conversation_manager = MagicMock() + message_history_manager = MagicMock() + persona_manager = MagicMock() + astrbot_config_mgr = MagicMock() + + star_context = Context( + event_queue, + config, + db, + provider_manager, + platform_manager, + conversation_manager, + message_history_manager, + persona_manager, + astrbot_config_mgr, + ) + + # Create the PluginManager instance + manager = PluginManager(star_context, config) + yield manager def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): @@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_plugin_crud(plugin_manager_pm: PluginManager): - """测试插件安装和重载""" - os.makedirs("data/plugins", exist_ok=True) +async def test_install_plugin(plugin_manager_pm: PluginManager): + """Tests successful plugin installation in an isolated environment.""" test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - plugin_path = await plugin_manager_pm.install_plugin(test_repo) - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert plugin_path is not None + plugin_info = await plugin_manager_pm.install_plugin(test_repo) + plugin_path = os.path.join( + plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential" + ) + + assert plugin_info is not None assert os.path.exists(plugin_path) - assert exists is True, "插件 astrbot_plugin_essential 未成功载入" - # shutil.rmtree(plugin_path) + assert any(md.name == "astrbot_plugin_essential" for md in star_registry), ( + "Plugin 'astrbot_plugin_essential' was not loaded into star_registry." + ) + - # install plugin which is not exists +@pytest.mark.asyncio +async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): + """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): - plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha") + await plugin_manager_pm.install_plugin( + "https://github.com/Soulter/non_existent_repo" + ) + + +@pytest.mark.asyncio +async def test_update_plugin(plugin_manager_pm: PluginManager): + """Tests updating an existing plugin in an isolated environment.""" + # First, install the plugin + test_repo = "https://github.com/Soulter/astrbot_plugin_essential" + await plugin_manager_pm.install_plugin(test_repo) - # update + # Then, update it await plugin_manager_pm.update_plugin("astrbot_plugin_essential") + +@pytest.mark.asyncio +async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): + """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): - await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha") + await plugin_manager_pm.update_plugin("non_existent_plugin") + + +@pytest.mark.asyncio +async def test_uninstall_plugin(plugin_manager_pm: PluginManager): + """Tests successful plugin uninstallation in an isolated environment.""" + # First, install the plugin + test_repo = "https://github.com/Soulter/astrbot_plugin_essential" + await plugin_manager_pm.install_plugin(test_repo) + plugin_path = os.path.join( + plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential" + ) + assert os.path.exists(plugin_path) # Pre-condition - # uninstall + # Then, uninstall it await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential") + assert not os.path.exists(plugin_path) - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" - exists = False - for md in star_handlers_registry: - if "astrbot_plugin_essential" in md.handler_module_path: - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), ( + "Plugin 'astrbot_plugin_essential' was not unloaded from star_registry." + ) + assert not any( + "astrbot_plugin_essential" in md.handler_module_path + for md in star_handlers_registry + ), ( + "Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry." + ) - with pytest.raises(Exception): - await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha") - # TODO: file installation +@pytest.mark.asyncio +async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager): + """Tests that uninstalling a non-existent plugin raises an exception.""" + with pytest.raises(Exception): + await plugin_manager_pm.uninstall_plugin("non_existent_plugin")