@@ -218,9 +218,23 @@ def create_invite_tokens(
218218 _check_election_is_not_ended (get_election (db , election_ref ))
219219 now = datetime .now ()
220220 params = {"date_created" : now , "date_modified" : now , "election_ref" : election_ref }
221- db_votes = [models .Vote (** params ) for _ in range (num_voters * num_candidates )]
222- db .bulk_save_objects (db_votes , return_defaults = True )
223- db .commit ()
221+
222+ try :
223+ db_ballots = [models .Ballot (election_ref = election_ref ) for _ in range (num_voters )]
224+ db .bulk_save_objects (db_ballots , return_defaults = True )
225+
226+ db_votes = []
227+
228+ for ballot in db_ballots :
229+ for _ in range (num_candidates ):
230+ db_votes .append (models .Vote (** params , ballot_id = ballot .id ))
231+
232+ db .bulk_save_objects (db_votes , return_defaults = True )
233+ db .commit ()
234+ except Exception as e :
235+ db .rollback ()
236+ raise e
237+
224238 vote_ids = [int (str (v .id )) for v in db_votes ]
225239 tokens = [
226240 create_ballot_token (vote_ids [i ::num_voters ], election_ref )
@@ -405,14 +419,25 @@ def create_ballot(db: Session, ballot: schemas.BallotCreate) -> schemas.BallotGe
405419 )
406420 _check_ballot_is_consistent (election , ballot )
407421
408- # Ideally, we would use RETURNING but it does not work yet for SQLite
409- db_votes = [
410- models .Vote (** v .model_dump (), election_ref = ballot .election_ref ) for v in ballot .votes
411- ]
412- db .add_all (db_votes )
413- db .commit ()
414- for v in db_votes :
415- db .refresh (v )
422+ try :
423+ db_ballot = models .Ballot (election_ref = ballot .election_ref )
424+ db .add (db_ballot )
425+ db .flush ()
426+
427+ # Create votes and associate them with the ballot
428+ db_votes = [
429+ models .Vote (** v .model_dump (), election_ref = ballot .election_ref , ballot_id = db_ballot .id )
430+ for v in ballot .votes
431+ ]
432+ db .add_all (db_votes )
433+ db .commit ()
434+ db .refresh (db_ballot )
435+
436+ for v in db_votes :
437+ db .refresh (v )
438+ except Exception as e :
439+ db .rollback ()
440+ raise e
416441
417442 votes_get = [schemas .VoteGet .model_validate (v ) for v in db_votes ]
418443 vote_ids = [v .id for v in votes_get ]
@@ -511,6 +536,11 @@ def update_ballot(
511536 if len (db_votes ) != len (vote_ids ):
512537 raise errors .NotFoundError ("votes" )
513538
539+ # Verify all votes belong to the same ballot
540+ ballot_ids = {v .ballot_id for v in db_votes if v .ballot_id is not None }
541+ if len (ballot_ids ) > 1 :
542+ raise errors .ForbiddenError ("All votes must belong to the same ballot" )
543+
514544 election = schemas .ElectionGet .model_validate (db_votes [0 ].election )
515545
516546 for vote , db_vote in zip (ballot .votes , db_votes ):
0 commit comments