Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/griffe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@
ExprFormatted,
ExprGeneratorExp,
ExprIfExp,
ExprInterpolation,
ExprJoinedStr,
ExprKeyword,
ExprLambda,
Expand All @@ -310,6 +311,7 @@
ExprSetComp,
ExprSlice,
ExprSubscript,
ExprTemplateStr,
ExprTuple,
ExprUnaryOp,
ExprVarKeyword,
Expand Down Expand Up @@ -453,6 +455,7 @@
"ExprFormatted",
"ExprGeneratorExp",
"ExprIfExp",
"ExprInterpolation",
"ExprJoinedStr",
"ExprKeyword",
"ExprLambda",
Expand All @@ -465,6 +468,7 @@
"ExprSetComp",
"ExprSlice",
"ExprSubscript",
"ExprTemplateStr",
"ExprTuple",
"ExprUnaryOp",
"ExprVarKeyword",
Expand Down
52 changes: 52 additions & 0 deletions src/griffe/_internal/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import ast
import sys
from dataclasses import dataclass
from dataclasses import fields as getfields
from enum import IntEnum, auto
Expand Down Expand Up @@ -532,6 +533,20 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
yield from _yield(self.orelse, flat=flat, outer_precedence=precedence, is_left=False)


@dataclass(eq=True, slots=True)
class ExprInterpolation(Expr):
"""Template string interpolation like `{name}`."""

value: str | Expr
"""Interpolated value."""

def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
yield "{"
# Prevent parentheses from being added, avoiding `{(1 + 1)}`
yield from _yield(self.value, flat=flat, outer_precedence=_OperatorPrecedence.NONE)
yield "}"


@dataclass(eq=True, slots=True)
class ExprJoinedStr(Expr):
"""Joined strings like `f"a {b} c"`."""
Expand Down Expand Up @@ -915,6 +930,19 @@ def canonical_path(self) -> str:
return self.left.canonical_path


@dataclass(eq=True, slots=True)
class ExprTemplateStr(Expr):
"""Template strings like `t"a {name}"`."""

values: Sequence[str | Expr]
"""Joined values."""

def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
yield "t'"
yield from _join(self.values, "", flat=flat)
yield "'"


@dataclass(eq=True, slots=True)
class ExprTuple(Expr):
"""Tuples like `(0, 1, 2)`."""
Expand Down Expand Up @@ -1213,6 +1241,12 @@ def _build_ifexp(node: ast.IfExp, parent: Module | Class, **kwargs: Any) -> Expr
)


if sys.version_info >= (3, 14):

def _build_interpolation(node: ast.Interpolation, parent: Module | Class, **kwargs: Any) -> Expr:
return ExprInterpolation(_build(node.value, parent, **kwargs))


def _build_joinedstr(
node: ast.JoinedStr,
parent: Module | Class,
Expand Down Expand Up @@ -1311,6 +1345,16 @@ def _build_subscript(
return ExprSubscript(left, slice_expr)


if sys.version_info >= (3, 14):

def _build_templatestr(
node: ast.TemplateStr,
parent: Module | Class,
**kwargs: Any,
) -> Expr:
return ExprTemplateStr([_build(value, parent, in_joined_str=True, **kwargs) for value in node.values])


def _build_tuple(
node: ast.Tuple,
parent: Module | Class,
Expand Down Expand Up @@ -1369,6 +1413,14 @@ def __call__(self, node: Any, parent: Module | Class, **kwargs: Any) -> Expr: ..
ast.YieldFrom: _build_yield_from,
}

if sys.version_info >= (3, 14):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine all 3 version checks in this file into one block

_node_map.update(
{
ast.Interpolation: _build_interpolation,
ast.TemplateStr: _build_templatestr,
},
)


def _build(node: ast.AST, parent: Module | Class, /, **kwargs: Any) -> Expr:
return _node_map[type(node)](node, parent, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import sys
from ast import PyCF_ONLY_AST

import pytest
Expand Down Expand Up @@ -51,6 +52,7 @@
"call(something=something)",
# Strings.
"f'a {round(key, 2)} {z}'",
*(["t'a {round(key, 2)} {z}'"] if sys.version_info >= (3, 14) else []),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also try something like

Suggested change
*(["t'a {round(key, 2)} {z}'"] if sys.version_info >= (3, 14) else []),
pytest.param("t'a {round(key, 2)} {z}'", marks=pytest.mark.skipif(sys.version_info < (3, 14))),

# Slices.
"o[x]",
"o[x, y]",
Expand Down
Loading