Skip to content

Commit 14d9fa4

Browse files
committed
Fix parser bug and implement ast_unparse in Python
1 parent 4f91ea3 commit 14d9fa4

File tree

2 files changed

+78
-78
lines changed

2 files changed

+78
-78
lines changed

Lib/_ast_unparse.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -573,21 +573,11 @@ def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
573573
quote_type = quote_types[0]
574574
self.write(f"{quote_type}{string}{quote_type}")
575575

576-
def visit_JoinedStr(self, node):
577-
self.write("f")
578-
579-
fstring_parts = []
580-
for value in node.values:
581-
with self.buffered() as buffer:
582-
self._write_fstring_inner(value)
583-
fstring_parts.append(
584-
("".join(buffer), isinstance(value, Constant))
585-
)
586-
587-
new_fstring_parts = []
576+
def _ftstring_helper(self, node, ftstring_parts):
577+
new_ftstring_parts = []
588578
quote_types = list(_ALL_QUOTES)
589579
fallback_to_repr = False
590-
for value, is_constant in fstring_parts:
580+
for value, is_constant in ftstring_parts:
591581
if is_constant:
592582
value, new_quote_types = self._str_literal_helper(
593583
value,
@@ -606,30 +596,47 @@ def visit_JoinedStr(self, node):
606596
new_quote_types = [q for q in quote_types if q not in value]
607597
if new_quote_types:
608598
quote_types = new_quote_types
609-
new_fstring_parts.append(value)
599+
new_ftstring_parts.append(value)
610600

611601
if fallback_to_repr:
612602
# If we weren't able to find a quote type that works for all parts
613603
# of the JoinedStr, fallback to using repr and triple single quotes.
614604
quote_types = ["'''"]
615-
new_fstring_parts.clear()
616-
for value, is_constant in fstring_parts:
605+
new_ftstring_parts.clear()
606+
for value, is_constant in ftstring_parts:
617607
if is_constant:
618608
value = repr('"' + value) # force repr to use single quotes
619609
expected_prefix = "'\""
620610
assert value.startswith(expected_prefix), repr(value)
621611
value = value[len(expected_prefix):-1]
622-
new_fstring_parts.append(value)
612+
new_ftstring_parts.append(value)
623613

624-
value = "".join(new_fstring_parts)
614+
value = "".join(new_ftstring_parts)
625615
quote_type = quote_types[0]
626616
self.write(f"{quote_type}{value}{quote_type}")
627617

628-
def _write_fstring_inner(self, node, is_format_spec=False):
618+
def _write_ftstring(self, node, prefix):
619+
self.write(prefix)
620+
fstring_parts = []
621+
for value in node.values:
622+
with self.buffered() as buffer:
623+
self._write_ftstring_inner(value)
624+
fstring_parts.append(
625+
("".join(buffer), isinstance(value, Constant))
626+
)
627+
self._ftstring_helper(node, fstring_parts)
628+
629+
def visit_JoinedStr(self, node):
630+
self._write_ftstring(node, "f")
631+
632+
def visit_TemplateStr(self, node):
633+
self._write_ftstring(node, "t")
634+
635+
def _write_ftstring_inner(self, node, is_format_spec=False):
629636
if isinstance(node, JoinedStr):
630637
# for both the f-string itself, and format_spec
631638
for value in node.values:
632-
self._write_fstring_inner(value, is_format_spec=is_format_spec)
639+
self._write_ftstring_inner(value, is_format_spec=is_format_spec)
633640
elif isinstance(node, Constant) and isinstance(node.value, str):
634641
value = node.value.replace("{", "{{").replace("}", "}}")
635642

@@ -641,26 +648,41 @@ def _write_fstring_inner(self, node, is_format_spec=False):
641648
self.write(value)
642649
elif isinstance(node, FormattedValue):
643650
self.visit_FormattedValue(node)
651+
elif isinstance(node, Interpolation):
652+
self.visit_Interpolation(node)
644653
else:
645654
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
646655

647-
def visit_FormattedValue(self, node):
648-
def unparse_inner(inner):
649-
unparser = type(self)()
650-
unparser.set_precedence(_Precedence.TEST.next(), inner)
651-
return unparser.visit(inner)
656+
def _unparse_interpolation_value(self, inner):
657+
unparser = type(self)()
658+
unparser.set_precedence(_Precedence.TEST.next(), inner)
659+
return unparser.visit(inner)
660+
661+
def _write_fstring_conversion(self, node):
662+
if node.conversion != -1:
663+
self.write(f"!{chr(node.conversion)}")
664+
665+
def _write_tstring_conversion(self, node):
666+
if node.conversion is not None:
667+
self.write(f"!{node.conversion}")
652668

669+
def _write_interpolation(self, node, write_conversion):
653670
with self.delimit("{", "}"):
654-
expr = unparse_inner(node.value)
671+
expr = self._unparse_interpolation_value(node.value)
655672
if expr.startswith("{"):
656673
# Separate pair of opening brackets as "{ {"
657674
self.write(" ")
658675
self.write(expr)
659-
if node.conversion != -1:
660-
self.write(f"!{chr(node.conversion)}")
676+
write_conversion(node)
661677
if node.format_spec:
662678
self.write(":")
663-
self._write_fstring_inner(node.format_spec, is_format_spec=True)
679+
self._write_ftstring_inner(node.format_spec, is_format_spec=True)
680+
681+
def visit_FormattedValue(self, node):
682+
self._write_interpolation(node, self._write_fstring_conversion)
683+
684+
def visit_Interpolation(self, node):
685+
self._write_interpolation(node, self._write_tstring_conversion)
664686

665687
def visit_Name(self, node):
666688
self.write(node.id)

Parser/action_helpers.c

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,8 +1654,8 @@ _build_concatenated_unicode(Parser *p, asdl_expr_seq *strings, int lineno,
16541654
end_lineno, end_col_offset, arena);
16551655
}
16561656

1657-
static expr_ty
1658-
_build_concatenated_joined_str(Parser *p, asdl_expr_seq *strings,
1657+
static asdl_expr_seq *
1658+
_build_concatenated_str(Parser *p, asdl_expr_seq *strings,
16591659
int lineno, int col_offset, int end_lineno,
16601660
int end_col_offset, PyArena *arena)
16611661
{
@@ -1669,6 +1669,9 @@ _build_concatenated_joined_str(Parser *p, asdl_expr_seq *strings,
16691669
case JoinedStr_kind:
16701670
n_flattened_elements += asdl_seq_LEN(elem->v.JoinedStr.values);
16711671
break;
1672+
case TemplateStr_kind:
1673+
n_flattened_elements += asdl_seq_LEN(elem->v.TemplateStr.values);
1674+
break;
16721675
default:
16731676
n_flattened_elements++;
16741677
break;
@@ -1695,6 +1698,15 @@ _build_concatenated_joined_str(Parser *p, asdl_expr_seq *strings,
16951698
asdl_seq_SET(flattened, current_pos++, subvalue);
16961699
}
16971700
break;
1701+
case TemplateStr_kind:
1702+
for (Py_ssize_t j = 0; j < asdl_seq_LEN(elem->v.TemplateStr.values); j++) {
1703+
expr_ty subvalue = asdl_seq_GET(elem->v.TemplateStr.values, j);
1704+
if (subvalue == NULL) {
1705+
return NULL;
1706+
}
1707+
asdl_seq_SET(flattened, current_pos++, subvalue);
1708+
}
1709+
break;
16981710
default:
16991711
asdl_seq_SET(flattened, current_pos++, elem);
17001712
break;
@@ -1795,6 +1807,16 @@ _build_concatenated_joined_str(Parser *p, asdl_expr_seq *strings,
17951807
}
17961808

17971809
assert(current_pos == n_elements);
1810+
return values;
1811+
}
1812+
1813+
static expr_ty
1814+
_build_concatenated_joined_str(Parser *p, asdl_expr_seq *strings,
1815+
int lineno, int col_offset, int end_lineno,
1816+
int end_col_offset, PyArena *arena)
1817+
{
1818+
asdl_expr_seq *values = _build_concatenated_str(p, strings, lineno,
1819+
col_offset, end_lineno, end_col_offset, arena);
17981820
return _PyAST_JoinedStr(values, lineno, col_offset, end_lineno, end_col_offset, p->arena);
17991821
}
18001822

@@ -1803,53 +1825,9 @@ _build_concatenated_template_str(Parser *p, asdl_expr_seq *strings,
18031825
int lineno, int col_offset, int end_lineno,
18041826
int end_col_offset, PyArena *arena)
18051827
{
1806-
Py_ssize_t len = asdl_seq_LEN(strings);
1807-
assert(len > 0);
1808-
1809-
Py_ssize_t n_flattened_elements = 0;
1810-
for (Py_ssize_t i = 0; i < len; i++) {
1811-
expr_ty elem = asdl_seq_GET(strings, i);
1812-
switch(elem->kind) {
1813-
case TemplateStr_kind:
1814-
n_flattened_elements += asdl_seq_LEN(elem->v.JoinedStr.values);
1815-
break;
1816-
default:
1817-
n_flattened_elements++;
1818-
break;
1819-
}
1820-
}
1821-
1822-
1823-
asdl_expr_seq* flattened = _Py_asdl_expr_seq_new(n_flattened_elements, p->arena);
1824-
if (flattened == NULL) {
1825-
return NULL;
1826-
}
1827-
1828-
Py_ssize_t pos = 0;
1829-
for (Py_ssize_t i = 0; i < len; i++) {
1830-
expr_ty elem = asdl_seq_GET(strings, i);
1831-
1832-
switch (elem->kind) {
1833-
case TemplateStr_kind:
1834-
for (Py_ssize_t j = 0; j < asdl_seq_LEN(elem->v.TemplateStr.values); j++) {
1835-
expr_ty subitem = asdl_seq_GET(elem->v.TemplateStr.values, j);
1836-
asdl_seq_SET(flattened, pos++, subitem);
1837-
}
1838-
break;
1839-
case JoinedStr_kind: {
1840-
expr_ty joined_str = _build_concatenated_joined_str(p,
1841-
elem->v.JoinedStr.values, lineno, col_offset,
1842-
end_lineno, end_col_offset, arena);
1843-
asdl_seq_SET(flattened, pos++, joined_str);
1844-
break;
1845-
}
1846-
default:
1847-
asdl_seq_SET(flattened, pos++, elem);
1848-
break;
1849-
}
1850-
}
1851-
1852-
return _PyAST_TemplateStr(flattened, lineno, col_offset, end_lineno,
1828+
asdl_expr_seq *values = _build_concatenated_str(p, strings, lineno,
1829+
col_offset, end_lineno, end_col_offset, arena);
1830+
return _PyAST_TemplateStr(values, lineno, col_offset, end_lineno,
18531831
end_col_offset, arena);
18541832
}
18551833

0 commit comments

Comments
 (0)