Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions .flake8

This file was deleted.

42 changes: 20 additions & 22 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,27 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 flake8-docstrings
- name: Lint with flake8
run: flake8 .
- uses: actions/checkout@v5
- uses: astral-sh/ruff-action@v3


test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install . --user
- name: Test using unittests
run: python -m unittest --failfast
- uses: actions/checkout@v5

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.8.22"

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run tests
run: uv run python -m unittest --failfast
3 changes: 3 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
tags:
- v*.*.*

permissions:
contents: write

jobs:
pypi-publish:
name: Upload release to PyPI
Expand Down
60 changes: 25 additions & 35 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
This module contains the :class:`Battle` class, which speciefies how each type of battle is fought and scored,
some basic battle types, and related classed.
"""
from abc import abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import StrEnum
from importlib.metadata import entry_points
from abc import abstractmethod
from inspect import isclass
from itertools import count
from pathlib import Path
Expand All @@ -16,7 +17,6 @@
Annotated,
Any,
ClassVar,
Iterable,
Literal,
ParamSpec,
Protocol,
Expand All @@ -26,9 +26,8 @@
Unpack,
overload,
)
from typing_extensions import TypedDict
from annotated_types import Ge

from annotated_types import Ge
from pydantic import (
ConfigDict,
Field,
Expand All @@ -39,11 +38,10 @@
ValidatorFunctionWrapHandler,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import (
tagged_union_schema,
with_info_wrap_validator_function,
)
from pydantic_core.core_schema import tagged_union_schema, with_info_wrap_validator_function
from typing_extensions import TypedDict

from algobattle.problem import InstanceModel, Problem, SolutionModel
from algobattle.program import (
Generator,
GeneratorResult,
Expand All @@ -53,15 +51,7 @@
Solver,
SolverResult,
)
from algobattle.problem import InstanceModel, Problem, SolutionModel
from algobattle.util import (
Encodable,
EncodableModel,
ExceptionInfo,
BaseModel,
Role,
)

from algobattle.util import BaseModel, Encodable, EncodableModel, ExceptionInfo, Role

_BattleConfig: TypeAlias = Any
"""Type alias used to generate correct typings when subclassing :class:`Battle`.
Expand All @@ -73,7 +63,7 @@
"""
T = TypeVar("T")
P = ParamSpec("P")
Type = type
_type = type


class ProgramLogConfigTime(StrEnum):
Expand All @@ -91,7 +81,7 @@ class ProgramLogConfigLocation(StrEnum):
inline = "inline"


class ProgramLogConfigView(Protocol): # noqa: D101
class ProgramLogConfigView(Protocol):
when: ProgramLogConfigTime = ProgramLogConfigTime.error
output: ProgramLogConfigLocation = ProgramLogConfigLocation.inline

Expand All @@ -100,7 +90,7 @@ class ProgramRunInfo(BaseModel):
"""Data about a program's execution."""

runtime: float = 0
overriden: RunConfigOverride = Field(default_factory=dict)
overriden: RunConfigOverride = Field(default_factory=RunConfigOverride, validate_default=True)
error: ExceptionInfo | None = None
battle_data: SerializeAsAny[EncodableModel] | None = None
instance: SerializeAsAny[InstanceModel] | None = None
Expand Down Expand Up @@ -177,7 +167,7 @@ def end_fight(self) -> None:
"""Informs the ui that the fight has finished running and has been added to the battle's `.fight_results`."""


class RunKwargs(TypedDict, total=False):
class RunKwargs(TypedDict, total=False, closed=True):
"""The keyword arguments used by the FightHandler.run family of functions."""

timeout_generator: float | None
Expand Down Expand Up @@ -394,7 +384,7 @@ class Config(BaseModel):
"""Type of battle that will be used."""

@classmethod
def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandler) -> CoreSchema:
def __get_pydantic_core_schema__(cls, source: _type, handler: GetCoreSchemaHandler) -> CoreSchema:
# there's two bugs we need to catch:
# 1. this function is called during the pydantic BaseModel metaclass's __new__, so the BattleConfig class
# won't be ready at that point and be missing in the namespace
Expand Down Expand Up @@ -446,7 +436,7 @@ def check_installed(val: object, handler: ValidatorFunctionWrapHandler, info: Va
installed = ", ".join(b.name() for b in Battle._battle_types.values())
raise ValueError(
f"The specified battle type '{passed}' is not installed. Installed types are: {installed}"
)
) from e

return with_info_wrap_validator_function(check_installed, subclass_schema)

Expand All @@ -459,7 +449,7 @@ class FallbackConfig(Config):

if TYPE_CHECKING:
# to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
def __getattr__(self, __attr: str) -> Any:
def __getattr__(self, attr: str, /) -> Any:
...

class UiData(BaseModel):
Expand Down Expand Up @@ -560,7 +550,7 @@ class Config(Battle.Config):
exit after that many failures, or `"unlimited"` to never exit early.
"""

class UiData(Battle.UiData): # noqa: D106
class UiData(Battle.UiData):
reached: list[int]
cap: int
note: str
Expand Down Expand Up @@ -620,7 +610,7 @@ def score(self, config: Config) -> float:
return 0 if len(self.results) == 0 else sum(self.results) / len(self.results)

@staticmethod
def format_score(score: float) -> str: # noqa: D102
def format_score(score: float) -> str:
return str(int(score))


Expand All @@ -637,7 +627,7 @@ class Config(Battle.Config):
num_fights: int = 10
"""Number of iterations in each round."""

class UiData(Battle.UiData): # noqa: D106
class UiData(Battle.UiData):
round: int

async def run_battle(self, fight: FightHandler, config: Config, min_size: int, ui: BattleUi) -> None:
Expand All @@ -659,7 +649,7 @@ def score(self, config: Config) -> float:
return sum(f.score for f in self.fights) / len(self.fights)

@staticmethod
def format_score(score: float) -> str: # noqa: D102
def format_score(score: float) -> str:
return format(score, ".0%")


Expand All @@ -681,7 +671,7 @@ class Fight:
gen_sols: set[Role]
sol_sols: set[Role]

def encode(self, target: Path, role: Role) -> None: # noqa: D102
def encode(self, target: Path, role: Role) -> None:
target.mkdir()
for i, fight in enumerate(self.history):
fight_dir = target / str(i)
Expand Down Expand Up @@ -715,16 +705,16 @@ class Config(Battle.Config):
"""Number of fights that will be fought."""
weighting: Annotated[float, Ge(0)] = 1.1
"""How much each successive fight should be weighted more than the previous."""
scores: set[Role] = {Role.generator, Role.solver}
scores: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
"""Who to show each fight's scores to."""
instances: set[Role] = {Role.generator, Role.solver}
instances: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
"""Who to show the instances to."""
generator_solutions: set[Role] = {Role.generator}
generator_solutions: set[Role] = {Role.generator} # noqa: RUF012
"""Who to show the generator's solutions to, if the problem requires them."""
solver_solutions: set[Role] = {Role.solver}
solver_solutions: set[Role] = {Role.solver} # noqa: RUF012
"""Who to show the solver's solutions to."""

class UiData(Battle.UiData): # noqa: D106
class UiData(Battle.UiData):
round: int

async def run_battle(self, fight: FightHandler, config: Config, min_size: int, ui: BattleUi) -> None:
Expand Down Expand Up @@ -760,5 +750,5 @@ def score(self, config: Config) -> float:
return total / quotient

@staticmethod
def format_score(score: float) -> str: # noqa: D102
def format_score(score: float) -> str:
return format(score, ".0%")
Loading