From e9ea2c627a36c207596c269ac1be383201da5549 Mon Sep 17 00:00:00 2001 From: grillazz Date: Sat, 3 Jan 2026 19:04:07 +0100 Subject: [PATCH 1/3] refactor: update test fixtures and remove unused environment variables --- .env | 1 - app/config.py | 1 - app/database.py | 2 +- tests/api/test_stuff.py | 19 +++++++++++++------ tests/conftest.py | 34 ++++++++++++++++++++++++++++------ 5 files changed, 42 insertions(+), 15 deletions(-) diff --git a/.env b/.env index 3dbfa84..842a23b 100644 --- a/.env +++ b/.env @@ -7,7 +7,6 @@ POSTGRES_PORT=5432 POSTGRES_DB=devdb POSTGRES_USER=devdb POSTGRES_TEST_DB=testdb -POSTGRES_TEST_USER=testdb POSTGRES_PASSWORD=secret # Redis diff --git a/app/config.py b/app/config.py index c79d0d6..dbe4877 100644 --- a/app/config.py +++ b/app/config.py @@ -33,7 +33,6 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str POSTGRES_HOST: str POSTGRES_DB: str - POSTGRES_TEST_USER: str POSTGRES_TEST_DB: str @computed_field diff --git a/app/database.py b/app/database.py index 7f34e70..e1003d7 100644 --- a/app/database.py +++ b/app/database.py @@ -18,7 +18,7 @@ test_engine = create_async_engine( global_settings.test_asyncpg_url.unicode_string(), future=True, - echo=True, + echo=False, ) # expire_on_commit=False will prevent attributes from being expired diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index e420b07..985c892 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -4,8 +4,10 @@ from httpx import AsyncClient from inline_snapshot import snapshot from polyfactory.factories.pydantic_factory import ModelFactory +from sqlalchemy.ext.asyncio import AsyncSession from app.schemas.stuff import StuffSchema +from app.models import Stuff pytestmark = pytest.mark.anyio @@ -14,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]): __model__ = StuffSchema -async def test_add_stuff(client: AsyncClient): +async def test_add_stuff(client: AsyncClient, db_session: AsyncSession): stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") response = await client.post("/stuff", json=stuff) assert response.status_code == status.HTTP_201_CREATED @@ -32,22 +34,27 @@ async def test_add_stuff(client: AsyncClient): ) -async def test_get_stuff(client: AsyncClient): +async def test_get_stuff(client: AsyncClient, db_session: AsyncSession): response = await client.get("/stuff/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == snapshot( {"no_response": "The requested resource was not found"} ) stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") - await client.post("/stuff", json=stuff) - name = stuff["name"] + # await client.post("/stuff", json=stuff) + # name = stuff["name"] + stuff = Stuff(**stuff) + name = stuff.name + db_session.add(stuff) + await db_session.commit() + response = await client.get(f"/stuff/{name}") assert response.status_code == status.HTTP_200_OK assert response.json() == snapshot( { "id": IsUUID(4), - "name": stuff["name"], - "description": stuff["description"], + "name": stuff.name, + "description": stuff.description, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 107d3bc..9ad5d46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,12 @@ from typing import Any import pytest +from fastapi.exceptions import ResponseValidationError from httpx import ASGITransport, AsyncClient from sqlalchemy import text -from sqlalchemy.exc import ProgrammingError +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError -from app.database import engine, get_db, get_test_db, test_engine +from app.database import engine, get_db, test_engine, TestAsyncSessionFactory from app.main import app from app.models.base import Base from app.redis import get_redis @@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None: pass -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) async def start_db(): # The `engine` is configured for the default 'postgres' database. # We connect to it and create the test database. @@ -63,16 +64,37 @@ async def start_db(): await test_engine.dispose() -@pytest.fixture(scope="session") -async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 +@pytest.fixture() +async def db_session(): + connection = await test_engine.connect() + transaction = await connection.begin() + session = TestAsyncSessionFactory(bind=connection) + + try: + yield session + finally: + # Rollback the overall transaction, restoring the state before the test ran. + await session.close() + if transaction.is_active: + await transaction.rollback() + await connection.close() + + +@pytest.fixture(scope="function") +async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 transport = ASGITransport( app=app, ) + + async def override_get_db(): + yield db_session + await db_session.commit() + async with AsyncClient( base_url="http://testserver/v1", headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: - app.dependency_overrides[get_db] = get_test_db + app.dependency_overrides[get_db] = override_get_db app.redis = await get_redis() yield test_client From 5080eda3acf9cfe979acf9f024e3473cd5bdec8c Mon Sep 17 00:00:00 2001 From: grillazz Date: Sun, 11 Jan 2026 08:37:48 +0100 Subject: [PATCH 2/3] refactor: clean up imports and enhance test setup for user token retrieval --- tests/api/test_auth.py | 17 ++++++++++++++--- tests/api/test_stuff.py | 7 +++---- tests/conftest.py | 5 ++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 9000ef3..2d6cd84 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -37,17 +37,28 @@ async def test_add_user(client: AsyncClient): # TODO: parametrize test with diff urls including 404 and 401 async def test_get_token(client: AsyncClient): - payload = {"email": "joe@grillazz.com", "password": "s1lly"} + # Create the user first + user_payload = { + "email": "joe@grillazz.com", + "first_name": "Joe", + "last_name": "Garcia", + "password": "s1lly", + } + create_user_response = await client.post("/user/", json=user_payload) + assert create_user_response.status_code == status.HTTP_201_CREATED + + # Now request the token + token_payload = {"email": "joe@grillazz.com", "password": "s1lly"} response = await client.post( "/user/token", - data=payload, + data=token_payload, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) assert response.status_code == status.HTTP_201_CREATED claimset = jwt.decode( response.json()["access_token"], options={"verify_signature": False} ) - assert claimset["email"] == payload["email"] + assert claimset["email"] == token_payload["email"] assert claimset["expiry"] == IsPositiveFloat() assert claimset["platform"] == "python-httpx/0.28.1" diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index 985c892..3b57da0 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -6,8 +6,8 @@ from polyfactory.factories.pydantic_factory import ModelFactory from sqlalchemy.ext.asyncio import AsyncSession -from app.schemas.stuff import StuffSchema from app.models import Stuff +from app.schemas.stuff import StuffSchema pytestmark = pytest.mark.anyio @@ -16,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]): __model__ = StuffSchema -async def test_add_stuff(client: AsyncClient, db_session: AsyncSession): +async def test_add_stuff(client: AsyncClient): stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") response = await client.post("/stuff", json=stuff) assert response.status_code == status.HTTP_201_CREATED @@ -40,9 +40,8 @@ async def test_get_stuff(client: AsyncClient, db_session: AsyncSession): assert response.json() == snapshot( {"no_response": "The requested resource was not found"} ) + # test if db_session and client share the same in-memory db and rollback works stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") - # await client.post("/stuff", json=stuff) - # name = stuff["name"] stuff = Stuff(**stuff) name = stuff.name db_session.add(stuff) diff --git a/tests/conftest.py b/tests/conftest.py index 9ad5d46..1a954d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,11 @@ from typing import Any import pytest -from fastapi.exceptions import ResponseValidationError from httpx import ASGITransport, AsyncClient from sqlalchemy import text -from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.exc import ProgrammingError -from app.database import engine, get_db, test_engine, TestAsyncSessionFactory +from app.database import TestAsyncSessionFactory, engine, get_db, test_engine from app.main import app from app.models.base import Base from app.redis import get_redis From 2373ea25451c8e92dbb728e1b28d2b8119772afb Mon Sep 17 00:00:00 2001 From: grillazz Date: Sun, 11 Jan 2026 08:46:16 +0100 Subject: [PATCH 3/3] refactor: remove unnecessary await from database session commit in client fixture --- tests/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1a954d7..25b3a50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,14 +80,13 @@ async def db_session(): @pytest.fixture(scope="function") -async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 +async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: transport = ASGITransport( app=app, ) async def override_get_db(): yield db_session - await db_session.commit() async with AsyncClient( base_url="http://testserver/v1",