diff --git a/.env b/.env index 3dbfa84..842a23b 100644 --- a/.env +++ b/.env @@ -7,7 +7,6 @@ POSTGRES_PORT=5432 POSTGRES_DB=devdb POSTGRES_USER=devdb POSTGRES_TEST_DB=testdb -POSTGRES_TEST_USER=testdb POSTGRES_PASSWORD=secret # Redis diff --git a/app/config.py b/app/config.py index c79d0d6..dbe4877 100644 --- a/app/config.py +++ b/app/config.py @@ -33,7 +33,6 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str POSTGRES_HOST: str POSTGRES_DB: str - POSTGRES_TEST_USER: str POSTGRES_TEST_DB: str @computed_field diff --git a/app/database.py b/app/database.py index 7f34e70..e1003d7 100644 --- a/app/database.py +++ b/app/database.py @@ -18,7 +18,7 @@ test_engine = create_async_engine( global_settings.test_asyncpg_url.unicode_string(), future=True, - echo=True, + echo=False, ) # expire_on_commit=False will prevent attributes from being expired diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 9000ef3..2d6cd84 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -37,17 +37,28 @@ async def test_add_user(client: AsyncClient): # TODO: parametrize test with diff urls including 404 and 401 async def test_get_token(client: AsyncClient): - payload = {"email": "joe@grillazz.com", "password": "s1lly"} + # Create the user first + user_payload = { + "email": "joe@grillazz.com", + "first_name": "Joe", + "last_name": "Garcia", + "password": "s1lly", + } + create_user_response = await client.post("/user/", json=user_payload) + assert create_user_response.status_code == status.HTTP_201_CREATED + + # Now request the token + token_payload = {"email": "joe@grillazz.com", "password": "s1lly"} response = await client.post( "/user/token", - data=payload, + data=token_payload, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) assert response.status_code == status.HTTP_201_CREATED claimset = jwt.decode( response.json()["access_token"], options={"verify_signature": False} ) - assert claimset["email"] == payload["email"] + assert claimset["email"] == token_payload["email"] assert claimset["expiry"] == IsPositiveFloat() assert claimset["platform"] == "python-httpx/0.28.1" diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index e420b07..3b57da0 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -4,7 +4,9 @@ from httpx import AsyncClient from inline_snapshot import snapshot from polyfactory.factories.pydantic_factory import ModelFactory +from sqlalchemy.ext.asyncio import AsyncSession +from app.models import Stuff from app.schemas.stuff import StuffSchema pytestmark = pytest.mark.anyio @@ -32,22 +34,26 @@ async def test_add_stuff(client: AsyncClient): ) -async def test_get_stuff(client: AsyncClient): +async def test_get_stuff(client: AsyncClient, db_session: AsyncSession): response = await client.get("/stuff/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == snapshot( {"no_response": "The requested resource was not found"} ) + # test if db_session and client share the same in-memory db and rollback works stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") - await client.post("/stuff", json=stuff) - name = stuff["name"] + stuff = Stuff(**stuff) + name = stuff.name + db_session.add(stuff) + await db_session.commit() + response = await client.get(f"/stuff/{name}") assert response.status_code == status.HTTP_200_OK assert response.json() == snapshot( { "id": IsUUID(4), - "name": stuff["name"], - "description": stuff["description"], + "name": stuff.name, + "description": stuff.description, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 107d3bc..25b3a50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from sqlalchemy import text from sqlalchemy.exc import ProgrammingError -from app.database import engine, get_db, get_test_db, test_engine +from app.database import TestAsyncSessionFactory, engine, get_db, test_engine from app.main import app from app.models.base import Base from app.redis import get_redis @@ -43,7 +43,7 @@ def _create_db_schema(conn) -> None: pass -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) async def start_db(): # The `engine` is configured for the default 'postgres' database. # We connect to it and create the test database. @@ -63,16 +63,36 @@ async def start_db(): await test_engine.dispose() -@pytest.fixture(scope="session") -async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 +@pytest.fixture() +async def db_session(): + connection = await test_engine.connect() + transaction = await connection.begin() + session = TestAsyncSessionFactory(bind=connection) + + try: + yield session + finally: + # Rollback the overall transaction, restoring the state before the test ran. + await session.close() + if transaction.is_active: + await transaction.rollback() + await connection.close() + + +@pytest.fixture(scope="function") +async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: transport = ASGITransport( app=app, ) + + async def override_get_db(): + yield db_session + async with AsyncClient( base_url="http://testserver/v1", headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: - app.dependency_overrides[get_db] = get_test_db + app.dependency_overrides[get_db] = override_get_db app.redis = await get_redis() yield test_client