diff --git a/README.md b/README.md index 0b5f8f2..dabf996 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,18 @@ uvicorn app.main:app --reload --env-file .env.local > If you need to alter the database, you can create new migrations using [alembic](https://alembic.sqlalchemy.org/en/latest/index.html). +## Migration: + +On bdd structure change: +Try an autogenerated migration: +``` +alembic revision --autogenerate -m "change context" +``` + +To apply it: +``` +alembic upgrade head +``` ## TODO diff --git a/app/auth.py b/app/auth.py index 28c74e9..1f86b5e 100644 --- a/app/auth.py +++ b/app/auth.py @@ -27,12 +27,13 @@ def jws_verify(token: str) -> Mapping[str, t.Any]: def create_ballot_token( vote_ids: int | list[int], election_ref: str, + ballot_id:int ) -> str: if isinstance(vote_ids, int): vote_ids = [vote_ids] vote_ids = sorted(vote_ids) return jws.sign( - {"votes": vote_ids, "election": election_ref}, + { "votes": vote_ids, "election": election_ref, "ballot": ballot_id }, settings.secret, algorithm="HS256", ) diff --git a/app/crud.py b/app/crud.py index b93edca..1f2a9cf 100644 --- a/app/crud.py +++ b/app/crud.py @@ -218,14 +218,33 @@ def create_invite_tokens( _check_election_is_not_ended(get_election(db, election_ref)) now = datetime.now() params = {"date_created": now, "date_modified": now, "election_ref": election_ref} - db_votes = [models.Vote(**params) for _ in range(num_voters * num_candidates)] - db.bulk_save_objects(db_votes, return_defaults=True) - db.commit() + + try: + db_ballots = [models.Ballot(election_ref=election_ref) for _ in range(num_voters)] + db.bulk_save_objects(db_ballots, return_defaults=True) + + db_votes = [] + + for ballot in db_ballots: + for _ in range(num_candidates): + db_votes.append(models.Vote(**params, ballot_id=ballot.id)) + + db.bulk_save_objects(db_votes, return_defaults=True) + db.commit() + except Exception as e: + db.rollback() + raise e + + tokens = [] vote_ids = [int(str(v.id)) for v in db_votes] - tokens = [ - create_ballot_token(vote_ids[i::num_voters], election_ref) - for i in range(num_voters) - ] + + for i, ballot in enumerate(db_ballots): + start = i * num_candidates + end = start + num_candidates + tokens.append( + create_ballot_token(vote_ids[start:end], election_ref, int(str(ballot.id))) + ) + return tokens @@ -405,18 +424,29 @@ def create_ballot(db: Session, ballot: schemas.BallotCreate) -> schemas.BallotGe ) _check_ballot_is_consistent(election, ballot) - # Ideally, we would use RETURNING but it does not work yet for SQLite - db_votes = [ - models.Vote(**v.model_dump(), election_ref=ballot.election_ref) for v in ballot.votes - ] - db.add_all(db_votes) - db.commit() - for v in db_votes: - db.refresh(v) + try: + db_ballot = models.Ballot(election_ref=ballot.election_ref) + db.add(db_ballot) + db.flush() + + # Create votes and associate them with the ballot + db_votes = [ + models.Vote(**v.model_dump(), election_ref=ballot.election_ref, ballot_id=db_ballot.id) + for v in ballot.votes + ] + db.add_all(db_votes) + db.commit() + db.refresh(db_ballot) + + for v in db_votes: + db.refresh(v) + except Exception as e: + db.rollback() + raise e votes_get = [schemas.VoteGet.model_validate(v) for v in db_votes] vote_ids = [v.id for v in votes_get] - token = create_ballot_token(vote_ids, ballot.election_ref) + token = create_ballot_token(vote_ids, ballot.election_ref, int(db_ballot.id)) return schemas.BallotGet(votes=votes_get, token=token, election=election) @@ -511,6 +541,13 @@ def update_ballot( if len(db_votes) != len(vote_ids): raise errors.NotFoundError("votes") + # Verify all votes belong to the same ballot + ballot_ids = {int(v.ballot_id) for v in db_votes if v.ballot_id is not None} + + if len(ballot_ids) > 1: + raise errors.ForbiddenError("All votes must belong to the same ballot") + + # old API does not contains ballot id in the token election = schemas.ElectionGet.model_validate(db_votes[0].election) for vote, db_vote in zip(ballot.votes, db_votes): @@ -521,7 +558,6 @@ def update_ballot( db.commit() votes_get = [schemas.VoteGet.model_validate(v) for v in db_votes] - token = create_ballot_token(vote_ids, election_ref) return schemas.BallotGet(votes=votes_get, token=token, election=election) diff --git a/app/models.py b/app/models.py index 2525050..0248b4c 100644 --- a/app/models.py +++ b/app/models.py @@ -2,7 +2,8 @@ from sqlalchemy.sql import func from sqlalchemy.orm import relationship from .database import Base - +import uuid # Pour pouvoir appeler uuid.uuid4() +from sqlalchemy import UUID class Election(Base): __tablename__ = "elections" @@ -24,6 +25,7 @@ class Election(Base): grades = relationship("Grade", back_populates="election") candidates = relationship("Candidate", back_populates="election") votes = relationship("Vote", back_populates="election") + ballots = relationship("Ballot", back_populates="election") class Candidate(Base): @@ -72,3 +74,20 @@ class Vote(Base): election_ref = Column(String(20), ForeignKey("elections.ref")) election = relationship("Election", back_populates="votes") + + ballot_id = Column(Integer, ForeignKey("ballots.id"), nullable=True) + ballot = relationship("Ballot", back_populates="votes") + + +class Ballot(Base): + __tablename__ = "ballots" + + id = Column(Integer, primary_key=True, index=True) + + voter_uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True) + + date_created = Column(DateTime, server_default=func.now()) + election_ref = Column(String(20), ForeignKey("elections.ref")) + + election = relationship("Election", back_populates="ballots") + votes = relationship("Vote", back_populates="ballot") \ No newline at end of file diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 4b1fd42..6eceddf 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -41,9 +41,9 @@ def test_ballot_token(): """ vote_ids = list(range(1000)) election_ref = "qwertyuiop" - token = create_ballot_token(vote_ids, election_ref) + token = create_ballot_token(vote_ids, election_ref, 1) data = jws_verify(token) - assert data == {"votes": vote_ids, "election": election_ref} + assert data == {"votes": vote_ids, "election": election_ref, "ballot": 1} def test_admin_token(): diff --git a/migrations/versions/81b4c6fc826d_add_ballot_table_and_ballot_id_to_votes.py b/migrations/versions/81b4c6fc826d_add_ballot_table_and_ballot_id_to_votes.py new file mode 100644 index 0000000..eb38880 --- /dev/null +++ b/migrations/versions/81b4c6fc826d_add_ballot_table_and_ballot_id_to_votes.py @@ -0,0 +1,43 @@ +"""Add ballot table and ballot_id to votes + +Revision ID: 81b4c6fc826d +Revises: 48bf0bdc1ca1 +Create Date: 2025-10-31 17:07:55.504501 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '81b4c6fc826d' +down_revision = '48bf0bdc1ca1' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('ballots', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('voter_uuid', sa.UUID(), nullable=True), + sa.Column('date_created', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.Column('election_ref', sa.String(length=20), nullable=True), + sa.ForeignKeyConstraint(['election_ref'], ['elections.ref'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_ballots_id'), 'ballots', ['id'], unique=False) + op.create_index(op.f('ix_ballots_voter_uuid'), 'ballots', ['voter_uuid'], unique=True) + op.add_column('votes', sa.Column('ballot_id', sa.Integer(), nullable=True)) + op.create_foreign_key(None, 'votes', 'ballots', ['ballot_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'votes', type_='foreignkey') + op.drop_column('votes', 'ballot_id') + op.drop_index(op.f('ix_ballots_voter_uuid'), table_name='ballots') + op.drop_index(op.f('ix_ballots_id'), table_name='ballots') + op.drop_table('ballots') + # ### end Alembic commands ###