Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, cast, Union

import rewrite.java as j
import rewrite.python as p
from rewrite import Tree
from rewrite.java import Space, J
from rewrite.python import PythonVisitor, CompilationUnit, PySpace
from rewrite.visitor import P, Cursor, T


class RemoveTrailingWhitespaceVisitor(PythonVisitor):
def __init__(self, stop_after: Optional[p.Tree] = None):
self._stop_after = stop_after
self._stop = False

def visit_compilation_unit(self, compilation_unit: CompilationUnit, p: P) -> J:
if not compilation_unit.prefix.comments:
compilation_unit = compilation_unit.with_prefix(Space.EMPTY)
cu: j.CompilationUnit = cast(j.CompilationUnit, super().visit_compilation_unit(compilation_unit, p))

if cu.eof.whitespace:
clean = "".join([_ for _ in cu.eof.whitespace if _ in ['\n', '\r']])
cu = cu.with_eof(cu.eof.with_whitespace(clean))

return cu

def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]],
p: P) -> Space:
s = cast(Space, super().visit_space(space, loc, p))

if not s or not s.whitespace:
return s

last_newline = s.whitespace.rfind('\n')
if last_newline > 0:
ws = [c for i, c in enumerate(s.whitespace) if i >= last_newline or c in {',', '\r', '\n'}]
s = s.with_whitespace(''.join(ws))
return s

def post_visit(self, tree: T, p: P) -> Optional[T]:
if self._stop_after and tree == self._stop_after:
self._stop = True
return tree

def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]:
return tree if self._stop else super().visit(tree, p, parent)
37 changes: 37 additions & 0 deletions rewrite/tests/python/all/format/remove_trailing_whitespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from rewrite.python.format.remove_trailing_whitespace_visitor import RemoveTrailingWhitespaceVisitor
from rewrite.test import rewrite_run, RecipeSpec, from_visitor, python


class Foo:
pass


@pytest.mark.parametrize('n_spaces', list(range(0, 4)))
@pytest.mark.parametrize('linebreaks', ['\n', '\r\n', '\r\n\n', '\n\n'])
def test_tabs_to_spaces(n_spaces, linebreaks):
# noinspection PyInconsistentIndentation
spaces = ' ' * n_spaces
rewrite_run(
python(
# language=python
f"""\
class Foo:{spaces}{linebreaks}\
def bar(self):
return 42{linebreaks}
{spaces}{linebreaks}
""",
# language=python
f"""\
class Foo:{linebreaks}\
def bar(self):
return 42{linebreaks}
{linebreaks}
"""
),
spec=RecipeSpec()
.with_recipes(
from_visitor(RemoveTrailingWhitespaceVisitor())
)
)
Loading