Skip to content

Commit ab9c0c6

Browse files
committed
check if the pid is running before attaching
Signed-off-by: Keming <kemingy94@gmail.com>
1 parent cfea35b commit ab9c0c6

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

Lib/profiling/sampling/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
import sys
88

9-
from .sample import sample, sample_live
9+
from .sample import sample, sample_live, _is_process_running
1010
from .pstats_collector import PstatsCollector
1111
from .stack_collector import CollapsedStackCollector, FlamegraphCollector
1212
from .heatmap_collector import HeatmapCollector
@@ -596,6 +596,8 @@ def main():
596596

597597
def _handle_attach(args):
598598
"""Handle the 'attach' command."""
599+
if not _is_process_running(args.pid):
600+
raise sys.exit(f"Process with PID {args.pid} is not running.")
599601
# Check if live mode is requested
600602
if args.live:
601603
_handle_live_attach(args, args.pid)

Lib/profiling/sampling/sample.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
8989
collector.collect_failed_sample()
9090
errors += 1
9191
except Exception as e:
92-
if not self._is_process_running():
92+
if not _is_process_running(self.pid):
9393
break
9494
raise e from None
9595

@@ -151,22 +151,6 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
151151
f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)"
152152
)
153153

154-
def _is_process_running(self):
155-
if sys.platform == "linux" or sys.platform == "darwin":
156-
try:
157-
os.kill(self.pid, 0)
158-
return True
159-
except ProcessLookupError:
160-
return False
161-
elif sys.platform == "win32":
162-
try:
163-
_remote_debugging.RemoteUnwinder(self.pid)
164-
except Exception:
165-
return False
166-
return True
167-
else:
168-
raise ValueError(f"Unsupported platform: {sys.platform}")
169-
170154
def _print_realtime_stats(self):
171155
"""Print real-time sampling statistics."""
172156
if len(self.sample_intervals) < 2:
@@ -282,6 +266,25 @@ def _print_unwinder_stats(self):
282266
print(f" {ANSIColors.YELLOW}Stale cache invalidations: {stale_invalidations}{ANSIColors.RESET}")
283267

284268

269+
def _is_process_running(pid):
270+
if pid <= 0:
271+
return False
272+
if sys.platform == "linux" or sys.platform == "darwin":
273+
try:
274+
os.kill(pid, 0)
275+
return True
276+
except ProcessLookupError:
277+
return False
278+
elif sys.platform == "win32":
279+
try:
280+
_remote_debugging.RemoteUnwinder(pid)
281+
except Exception:
282+
return False
283+
return True
284+
else:
285+
raise ValueError(f"Unsupported platform: {sys.platform}")
286+
287+
285288
def sample(
286289
pid,
287290
collector,

Lib/test/test_profiling/test_sampling_profiler/test_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def test_cli_default_collapsed_filename(self):
434434

435435
with (
436436
mock.patch("sys.argv", test_args),
437+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
437438
mock.patch("profiling.sampling.cli.sample") as mock_sample,
438439
):
439440
from profiling.sampling.cli import main
@@ -476,6 +477,7 @@ def test_cli_custom_output_filenames(self):
476477
for test_args, expected_filename, expected_format in test_cases:
477478
with (
478479
mock.patch("sys.argv", test_args),
480+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
479481
mock.patch("profiling.sampling.cli.sample") as mock_sample,
480482
):
481483
main()
@@ -516,6 +518,7 @@ def test_argument_parsing_basic(self):
516518

517519
with (
518520
mock.patch("sys.argv", test_args),
521+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
519522
mock.patch("profiling.sampling.cli.sample") as mock_sample,
520523
):
521524
from profiling.sampling.cli import main
@@ -540,6 +543,7 @@ def test_sort_options(self):
540543

541544
with (
542545
mock.patch("sys.argv", test_args),
546+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
543547
mock.patch("profiling.sampling.cli.sample") as mock_sample,
544548
):
545549
from profiling.sampling.cli import main
@@ -554,6 +558,7 @@ def test_async_aware_flag_defaults_to_running(self):
554558

555559
with (
556560
mock.patch("sys.argv", test_args),
561+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
557562
mock.patch("profiling.sampling.cli.sample") as mock_sample,
558563
):
559564
from profiling.sampling.cli import main
@@ -570,6 +575,7 @@ def test_async_aware_with_async_mode_all(self):
570575

571576
with (
572577
mock.patch("sys.argv", test_args),
578+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
573579
mock.patch("profiling.sampling.cli.sample") as mock_sample,
574580
):
575581
from profiling.sampling.cli import main
@@ -585,6 +591,7 @@ def test_async_aware_default_is_none(self):
585591

586592
with (
587593
mock.patch("sys.argv", test_args),
594+
mock.patch("profiling.sampling.sample._is_process_running", return_value=True),
588595
mock.patch("profiling.sampling.cli.sample") as mock_sample,
589596
):
590597
from profiling.sampling.cli import main

Lib/test/test_profiling/test_sampling_profiler/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import profiling.sampling.sample
1818
from profiling.sampling.pstats_collector import PstatsCollector
1919
from profiling.sampling.stack_collector import CollapsedStackCollector
20-
from profiling.sampling.sample import SampleProfiler
20+
from profiling.sampling.sample import SampleProfiler, _is_process_running
2121
except ImportError:
2222
raise unittest.SkipTest(
2323
"Test only runs when _remote_debugging is available"
@@ -681,7 +681,7 @@ def test_is_process_running(self):
681681
self.skipTest(
682682
"Insufficient permissions to read the stack trace"
683683
)
684-
self.assertTrue(profiler._is_process_running())
684+
self.assertTrue(_is_process_running(profiler.pid))
685685
self.assertIsNotNone(profiler.unwinder.get_stack_trace())
686686
subproc.process.kill()
687687
subproc.process.wait()
@@ -690,7 +690,7 @@ def test_is_process_running(self):
690690
)
691691

692692
# Exit the context manager to ensure the process is terminated
693-
self.assertFalse(profiler._is_process_running())
693+
self.assertFalse(_is_process_running(profiler.pid))
694694
self.assertRaises(
695695
ProcessLookupError, profiler.unwinder.get_stack_trace
696696
)

0 commit comments

Comments
 (0)