Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/2781.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When :confval:`strict_markers` is enabled, marker names used in :option:`-m` expressions are now validated against registered markers.
30 changes: 30 additions & 0 deletions src/_pytest/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing import Final
from typing import final
from typing import IO
from typing import NamedTuple
from typing import TextIO
from typing import TYPE_CHECKING
import warnings
Expand Down Expand Up @@ -1012,6 +1013,19 @@ def __len__(self) -> int:
return len(self._config._inicfg)


class RegisteredMarker(NamedTuple):
"""A marker registered in the configuration.

:param name: The marker name (e.g., ``skipif``).
:param signature: The full marker signature (e.g., ``skipif(condition)``).
:param description: The marker description.
"""

name: str
signature: str
description: str


@final
class Config:
"""Access to configuration values, pluginmanager and plugin hooks.
Expand Down Expand Up @@ -1671,6 +1685,22 @@ def getini(self, name: str) -> Any:
self._inicache[canonical_name] = val = self._getini(canonical_name)
return val

def _iter_registered_markers(self) -> Iterator[RegisteredMarker]:
"""Iterate over all markers registered in the configuration.

Yields :class:`RegisteredMarker` named tuples with ``name``,
``signature``, and ``description`` fields.
"""
for line in self.getini("markers"):
# Example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
parts = line.split(":", 1)
signature = parts[0]
description = parts[1].strip() if len(parts) == 2 else ""
name = signature.split("(")[0].strip()
yield RegisteredMarker(name, signature, description)

# Meant for easy monkeypatching by legacypath plugin.
# Can be inlined back (with no cover removed) once legacypath is gone.
def _getini_unknown_type(self, name: str, type: str, value: object):
Expand Down
34 changes: 28 additions & 6 deletions src/_pytest/mark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,9 @@ def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.markers:
config._do_configure()
tw = _pytest.config.create_terminal_writer(config)
for line in config.getini("markers"):
parts = line.split(":", 1)
name = parts[0]
rest = parts[1] if len(parts) == 2 else ""
tw.write(f"@pytest.mark.{name}:", bold=True)
tw.line(rest)
for marker in config._iter_registered_markers():
tw.write(f"@pytest.mark.{marker.signature}:", bold=True)
tw.line(f" {marker.description}" if marker.description else "")
tw.line()
config._ensure_unconfigure()
return 0
Expand Down Expand Up @@ -258,6 +255,10 @@ def deselect_by_mark(items: list[Item], config: Config) -> None:
return

expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")

# Validate marker names in the expression if strict_markers is enabled.
_validate_marker_names(expr, config)

remaining: list[Item] = []
deselected: list[Item] = []
for item in items:
Expand All @@ -270,6 +271,27 @@ def deselect_by_mark(items: list[Item], config: Config) -> None:
items[:] = remaining


def _validate_marker_names(expr: Expression, config: Config) -> None:
"""Validate that all marker names in the expression are registered.

Only validates when strict_markers is enabled.
"""
strict_markers = config.getini("strict_markers")
if strict_markers is None:
strict_markers = config.getini("strict")
if not strict_markers:
return

registered_markers = {marker.name for marker in config._iter_registered_markers()}
unknown_markers = expr.idents() - registered_markers
if unknown_markers:
unknown_str = ", ".join(sorted(unknown_markers))
raise UsageError(
f"Unknown marker(s) in '-m' expression: {unknown_str}. "
"Use 'pytest --markers' to see available markers."
)


def _parse_expression(expr: str, exc_message: str) -> Expression:
try:
return Expression.compile(expr)
Expand Down
24 changes: 17 additions & 7 deletions src/_pytest/mark/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ class Token:


class Scanner:
__slots__ = ("current", "input", "tokens")
__slots__ = ("current", "idents", "input", "tokens")

def __init__(self, input: str) -> None:
self.input = input
self.idents: set[str] = set()
self.tokens = self.lex(input)
self.current = next(self.tokens)

Expand Down Expand Up @@ -163,13 +164,13 @@ def reject(self, expected: Sequence[TokenType]) -> NoReturn:
IDENT_PREFIX = "$"


def expression(s: Scanner) -> ast.Expression:
def expression(s: Scanner) -> tuple[ast.Expression, frozenset[str]]:
if s.accept(TokenType.EOF):
ret: ast.expr = ast.Constant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
return ast.fix_missing_locations(ast.Expression(ret)), frozenset(s.idents)


def expr(s: Scanner) -> ast.expr:
Expand Down Expand Up @@ -197,6 +198,7 @@ def not_expr(s: Scanner) -> ast.expr:
return ret
ident = s.accept(TokenType.IDENT)
if ident:
s.idents.add(ident.value)
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
Expand Down Expand Up @@ -314,12 +316,16 @@ class Expression:
The expression can be evaluated against different matchers.
"""

__slots__ = ("_code", "input")
__slots__ = ("_code", "_idents", "input")

def __init__(self, input: str, code: types.CodeType) -> None:
def __init__(
self, input: str, code: types.CodeType, idents: frozenset[str]
) -> None:
#: The original input line, as a string.
self.input: Final = input
self._code: Final = code
#: All identifiers which appear in the expression.
self._idents: Final = idents

@classmethod
def compile(cls, input: str) -> Expression:
Expand All @@ -329,13 +335,17 @@ def compile(cls, input: str) -> Expression:

:raises SyntaxError: If the expression is malformed.
"""
astexpr = expression(Scanner(input))
astexpr, idents = expression(Scanner(input))
code = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
)
return Expression(input, code)
return Expression(input, code, idents)

def idents(self) -> frozenset[str]:
"""Return the set of all identifiers which appear in the expression."""
return self._idents

def evaluate(self, matcher: ExpressionMatcher) -> bool:
"""Evaluate the match expression.
Expand Down
9 changes: 3 additions & 6 deletions src/_pytest/mark/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,9 @@ def __getattr__(self, name: str) -> MarkDecorator:
# name is in the set we definitely know it, but a mark may be known and
# not in the set. We therefore start by updating the set!
if name not in self._markers:
for line in self._config.getini("markers"):
# example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
marker = line.split(":")[0].split("(")[0].strip()
self._markers.add(marker)
self._markers.update(
m.name for m in self._config._iter_registered_markers()
)

# If the name is not in the set of known marks after updating,
# then it really is time to issue a warning or an error.
Expand Down
60 changes: 60 additions & 0 deletions testing/test_mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

import os
import sys
from typing import cast
from unittest import mock

from _pytest.config import ExitCode
from _pytest.config import UsageError
from _pytest.mark import _validate_marker_names
from _pytest.mark import MarkGenerator
from _pytest.mark.expression import Expression
from _pytest.mark.structures import EMPTY_PARAMETERSET_OPTION
from _pytest.nodes import Collector
from _pytest.nodes import Node
Expand Down Expand Up @@ -213,6 +217,62 @@ def test_hello():
)


class TestValidateMarkerNames:
"""Tests for _validate_marker_names (issue #2781)."""

class FakeConfig:
def __init__(
self,
markers: list[str],
strict_markers: bool | None = None,
strict: bool = False,
) -> None:
self._ini: dict[str, list[str] | bool | None] = {
"markers": markers,
"strict_markers": strict_markers,
"strict": strict,
}

def getini(self, name: str) -> list[str] | bool | None:
return self._ini[name]

def _make_config(
self,
strict_markers: bool | None = None,
strict: bool = False,
) -> pytest.Config:
return cast(
pytest.Config,
self.FakeConfig(
markers=["registered: a registered marker"],
strict_markers=strict_markers,
strict=strict,
),
)

def test_unknown_marker_with_strict_markers(self) -> None:
expr = Expression.compile("unknown_marker")

with pytest.raises(UsageError, match=r"Unknown marker.*unknown_marker"):
_validate_marker_names(expr, self._make_config(strict_markers=True))

def test_unknown_marker_with_strict(self) -> None:
expr = Expression.compile("unknown_marker")

with pytest.raises(UsageError, match=r"Unknown marker.*unknown_marker"):
_validate_marker_names(expr, self._make_config(strict=True))

def test_registered_marker_passes(self) -> None:
expr = Expression.compile("registered")

_validate_marker_names(expr, self._make_config(strict_markers=True))

def test_no_validation_without_strict(self) -> None:
expr = Expression.compile("any_marker")

_validate_marker_names(expr, self._make_config())


@pytest.mark.parametrize(
("expr", "expected_passed"),
[
Expand Down
18 changes: 18 additions & 0 deletions testing/test_mark_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,21 @@ def test_str_keyword_expressions(
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected


@pytest.mark.parametrize(
("expr", "expected_idents"),
(
("", frozenset()),
("foo", frozenset(["foo"])),
("foo and bar", frozenset(["foo", "bar"])),
("foo or bar", frozenset(["foo", "bar"])),
("not foo", frozenset(["foo"])),
("(foo and bar) or baz", frozenset(["foo", "bar", "baz"])),
("foo and foo", frozenset(["foo"])), # Duplicates are deduplicated.
("mark(a=1)", frozenset(["mark"])), # Only marker name, not kwargs.
),
)
def test_expression_idents(expr: str, expected_idents: frozenset[str]) -> None:
"""Test that Expression.idents() returns the identifiers in the expression."""
assert Expression.compile(expr).idents() == expected_idents
Loading