Skip to content

Commit 8cfb2ce

Browse files
committed
Added: ballot bdd to link votes together for 1 voter
1 parent a41901b commit 8cfb2ce

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

app/crud.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,23 @@ def create_invite_tokens(
218218
_check_election_is_not_ended(get_election(db, election_ref))
219219
now = datetime.now()
220220
params = {"date_created": now, "date_modified": now, "election_ref": election_ref}
221-
db_votes = [models.Vote(**params) for _ in range(num_voters * num_candidates)]
222-
db.bulk_save_objects(db_votes, return_defaults=True)
223-
db.commit()
221+
222+
try:
223+
db_ballots = [models.Ballot(election_ref=election_ref) for _ in range(num_voters)]
224+
db.bulk_save_objects(db_ballots, return_defaults=True)
225+
226+
db_votes = []
227+
228+
for ballot in db_ballots:
229+
for _ in range(num_candidates):
230+
db_votes.append(models.Vote(**params, ballot_id=ballot.id))
231+
232+
db.bulk_save_objects(db_votes, return_defaults=True)
233+
db.commit()
234+
except Exception as e:
235+
db.rollback()
236+
raise e
237+
224238
vote_ids = [int(str(v.id)) for v in db_votes]
225239
tokens = [
226240
create_ballot_token(vote_ids[i::num_voters], election_ref)
@@ -405,14 +419,25 @@ def create_ballot(db: Session, ballot: schemas.BallotCreate) -> schemas.BallotGe
405419
)
406420
_check_ballot_is_consistent(election, ballot)
407421

408-
# Ideally, we would use RETURNING but it does not work yet for SQLite
409-
db_votes = [
410-
models.Vote(**v.model_dump(), election_ref=ballot.election_ref) for v in ballot.votes
411-
]
412-
db.add_all(db_votes)
413-
db.commit()
414-
for v in db_votes:
415-
db.refresh(v)
422+
try:
423+
db_ballot = models.Ballot(election_ref=ballot.election_ref)
424+
db.add(db_ballot)
425+
db.flush()
426+
427+
# Create votes and associate them with the ballot
428+
db_votes = [
429+
models.Vote(**v.model_dump(), election_ref=ballot.election_ref, ballot_id=db_ballot.id)
430+
for v in ballot.votes
431+
]
432+
db.add_all(db_votes)
433+
db.commit()
434+
db.refresh(db_ballot)
435+
436+
for v in db_votes:
437+
db.refresh(v)
438+
except Exception as e:
439+
db.rollback()
440+
raise e
416441

417442
votes_get = [schemas.VoteGet.model_validate(v) for v in db_votes]
418443
vote_ids = [v.id for v in votes_get]
@@ -511,6 +536,11 @@ def update_ballot(
511536
if len(db_votes) != len(vote_ids):
512537
raise errors.NotFoundError("votes")
513538

539+
# Verify all votes belong to the same ballot
540+
ballot_ids = {v.ballot_id for v in db_votes if v.ballot_id is not None}
541+
if len(ballot_ids) > 1:
542+
raise errors.ForbiddenError("All votes must belong to the same ballot")
543+
514544
election = schemas.ElectionGet.model_validate(db_votes[0].election)
515545

516546
for vote, db_vote in zip(ballot.votes, db_votes):

app/models.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from sqlalchemy.sql import func
33
from sqlalchemy.orm import relationship
44
from .database import Base
5-
5+
import uuid # Pour pouvoir appeler uuid.uuid4()
6+
from sqlalchemy import UUID
67

78
class Election(Base):
89
__tablename__ = "elections"
@@ -24,6 +25,7 @@ class Election(Base):
2425
grades = relationship("Grade", back_populates="election")
2526
candidates = relationship("Candidate", back_populates="election")
2627
votes = relationship("Vote", back_populates="election")
28+
ballots = relationship("Ballot", back_populates="election")
2729

2830

2931
class Candidate(Base):
@@ -72,3 +74,20 @@ class Vote(Base):
7274

7375
election_ref = Column(String(20), ForeignKey("elections.ref"))
7476
election = relationship("Election", back_populates="votes")
77+
78+
ballot_id = Column(Integer, ForeignKey("ballots.id"), nullable=True)
79+
ballot = relationship("Ballot", back_populates="votes")
80+
81+
82+
class Ballot(Base):
83+
__tablename__ = "ballots"
84+
85+
id = Column(Integer, primary_key=True, index=True)
86+
87+
voter_uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True)
88+
89+
date_created = Column(DateTime, server_default=func.now())
90+
election_ref = Column(String(20), ForeignKey("elections.ref"))
91+
92+
election = relationship("Election", back_populates="ballots")
93+
votes = relationship("Vote", back_populates="ballot")

0 commit comments

Comments
 (0)