diff --git a/rewrite/rewrite/python/format/auto_format.py b/rewrite/rewrite/python/format/auto_format.py index d3e94c53..c4f68aed 100644 --- a/rewrite/rewrite/python/format/auto_format.py +++ b/rewrite/rewrite/python/format/auto_format.py @@ -2,6 +2,7 @@ from .blank_lines import BlankLinesVisitor from .normalize_format import NormalizeFormatVisitor +from .remove_trailing_whitespace_visitor import RemoveTrailingWhitespaceVisitor from .spaces_visitor import SpacesVisitor from .normalize_tabs_or_spaces import NormalizeTabsOrSpacesVisitor from .. import TabsAndIndentsStyle @@ -32,4 +33,5 @@ def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> cu.get_style(TabsAndIndentsStyle) or IntelliJ.tabs_and_indents(), self._stop_after ).visit(tree, p, self._cursor.fork()) + tree = RemoveTrailingWhitespaceVisitor(self._stop_after).visit(tree, self._cursor.fork()) return tree diff --git a/rewrite/rewrite/python/format/remove_trailing_whitespace_visitor.py b/rewrite/rewrite/python/format/remove_trailing_whitespace_visitor.py new file mode 100644 index 00000000..81e4e0e4 --- /dev/null +++ b/rewrite/rewrite/python/format/remove_trailing_whitespace_visitor.py @@ -0,0 +1,54 @@ +from typing import Optional, cast, Union + +import rewrite.java as j +import rewrite.python as p +from rewrite import Tree, Marker +from rewrite.java import Space, J, TrailingComma +from rewrite.python import PythonVisitor, CompilationUnit, PySpace +from rewrite.visitor import P, Cursor, T + + +class RemoveTrailingWhitespaceVisitor(PythonVisitor): + def __init__(self, stop_after: Optional[p.Tree] = None): + self._stop_after = stop_after + self._stop = False + + def visit_compilation_unit(self, compilation_unit: CompilationUnit, p: P) -> J: + if not compilation_unit.prefix.comments: + compilation_unit = compilation_unit.with_prefix(Space.EMPTY) + cu: j.CompilationUnit = cast(j.CompilationUnit, super().visit_compilation_unit(compilation_unit, p)) + + if cu.eof.whitespace: + clean = "".join([_ for _ in cu.eof.whitespace if _ in ['\n', '\r']]) + cu = cu.with_eof(cu.eof.with_whitespace(clean)) + + return cu + + def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]], + p: P) -> Space: + s = cast(Space, super().visit_space(space, loc, p)) + if not s or not s.whitespace: + return s + return self._normalize_whitespace(s) + + def visit_marker(self, marker: Marker, p: P) -> Marker: + m = cast(Marker, super().visit_marker(marker, p)) + if isinstance(m, TrailingComma): + return m.with_suffix(self._normalize_whitespace(m.suffix)) + return m + + @staticmethod + def _normalize_whitespace(s): + last_newline = s.whitespace.rfind('\n') + if last_newline > 0: + ws = [c for i, c in enumerate(s.whitespace) if i >= last_newline or c in {'\r', '\n'}] + s = s.with_whitespace(''.join(ws)) + return s + + def post_visit(self, tree: T, p: P) -> Optional[T]: + if self._stop_after and tree == self._stop_after: + self._stop = True + return tree + + def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]: + return tree if self._stop else super().visit(tree, p, parent) diff --git a/rewrite/tests/python/all/format/remove_trailing_whitespace_visitor_test.py b/rewrite/tests/python/all/format/remove_trailing_whitespace_visitor_test.py new file mode 100644 index 00000000..d3ac82d3 --- /dev/null +++ b/rewrite/tests/python/all/format/remove_trailing_whitespace_visitor_test.py @@ -0,0 +1,68 @@ +import pytest + +from rewrite.python.format.remove_trailing_whitespace_visitor import RemoveTrailingWhitespaceVisitor +from rewrite.test import rewrite_run, RecipeSpec, from_visitor, python + + +@pytest.mark.parametrize('n_spaces', list(range(0, 4))) +@pytest.mark.parametrize('linebreaks', ['\n', '\r\n', '\r\n\n', '\n\n']) +def test_tabs_to_spaces(n_spaces, linebreaks): + # noinspection PyInconsistentIndentation + spaces = ' ' * n_spaces + rewrite_run( + python( + # language=python + f"""\ + class Foo:{spaces}{linebreaks}\ + def bar(self): + return 42{linebreaks} + {spaces}{linebreaks} + """, + # language=python + f"""\ + class Foo:{linebreaks}\ + def bar(self): + return 42{linebreaks} + {linebreaks} + """ + ), + spec=RecipeSpec() + .with_recipes( + from_visitor(RemoveTrailingWhitespaceVisitor()) + ) + ) + + +@pytest.mark.parametrize('n_spaces', list(range(0, 4))) +@pytest.mark.parametrize('linebreaks', ['\n', '\r\n', '\r\n\n', '\n\n']) +def test_tabs_to_spaces_with_trailing_comma(n_spaces, linebreaks): + # noinspection PyInconsistentIndentation + spaces = ' ' * n_spaces + rewrite_run( + python( + # language=python + f"""\ + def bar():{spaces}{linebreaks}\ + return [ + 1,{spaces}{linebreaks}\ + 2, + 3,{spaces}{linebreaks}\ + ] + {spaces}{linebreaks} + """, + # language=python + f"""\ + def bar():{linebreaks}\ + return [ + 1,{linebreaks}\ + 2, + 3,{linebreaks}\ + ] + {linebreaks} + """ + ), + spec=RecipeSpec() + .with_recipes( + from_visitor(RemoveTrailingWhitespaceVisitor()) + ) + )