Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ POSTGRES_PORT=5432
POSTGRES_DB=devdb
POSTGRES_USER=devdb
POSTGRES_TEST_DB=testdb
POSTGRES_TEST_USER=testdb
POSTGRES_PASSWORD=secret

# Redis
Expand Down
1 change: 0 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Settings(BaseSettings):
POSTGRES_PASSWORD: str
POSTGRES_HOST: str
POSTGRES_DB: str
POSTGRES_TEST_USER: str
POSTGRES_TEST_DB: str

@computed_field
Expand Down
2 changes: 1 addition & 1 deletion app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
test_engine = create_async_engine(
global_settings.test_asyncpg_url.unicode_string(),
future=True,
echo=True,
echo=False,
)

# expire_on_commit=False will prevent attributes from being expired
Expand Down
17 changes: 14 additions & 3 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,28 @@ async def test_add_user(client: AsyncClient):

# TODO: parametrize test with diff urls including 404 and 401
async def test_get_token(client: AsyncClient):
payload = {"email": "joe@grillazz.com", "password": "s1lly"}
# Create the user first
user_payload = {
"email": "joe@grillazz.com",
"first_name": "Joe",
"last_name": "Garcia",
"password": "s1lly",
}
create_user_response = await client.post("/user/", json=user_payload)
assert create_user_response.status_code == status.HTTP_201_CREATED

# Now request the token
token_payload = {"email": "joe@grillazz.com", "password": "s1lly"}
response = await client.post(
"/user/token",
data=payload,
data=token_payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
assert response.status_code == status.HTTP_201_CREATED
claimset = jwt.decode(
response.json()["access_token"], options={"verify_signature": False}
)
assert claimset["email"] == payload["email"]
assert claimset["email"] == token_payload["email"]
assert claimset["expiry"] == IsPositiveFloat()
assert claimset["platform"] == "python-httpx/0.28.1"

Expand Down
16 changes: 11 additions & 5 deletions tests/api/test_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from httpx import AsyncClient
from inline_snapshot import snapshot
from polyfactory.factories.pydantic_factory import ModelFactory
from sqlalchemy.ext.asyncio import AsyncSession

from app.models import Stuff
from app.schemas.stuff import StuffSchema

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -32,22 +34,26 @@ async def test_add_stuff(client: AsyncClient):
)


async def test_get_stuff(client: AsyncClient):
async def test_get_stuff(client: AsyncClient, db_session: AsyncSession):
response = await client.get("/stuff/nonexistent")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == snapshot(
{"no_response": "The requested resource was not found"}
)
# test if db_session and client share the same in-memory db and rollback works
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
await client.post("/stuff", json=stuff)
name = stuff["name"]
stuff = Stuff(**stuff)
name = stuff.name
db_session.add(stuff)
await db_session.commit()

response = await client.get(f"/stuff/{name}")
assert response.status_code == status.HTTP_200_OK
assert response.json() == snapshot(
{
"id": IsUUID(4),
"name": stuff["name"],
"description": stuff["description"],
"name": stuff.name,
"description": stuff.description,
}
)

Expand Down
30 changes: 25 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError

from app.database import engine, get_db, get_test_db, test_engine
from app.database import TestAsyncSessionFactory, engine, get_db, test_engine
from app.main import app
from app.models.base import Base
from app.redis import get_redis
Expand Down Expand Up @@ -43,7 +43,7 @@ def _create_db_schema(conn) -> None:
pass


@pytest.fixture(scope="session")
@pytest.fixture(scope="session", autouse=True)
async def start_db():
# The `engine` is configured for the default 'postgres' database.
# We connect to it and create the test database.
Expand All @@ -63,16 +63,36 @@ async def start_db():
await test_engine.dispose()


@pytest.fixture(scope="session")
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
session = TestAsyncSessionFactory(bind=connection)

try:
yield session
finally:
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
if transaction.is_active:
await transaction.rollback()
await connection.close()


@pytest.fixture(scope="function")
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]:
transport = ASGITransport(
app=app,
)

async def override_get_db():
yield db_session

async with AsyncClient(
base_url="http://testserver/v1",
headers={"Content-Type": "application/json"},
transport=transport,
) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.dependency_overrides[get_db] = override_get_db
app.redis = await get_redis()
yield test_client