Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions app/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import secrets

from fastapi import HTTPException, Depends
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette import status

from app.config import API_PWD, API_USER

security = HTTPBasic()


def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, API_USER)
correct_password = secrets.compare_digest(credentials.password, API_PWD)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
12 changes: 9 additions & 3 deletions app/api/v1/exercise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.auth import get_current_username
from app.api.schema.exercise import ExerciseRead, ExerciseCreate, TestCaseRead, TestCaseCreate, \
ExerciseWithUnlockTimestamps
from app.db.database import get_session
Expand Down Expand Up @@ -159,7 +160,8 @@ async def post_skip_current_exercise(tan_code: str, session: AsyncSession = Depe
@router.get("/{exercise_id}",
response_model=ExerciseRead,
status_code=status.HTTP_200_OK)
async def get_exercise(exercise_id: int, session: AsyncSession = Depends(get_session)) -> ExerciseRead:
async def get_exercise(exercise_id: int, _username: str = Depends(get_current_username),
session: AsyncSession = Depends(get_session)) -> ExerciseRead:
statement = select(Exercise).where(Exercise.id == exercise_id)
result = await session.execute(statement)
exercise = result.scalars().first()
Expand All @@ -171,7 +173,9 @@ async def get_exercise(exercise_id: int, session: AsyncSession = Depends(get_ses


@router.post("/", status_code=status.HTTP_201_CREATED, response_model=ExerciseRead)
async def create_exercise(new_exercise: ExerciseCreate, session: AsyncSession = Depends(get_session)) -> ExerciseRead:
async def create_exercise(new_exercise: ExerciseCreate,
_username: str = Depends(get_current_username),
session: AsyncSession = Depends(get_session)) -> ExerciseRead:
exercise = Exercise(**new_exercise.model_dump())
exercise.id = None

Expand All @@ -184,6 +188,7 @@ async def create_exercise(new_exercise: ExerciseCreate, session: AsyncSession =

@router.post("/{exercise_id}/test-cases", response_model=TestCaseRead)
async def create_test_case(exercise_id: int, new_test_case: TestCaseCreate,
_username: str = Depends(get_current_username),
session: AsyncSession = Depends(get_session)) -> TestCaseRead:
test_case = TestCase(exercise_id=exercise_id, **new_test_case.model_dump())
session.add(test_case)
Expand All @@ -195,7 +200,8 @@ async def create_test_case(exercise_id: int, new_test_case: TestCaseCreate,


@router.get("/{exercise_id}/test-cases", response_model=list[TestCaseRead], status_code=status.HTTP_200_OK)
async def get_test_cases(exercise_id: int, session: AsyncSession = Depends(get_session)) -> list[TestCaseRead]:
async def get_test_cases(exercise_id: int, _username: str = Depends(get_current_username),
session: AsyncSession = Depends(get_session)) -> list[TestCaseRead]:
statement = select(TestCase).where(TestCase.exercise_id == exercise_id)
result = await session.execute(statement)
test_cases = result.scalars().all()
Expand Down
4 changes: 3 additions & 1 deletion app/api/v1/logging_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.auth import get_current_username
from app.api.schema.logging_event import LoggingEventRead, LoggingEventCreate
from app.db.database import get_session
from app.db.model.logging_event import LoggingEvent
Expand All @@ -18,7 +19,8 @@
@router.get("/{tan_code}",
response_model=list[LoggingEventRead],
status_code=status.HTTP_200_OK)
async def get_logging_events(tan_code: str, session: AsyncSession = Depends(get_session)) -> list[LoggingEventRead]:
async def get_logging_events(tan_code: str, _username: str = Depends(get_current_username),
session: AsyncSession = Depends(get_session)) -> list[LoggingEventRead]:
statement = select(Tan).where(Tan.code == tan_code)
result = await session.execute(statement)
tan = result.scalars().first()
Expand Down
7 changes: 5 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
ORIGINS = os.environ.get("BLOCKSEMBLER_ORIGINS", "*").split(',')
BASE_URL = os.environ.get('BLOCKSEMBLER_API_BASE_URL', '')

API_PWD = os.environ.get('BLOCKSEMBLER_API_PWD', 's3cr3t!')
API_USER = os.environ.get('BLOCKSEMBLER_ACCESS_TOKEN', 'admin')

DATABASE_URL = os.environ.get("BLOCKSEMBLER_DB_URI",
"postgresql+asyncpg://postgres:postgres@blocksembler-db:5432/blocksembler")
"postgresql+asyncpg://postgres:postgres@localhost:5432/blocksembler")

MQ_URL = os.environ.get('BLOCKSEMBLER_MQ_URL', 'blocksembler-mq')
MQ_URL = os.environ.get('BLOCKSEMBLER_MQ_URL', 'localhost')
MQ_PORT = os.environ.get('BLOCKSEMBLER_MQ_PORT', '5672')
MQ_USER = os.environ.get("BLOCKSEMBLER_MQ_USER", "blocksembler")
MQ_PWD = os.environ.get("BLOCKSEMBLER_MQ_PWD", "blocksembler")
Expand Down
14 changes: 9 additions & 5 deletions tests/test_exercise.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
from datetime import datetime, timezone

from fastapi import status
from fastapi import status, security
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker

from app.config import API_USER, API_PWD
from app.db.database import get_session
from app.main import app
from app.util import get_datetime_now
from tests.util.auth_util import basic_auth_header
from tests.util.db_util import create_test_tables, get_override_dependency, insert_demo_data, DB_URI
from tests.util.demo_data import EXERCISES

Expand All @@ -29,10 +31,11 @@ def setup_class(self):
asyncio.run(insert_demo_data(self.async_session))

def test_get_exercise(self):
app.dependency_overrides[security] = get_override_dependency(self.engine)
app.dependency_overrides[get_session] = get_override_dependency(self.engine)
client = TestClient(app)

response = client.get("/exercises/1")
response = client.get("/exercises/1", headers=basic_auth_header(API_USER, API_PWD))

assert response.json() == EXERCISES[0]
assert response.status_code == 200
Expand All @@ -48,7 +51,7 @@ def test_post_exercise(self):
"next_exercise_id": None,
}

response = client.post("/exercises", json=new_exercise)
response = client.post("/exercises", json=new_exercise, headers=basic_auth_header(API_USER, API_PWD))
result_exercise = response.json()

print(result_exercise)
Expand Down Expand Up @@ -137,7 +140,8 @@ def test_post_test_case(self):
"expected_instructions": [],
}

response = client.post("/exercises/1/test-cases", json=new_test_case)
response = client.post("/exercises/1/test-cases", json=new_test_case,
headers=basic_auth_header(API_USER, API_PWD))

result_test_case = response.json()

Expand All @@ -153,7 +157,7 @@ def test_get_test_case(self):
app.dependency_overrides[get_session] = get_override_dependency(self.engine)
client = TestClient(app)

response = client.get("/exercises/2/test-cases")
response = client.get("/exercises/2/test-cases", headers=basic_auth_header(API_USER, API_PWD))
result_test_cases = response.json()

expected_test_case = {
Expand Down
4 changes: 3 additions & 1 deletion tests/test_logging_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker

from app.api.schema.logging_event import LoggingEventRead
from app.config import API_USER, API_PWD
from app.db.database import get_session
from app.main import app
from tests.util.auth_util import basic_auth_header
from tests.util.db_util import insert_demo_data, DB_URI, create_test_tables, get_override_dependency
from tests.util.demo_data import LOGGING_EVENTS

Expand All @@ -23,7 +25,7 @@ def test_get_logging_events(self):
app.dependency_overrides[get_session] = get_override_dependency(self.engine)
client = TestClient(app)

response = client.get("/logging-events/logging-test-tan")
response = client.get("/logging-events/logging-test-tan", headers=basic_auth_header(API_USER, API_PWD))

assert response.status_code == 200
assert len(response.json()) == 2
Expand Down
7 changes: 7 additions & 0 deletions tests/util/auth_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import base64


def basic_auth_header(username: str, password: str):
credentials = f"{username}:{password}"
encoded = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded}"}