diff --git a/axelrod/match_generator.py b/axelrod/match_generator.py index 37038e7b5..38d39ddf8 100644 --- a/axelrod/match_generator.py +++ b/axelrod/match_generator.py @@ -1,6 +1,21 @@ +from dataclasses import dataclass +from typing import Any, Dict, Iterator, Tuple + from axelrod.random_ import BulkRandomGenerator +@dataclass +class MatchChunk(object): + index_pair: Tuple[int] + match_params: Dict[str, Any] + repetitions: int + seed: BulkRandomGenerator + + def as_tuple(self) -> Tuple: + """Kept for legacy reasons""" + return (self.index_pair, self.match_params, self.repetitions, self.seed) + + class MatchGenerator(object): def __init__( self, @@ -62,7 +77,7 @@ def __init__( def __len__(self): return self.size - def build_match_chunks(self): + def build_match_chunks(self) -> Iterator[MatchChunk]: """ A generator that returns player index pairs and match parameters for a round robin tournament. @@ -80,7 +95,12 @@ def build_match_chunks(self): for index_pair in edges: match_params = self.build_single_match_params() r = next(self.random_generator) - yield (index_pair, match_params, self.repetitions, r) + yield MatchChunk( + index_pair=index_pair, + match_params=match_params, + repetitions=self.repetitions, + seed=r, + ) def build_single_match_params(self): """ diff --git a/axelrod/tests/unit/test_match_generator.py b/axelrod/tests/unit/test_match_generator.py index 437b47c96..011a105ca 100644 --- a/axelrod/tests/unit/test_match_generator.py +++ b/axelrod/tests/unit/test_match_generator.py @@ -171,7 +171,7 @@ def test_build_match_chunks(self, repetitions): game=test_game, repetitions=repetitions, ) - chunks = list(rr.build_match_chunks()) + chunks = [chunk.as_tuple() for chunk in rr.build_match_chunks()] match_definitions = [ tuple(list(index_pair) + [repetitions]) for (index_pair, match_params, repetitions, _) in chunks @@ -239,7 +239,7 @@ def test_spatial_build_match_chunks(self, repetitions): edges=cycle, repetitions=repetitions, ) - chunks = list(rr.build_match_chunks()) + chunks = [chunk.as_tuple() for chunk in rr.build_match_chunks()] match_definitions = [ tuple(list(index_pair) + [repetitions]) for (index_pair, match_params, repetitions, _) in chunks diff --git a/axelrod/tests/unit/test_tournament.py b/axelrod/tests/unit/test_tournament.py index d7e8fcd1b..6e813c248 100644 --- a/axelrod/tests/unit/test_tournament.py +++ b/axelrod/tests/unit/test_tournament.py @@ -24,7 +24,7 @@ strategy_lists, tournaments, ) -from axelrod.tournament import _close_objects +from axelrod.tournament import MatchChunk, _close_objects C, D = axl.Action.C, axl.Action.D @@ -663,7 +663,9 @@ def make_chunk_generator(): for player2_index in range(player1_index, len(self.players)): index_pair = (player1_index, player2_index) match_params = {"turns": turns, "game": self.game} - yield (index_pair, match_params, self.test_repetitions, 0) + yield MatchChunk( + index_pair, match_params, self.test_repetitions, 0 + ) chunk_generator = make_chunk_generator() interactions = {} diff --git a/axelrod/tournament.py b/axelrod/tournament.py index c79775489..81e24831c 100644 --- a/axelrod/tournament.py +++ b/axelrod/tournament.py @@ -16,7 +16,7 @@ from .game import Game from .match import Match -from .match_generator import MatchGenerator +from .match_generator import MatchChunk, MatchGenerator from .result_set import ResultSet C, D = Action.C, Action.D @@ -427,7 +427,7 @@ def _worker( done_queue.put("STOP") return True - def _play_matches(self, chunk, build_results=True): + def _play_matches(self, chunk: Match, build_results: bool = True): """ Play matches in a given chunk. @@ -446,14 +446,13 @@ def _play_matches(self, chunk, build_results=True): (0, 1) -> [(C, D), (D, C),...] """ interactions = defaultdict(list) - index_pair, match_params, repetitions, seed = chunk - p1_index, p2_index = index_pair + p1_index, p2_index = chunk.index_pair player1 = self.players[p1_index].clone() player2 = self.players[p2_index].clone() - match_params["players"] = (player1, player2) - match_params["seed"] = seed - match = Match(**match_params) - for _ in range(repetitions): + chunk.match_params["players"] = (player1, player2) + chunk.match_params["seed"] = chunk.seed + match = Match(**chunk.match_params) + for _ in range(chunk.repetitions): match.play() if build_results: @@ -461,7 +460,7 @@ def _play_matches(self, chunk, build_results=True): else: results = None - interactions[index_pair].append([match.result, results]) + interactions[chunk.index_pair].append([match.result, results]) return interactions def _calculate_results(self, interactions):