Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 79 additions & 43 deletions tests/test_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import pytest_asyncio
import os
import asyncio
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
Expand All @@ -9,36 +11,46 @@
from astrbot.core.star.star import star_registry


@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
return core_lifecycle


@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
# The db instance is already part of the core_lifecycle_td
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app


@pytest.fixture(scope="module")
def header():
return {}


@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}


@pytest.mark.asyncio
async def test_auth_login(
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
):
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Tests the login functionality with both wrong and correct credentials."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login", json={"username": "wrong", "password": "password"}
Expand All @@ -55,31 +67,32 @@ async def test_auth_login(
)
data = await response.get_json()
assert data["status"] == "ok" and "token" in data["data"]
header["Authorization"] = f"Bearer {data['data']['token']}"


@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
async def test_get_stat(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/stat/get")
assert response.status_code == 401
response = await test_client.get("/api/stat/get", headers=header)
response = await test_client.get("/api/stat/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok" and "platform" in data["data"]


@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
async def test_plugins(app: Quart, authenticated_header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=header)
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

# 插件市场
response = await test_client.get("/api/plugin/market_list", headers=header)
response = await test_client.get(
"/api/plugin/market_list", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
Expand All @@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/install",
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
Expand All @@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):

# 插件更新
response = await test_client.post(
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
"/api/plugin/update",
json={"name": "astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
Expand All @@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": "astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
Expand All @@ -132,34 +147,55 @@ async def test_plugins(app: Quart, header: dict):


@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
async def test_check_update(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=header)
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"


@pytest.mark.asyncio
async def test_do_update(
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path_factory,
):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "latest"}

# Use a temporary path for the mock update to avoid side effects
temp_release_dir = tmp_path_factory.mktemp("release")
release_path = temp_release_dir / "astrbot"

async def mock_update(*args, **kwargs):
"""Mocks the update process by creating a directory in the temp path."""
os.makedirs(release_path, exist_ok=True)
return

async def mock_download_dashboard(*args, **kwargs):
"""Mocks the dashboard download to prevent network access."""
return

async def mock_pip_install(*args, **kwargs):
"""Mocks pip install to prevent actual installation."""
return

monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error" # 已经是最新版本

response = await test_client.post(
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
"/api/update/do",
headers=authenticated_header,
json={"version": "v3.4.0", "reboot": False},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists("data/astrbot_release/astrbot")
assert os.path.exists(release_path)
69 changes: 51 additions & 18 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import sys

# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import pytest
from unittest import mock
from main import check_env, check_dashboard_files
Expand Down Expand Up @@ -27,29 +31,58 @@ def test_check_env(monkeypatch):


@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
monkeypatch.setattr(os.path, "exists", lambda x: False)

async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()


async def read(self):
return b"content"
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)

# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION

return MockResponse()
mock_get_version.return_value = f"v{VERSION}"

with mock.patch("aiohttp.ClientSession.get", new=mock_get):
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()

async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()

async def mock_aexit(obj, exc_type, exc, tb):
return
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)

with mock.patch("main.get_dashboard_version") as mock_get_version:
mock_get_version.return_value = "v0.0.1" # A different version

with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "不符" in call_args[0]


@pytest.mark.asyncio
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
"""Tests that providing a valid webui_dir skips all checks."""
valid_dir = "/tmp/my-custom-webui"
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)

mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()
mock_get_version.assert_not_called()
Loading