Skip to content

Commit e7c72fc

Browse files
committed
[forward port] #1130, #1134
#1130 #1134
1 parent b09d721 commit e7c72fc

File tree

9 files changed

+244
-0
lines changed

9 files changed

+244
-0
lines changed

smart_tests/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typer
1010

1111
from smart_tests.app import Application
12+
from smart_tests.commands import detect_flakes
1213
from smart_tests.commands.record.tests import create_nested_commands as create_record_target_commands
1314
from smart_tests.commands.subset import create_nested_commands as create_subset_target_commands
1415
from smart_tests.utils.test_runner_registry import get_registry
@@ -150,6 +151,7 @@ def main(
150151
app.add_typer(inspect.app, name="inspect")
151152
app.add_typer(stats.app, name="stats")
152153
app.add_typer(compare.app, name="compare")
154+
app.add_typer(detect_flakes.app, name="detect-flakes")
153155

154156
# Add record-target as a sub-app to record command
155157
record.app.add_typer(record_target_app, name="test") # Use NestedCommand version
@@ -162,6 +164,7 @@ def main(
162164
app.add_typer(verify.app, name="verify")
163165
app.add_typer(inspect.app, name="inspect")
164166
app.add_typer(stats.app, name="stats")
167+
app.add_typer(detect_flakes.app, name="detect-flakes")
165168

166169
app.callback()(main)
167170

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from enum import Enum
2+
import os
3+
from typing import Annotated
4+
import typer
5+
6+
from smart_tests.app import Application
7+
from smart_tests.commands.test_path_writer import TestPathWriter
8+
from smart_tests.testpath import unparse_test_path
9+
from smart_tests.utils.commands import Command
10+
from smart_tests.utils.dynamic_commands import DynamicCommandBuilder, extract_callback_options
11+
from smart_tests.utils.env_keys import REPORT_ERROR_KEY
12+
from smart_tests.utils.exceptions import print_error_and_die
13+
from smart_tests.utils.session import get_session
14+
from smart_tests.utils.smart_tests_client import SmartTestsClient
15+
from smart_tests.utils.tracking import Tracking, TrackingClient
16+
from smart_tests.utils.typer_types import ignorable_error
17+
18+
19+
class DetectFlakesRetryThreshold(str, Enum):
20+
LOW = "LOW"
21+
MEDIUM = "MEDIUM"
22+
HIGH = "HIGH"
23+
24+
25+
app = typer.Typer(name="detect-flakes", help="Detect flaky tests")
26+
27+
28+
@app.callback()
29+
def detect_flakes(
30+
ctx: typer.Context,
31+
session: Annotated[str, typer.Option(
32+
"--session",
33+
help="In the format builds/<build-name>/test_sessions/<test-session-id>",
34+
metavar="SESSION"
35+
)],
36+
retry_threshold: Annotated[DetectFlakesRetryThreshold, typer.Option(
37+
"--retry-threshold",
38+
help="Thoroughness of how \"flake\" is detected",
39+
case_sensitive=False,
40+
show_default=True,
41+
)] = DetectFlakesRetryThreshold.MEDIUM,
42+
):
43+
app = ctx.obj
44+
tracking_client = TrackingClient(Command.DETECT_FLAKE, app=app)
45+
test_runner = getattr(ctx, 'test_runner', None)
46+
client = SmartTestsClient(app=app, tracking_client=tracking_client, test_runner=test_runner)
47+
48+
test_session = None
49+
try:
50+
test_session = get_session(client=client, session=session)
51+
except ValueError as e:
52+
print_error_and_die(msg=str(e), tracking_client=tracking_client, event=Tracking.ErrorEvent.USER_ERROR)
53+
except Exception as e:
54+
if os.getenv(REPORT_ERROR_KEY):
55+
raise e
56+
else:
57+
typer.echo(ignorable_error(e), err=True)
58+
59+
if test_session is None:
60+
return
61+
62+
class FlakeDetection(TestPathWriter):
63+
def __init__(self, app: Application):
64+
super(FlakeDetection, self).__init__(app)
65+
66+
def run(self):
67+
test_paths = []
68+
try:
69+
res = client.request(
70+
"get",
71+
"detect-flake",
72+
params={
73+
"confidence": retry_threshold.value.upper(),
74+
"session-id": os.path.basename(session),
75+
"test-runner": test_runner,
76+
})
77+
78+
res.raise_for_status()
79+
test_paths = res.json().get("testPaths", [])
80+
if test_paths:
81+
self.print(test_paths)
82+
typer.echo("Trying to retry the following tests:", err=True)
83+
for detail in res.json().get("testDetails", []):
84+
typer.echo(f"{detail.get('reason'): {unparse_test_path(detail.get('fullTestPath'))}}", err=True)
85+
except Exception as e:
86+
tracking_client.send_error_event(
87+
event_name=Tracking.ErrorEvent.INTERNAL_CLI_ERROR,
88+
stack_trace=str(e),
89+
)
90+
if os.getenv(REPORT_ERROR_KEY):
91+
raise e
92+
else:
93+
typer.echo(ignorable_error(e), err=True)
94+
95+
ctx.obj = FlakeDetection(app=ctx.obj)
96+
97+

smart_tests/test_runners/bazel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def subset(client):
3030
client.run()
3131

3232

33+
smart_tests.CommonDetectFlakesImpls(__name__).detect_flakes()
34+
35+
3336
@smart_tests.record.tests
3437
def record_tests(
3538
client,

smart_tests/test_runners/file.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,6 @@ def find_filename():
5555
for r in reports:
5656
client.report(r)
5757
client.run()
58+
59+
60+
smart_tests.CommonDetectFlakesImpls(__name__).detect_flakes()

smart_tests/test_runners/rspec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
subset = smart_tests.CommonSubsetImpls(__name__).scan_files('*_spec.rb')
44
record_tests = smart_tests.CommonRecordTestImpls(__name__).report_files()
5+
detect_flakes = smart_tests.CommonDetectFlakesImpls(__name__).detect_flakes()

smart_tests/test_runners/smart_tests.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from smart_tests.commands.record.tests import app as record_tests_cmd
1010
from smart_tests.commands.subset import app as subset_cmd
11+
from smart_tests.commands.detect_flakes import app as detect_flakes_cmd
1112
from smart_tests.utils.test_runner_registry import cmdname, create_test_runner_wrapper, get_registry
1213

1314

@@ -34,6 +35,13 @@ def subset(f):
3435
return f
3536

3637

38+
def detect_flakes(f):
39+
test_runner_name = cmdname(f.__module__)
40+
registry = get_registry()
41+
registry.register_detect_flakes(test_runner_name, f)
42+
return f
43+
44+
3745
record = types.SimpleNamespace()
3846

3947

@@ -152,3 +160,26 @@ def load_report_files(cls, client, source_roots, file_mask="*.xml"):
152160
return
153161

154162
client.run()
163+
164+
165+
class CommonDetectFlakesImpls:
166+
def __init__(
167+
self,
168+
module_name,
169+
formatter=None,
170+
separator="\n",
171+
):
172+
self.cmdname = cmdname(module_name)
173+
self._formatter = formatter
174+
self._separator = separator
175+
176+
def detect_flakes(self):
177+
def detect_flakes(client):
178+
if self._formatter:
179+
client.formatter = self._formatter
180+
if self._separator:
181+
client.separator = self._separator
182+
183+
client.run()
184+
185+
return wrap(detect_flakes, detect_flakes_cmd, self.cmdname)

smart_tests/utils/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class Command(Enum):
88
RECORD_SESSION = 'RECORD_SESSION'
99
SUBSET = 'SUBSET'
1010
COMMIT = 'COMMIT'
11+
DETECT_FLAKE = 'DETECT_FLAKE'
1112

1213
def display_name(self):
1314
return self.value.lower().replace('_', ' ')

smart_tests/utils/test_runner_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self):
2222
self._subset_functions: Dict[str, Callable] = {}
2323
self._record_test_functions: Dict[str, Callable] = {}
2424
self._split_subset_functions: Dict[str, Callable] = {}
25+
self._detect_flakes_functions: Dict[str, Callable] = {}
2526
# Callback to trigger when new test runners are registered
2627
self._on_register_callback: Callable[[], None] | None = None
2728

@@ -47,6 +48,11 @@ def register_split_subset(self, test_runner_name: str, func: Callable) -> None:
4748
if self._on_register_callback:
4849
self._on_register_callback()
4950

51+
def register_detect_flakes(self, test_runner_name: str, func: Callable) -> None:
52+
self._detect_flakes_functions[test_runner_name] = func
53+
if self._on_register_callback:
54+
self._on_register_callback()
55+
5056
def get_subset_functions(self) -> Dict[str, Callable]:
5157
"""Get all registered subset functions."""
5258
return self._subset_functions.copy()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
from unittest import mock
3+
4+
import responses # type: ignore
5+
6+
from smart_tests.utils.http_client import get_base_url
7+
from tests.cli_test_case import CliTestCase
8+
9+
10+
class DetectFlakeTest(CliTestCase):
11+
@responses.activate
12+
@mock.patch.dict(os.environ, {"SMART_TESTS_TOKEN": CliTestCase.smart_tests_token})
13+
def test_detect_flakes_success(self):
14+
mock_json_response = {
15+
"testPaths": [
16+
[{"type": "file", "name": "test_flaky_1.py"}],
17+
[{"type": "file", "name": "test_flaky_2.py"}],
18+
]
19+
}
20+
responses.add(
21+
responses.GET,
22+
f"{get_base_url()}/intake/organizations/{self.organization}/workspaces/{self.workspace}/detect-flake",
23+
json=mock_json_response,
24+
status=200,
25+
)
26+
result = self.cli(
27+
"detect-flakes",
28+
"file",
29+
"--session", self.session,
30+
"--retry-threshold", "high",
31+
mix_stderr=False,
32+
)
33+
self.assert_success(result)
34+
self.assertIn("test_flaky_1.py", result.stdout)
35+
self.assertIn("test_flaky_2.py", result.stdout)
36+
37+
@responses.activate
38+
@mock.patch.dict(os.environ, {"SMART_TESTS_TOKEN": CliTestCase.smart_tests_token})
39+
def test_detect_flakes_without_retry_threshold_success(self):
40+
mock_json_response = {
41+
"testPaths": [
42+
[{"type": "file", "name": "test_flaky_1.py"}],
43+
[{"type": "file", "name": "test_flaky_2.py"}],
44+
]
45+
}
46+
responses.add(
47+
responses.GET,
48+
f"{get_base_url()}/intake/organizations/{self.organization}/workspaces/{self.workspace}/detect-flake",
49+
json=mock_json_response,
50+
status=200,
51+
)
52+
result = self.cli(
53+
"detect-flakes",
54+
"file",
55+
"--session", self.session,
56+
mix_stderr=False,
57+
)
58+
self.assert_success(result)
59+
self.assertIn("test_flaky_1.py", result.stdout)
60+
self.assertIn("test_flaky_2.py", result.stdout)
61+
62+
@responses.activate
63+
@mock.patch.dict(os.environ, {"SMART_TESTS_TOKEN": CliTestCase.smart_tests_token})
64+
def test_detect_flakes_no_flakes(self):
65+
mock_json_response = {"testPaths": []}
66+
responses.add(
67+
responses.GET,
68+
f"{get_base_url()}/intake/organizations/{self.organization}/workspaces/{self.workspace}/detect-flake",
69+
json=mock_json_response,
70+
status=200,
71+
)
72+
result = self.cli(
73+
"detect-flakes",
74+
"file",
75+
"--session", self.session,
76+
"--retry-threshold", "low",
77+
mix_stderr=False,
78+
)
79+
self.assert_success(result)
80+
self.assertEqual(result.stdout, "")
81+
82+
@responses.activate
83+
@mock.patch.dict(os.environ, {"SMART_TESTS_TOKEN": CliTestCase.smart_tests_token})
84+
def test_flake_detection_api_error(self):
85+
responses.add(
86+
responses.GET,
87+
f"{get_base_url()}/intake/organizations/{self.organization}/workspaces/{self.workspace}/detect-flake",
88+
status=500,
89+
)
90+
result = self.cli(
91+
"detect-flakes",
92+
"file",
93+
"--session", self.session,
94+
"--retry-threshold", "medium",
95+
mix_stderr=False,
96+
)
97+
self.assert_exit_code(result, 0)
98+
self.assertIn("Error", result.stderr)
99+
self.assertEqual(result.stdout, "")

0 commit comments

Comments
 (0)