Skip to content

Commit e20e5b5

Browse files
author
tboy1337
committed
Refactor PostgreSQLExecutor to support Windows compatibility for process management
- Introduced _get_base_command method to dynamically set the base command for PostgreSQL based on the operating system. - Added _windows_terminate_process method to handle process termination on Windows, ensuring graceful shutdown before force-killing. - Updated stop method to utilize the new Windows-specific termination logic.
1 parent 638ab96 commit e20e5b5

File tree

2 files changed

+272
-9
lines changed

2 files changed

+272
-9
lines changed

pytest_postgresql/executor.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import subprocess
2525
import tempfile
2626
import time
27+
import os
28+
import signal
2729
from typing import Any, Optional, TypeVar
2830

2931
from mirakuru import TCPExecutor
@@ -48,13 +50,28 @@ class PostgreSQLExecutor(TCPExecutor):
4850
<http://www.postgresql.org/docs/current/static/app-pg-ctl.html>`_
4951
"""
5052

51-
BASE_PROC_START_COMMAND = (
52-
'{executable} start -D "{datadir}" '
53-
"-o \"-F -p {port} -c log_destination='stderr' "
54-
"-c logging_collector=off "
55-
"-c unix_socket_directories='{unixsocketdir}' {postgres_options}\" "
56-
'-l "{logfile}" {startparams}'
57-
)
53+
def _get_base_command(self) -> str:
54+
"""Get the base PostgreSQL command, Windows-compatible."""
55+
if platform.system() == "Windows":
56+
# Windows doesn't handle single quotes well in subprocess calls
57+
return (
58+
'{executable} start -D "{datadir}" '
59+
"-o \"-F -p {port} -c log_destination=stderr "
60+
"-c logging_collector=off "
61+
"-c unix_socket_directories={unixsocketdir} {postgres_options}\" "
62+
'-l "{logfile}" {startparams}'
63+
)
64+
else:
65+
# Unix/Linux systems work with single quotes
66+
return (
67+
'{executable} start -D "{datadir}" '
68+
"-o \"-F -p {port} -c log_destination='stderr' "
69+
"-c logging_collector=off "
70+
"-c unix_socket_directories='{unixsocketdir}' {postgres_options}\" "
71+
'-l "{logfile}" {startparams}'
72+
)
73+
74+
BASE_PROC_START_COMMAND = "" # Will be set dynamically
5875

5976
VERSION_RE = re.compile(r".* (?P<version>\d+(?:\.\d+)?)")
6077
MIN_SUPPORTED_VERSION = parse("10")
@@ -108,7 +125,7 @@ def __init__(
108125
self.logfile = logfile
109126
self.startparams = startparams
110127
self.postgres_options = postgres_options
111-
command = self.BASE_PROC_START_COMMAND.format(
128+
command = self._get_base_command().format(
112129
executable=self.executable,
113130
datadir=self.datadir,
114131
port=port,
@@ -219,17 +236,50 @@ def running(self) -> bool:
219236
status_code = subprocess.getstatusoutput(f'{self.executable} status -D "{self.datadir}"')[0]
220237
return status_code == 0
221238

239+
def _windows_terminate_process(self, sig: Optional[int] = None) -> None:
240+
"""Terminate process on Windows."""
241+
if self.process is None:
242+
return
243+
244+
try:
245+
if platform.system() == "Windows":
246+
# On Windows, try to terminate gracefully first
247+
self.process.terminate()
248+
# Give it a chance to terminate gracefully
249+
try:
250+
self.process.wait(timeout=5)
251+
except subprocess.TimeoutExpired:
252+
# If it doesn't terminate gracefully, force kill
253+
self.process.kill()
254+
self.process.wait()
255+
else:
256+
# On Unix systems, use the signal
257+
actual_sig = sig or signal.SIGTERM
258+
os.killpg(self.process.pid, actual_sig)
259+
except (OSError, AttributeError):
260+
# Process might already be dead or other issues
261+
pass
262+
222263
def stop(self: T, sig: Optional[int] = None, exp_sig: Optional[int] = None) -> T:
223264
"""Issue a stop request to executable."""
224265
subprocess.check_output(
225266
f'{self.executable} stop -D "{self.datadir}" -m f',
226267
shell=True,
227268
)
228269
try:
229-
super().stop(sig, exp_sig)
270+
if platform.system() == "Windows":
271+
self._windows_terminate_process(sig)
272+
else:
273+
super().stop(sig, exp_sig)
230274
except ProcessFinishedWithError:
231275
# Finished, leftovers ought to be cleaned afterwards anyway
232276
pass
277+
except AttributeError as e:
278+
# Handle case where os.killpg doesn't exist (shouldn't happen now)
279+
if "killpg" in str(e):
280+
self._windows_terminate_process(sig)
281+
else:
282+
raise
233283
return self
234284

235285
def __del__(self) -> None:
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""Test Windows compatibility fixes for pytest-postgresql."""
2+
3+
import platform
4+
import pytest
5+
from unittest.mock import Mock, patch, MagicMock
6+
import subprocess
7+
import signal
8+
9+
from pytest_postgresql.executor import PostgreSQLExecutor
10+
11+
12+
class TestWindowsCompatibility:
13+
"""Test Windows-specific functionality."""
14+
15+
def test_get_base_command_windows(self):
16+
"""Test that Windows base command doesn't use single quotes."""
17+
executor = PostgreSQLExecutor(
18+
executable="/path/to/pg_ctl",
19+
host="localhost",
20+
port=5432,
21+
datadir="/tmp/data",
22+
unixsocketdir="/tmp/socket",
23+
logfile="/tmp/log",
24+
startparams="-w",
25+
dbname="test",
26+
)
27+
28+
with patch('platform.system', return_value='Windows'):
29+
command = executor._get_base_command()
30+
# Windows command should not have single quotes around stderr
31+
assert "log_destination=stderr" in command
32+
assert "log_destination='stderr'" not in command
33+
assert "unix_socket_directories={unixsocketdir}" in command
34+
assert "unix_socket_directories='{unixsocketdir}'" not in command
35+
36+
def test_get_base_command_unix(self):
37+
"""Test that Unix base command uses single quotes."""
38+
executor = PostgreSQLExecutor(
39+
executable="/path/to/pg_ctl",
40+
host="localhost",
41+
port=5432,
42+
datadir="/tmp/data",
43+
unixsocketdir="/tmp/socket",
44+
logfile="/tmp/log",
45+
startparams="-w",
46+
dbname="test",
47+
)
48+
49+
with patch('platform.system', return_value='Linux'):
50+
command = executor._get_base_command()
51+
# Unix command should have single quotes around stderr
52+
assert "log_destination='stderr'" in command
53+
assert "unix_socket_directories='{unixsocketdir}'" in command
54+
55+
def test_windows_terminate_process(self):
56+
"""Test Windows process termination."""
57+
executor = PostgreSQLExecutor(
58+
executable="/path/to/pg_ctl",
59+
host="localhost",
60+
port=5432,
61+
datadir="/tmp/data",
62+
unixsocketdir="/tmp/socket",
63+
logfile="/tmp/log",
64+
startparams="-w",
65+
dbname="test",
66+
)
67+
68+
# Mock process
69+
mock_process = MagicMock()
70+
executor.process = mock_process
71+
72+
with patch('platform.system', return_value='Windows'):
73+
executor._windows_terminate_process()
74+
75+
# Should call terminate first
76+
mock_process.terminate.assert_called_once()
77+
mock_process.wait.assert_called()
78+
79+
def test_windows_terminate_process_force_kill(self):
80+
"""Test Windows process termination with force kill on timeout."""
81+
executor = PostgreSQLExecutor(
82+
executable="/path/to/pg_ctl",
83+
host="localhost",
84+
port=5432,
85+
datadir="/tmp/data",
86+
unixsocketdir="/tmp/socket",
87+
logfile="/tmp/log",
88+
startparams="-w",
89+
dbname="test",
90+
)
91+
92+
# Mock process that times out
93+
mock_process = MagicMock()
94+
mock_process.wait.side_effect = [subprocess.TimeoutExpired(cmd="test", timeout=5), None]
95+
executor.process = mock_process
96+
97+
with patch('platform.system', return_value='Windows'):
98+
executor._windows_terminate_process()
99+
100+
# Should call terminate, wait (timeout), then kill, then wait again
101+
mock_process.terminate.assert_called_once()
102+
mock_process.kill.assert_called_once()
103+
assert mock_process.wait.call_count == 2
104+
105+
def test_stop_method_windows(self):
106+
"""Test stop method on Windows."""
107+
executor = PostgreSQLExecutor(
108+
executable="/path/to/pg_ctl",
109+
host="localhost",
110+
port=5432,
111+
datadir="/tmp/data",
112+
unixsocketdir="/tmp/socket",
113+
logfile="/tmp/log",
114+
startparams="-w",
115+
dbname="test",
116+
)
117+
118+
# Mock subprocess and process
119+
with patch('subprocess.check_output') as mock_subprocess, \
120+
patch('platform.system', return_value='Windows'), \
121+
patch.object(executor, '_windows_terminate_process') as mock_terminate:
122+
123+
result = executor.stop()
124+
125+
# Should call pg_ctl stop and Windows terminate
126+
mock_subprocess.assert_called_once()
127+
mock_terminate.assert_called_once_with(None)
128+
assert result is executor
129+
130+
def test_stop_method_unix(self):
131+
"""Test stop method on Unix systems."""
132+
executor = PostgreSQLExecutor(
133+
executable="/path/to/pg_ctl",
134+
host="localhost",
135+
port=5432,
136+
datadir="/tmp/data",
137+
unixsocketdir="/tmp/socket",
138+
logfile="/tmp/log",
139+
startparams="-w",
140+
dbname="test",
141+
)
142+
143+
# Mock subprocess and super().stop
144+
with patch('subprocess.check_output') as mock_subprocess, \
145+
patch('platform.system', return_value='Linux'), \
146+
patch('pytest_postgresql.executor.TCPExecutor.stop') as mock_super_stop:
147+
148+
mock_super_stop.return_value = executor
149+
result = executor.stop()
150+
151+
# Should call pg_ctl stop and parent class stop
152+
mock_subprocess.assert_called_once()
153+
mock_super_stop.assert_called_once_with(None, None)
154+
assert result is executor
155+
156+
def test_stop_method_fallback_on_killpg_error(self):
157+
"""Test stop method falls back to Windows termination on killpg AttributeError."""
158+
executor = PostgreSQLExecutor(
159+
executable="/path/to/pg_ctl",
160+
host="localhost",
161+
port=5432,
162+
datadir="/tmp/data",
163+
unixsocketdir="/tmp/socket",
164+
logfile="/tmp/log",
165+
startparams="-w",
166+
dbname="test",
167+
)
168+
169+
# Mock subprocess and super().stop to raise AttributeError
170+
with patch('subprocess.check_output') as mock_subprocess, \
171+
patch('platform.system', return_value='Linux'), \
172+
patch('pytest_postgresql.executor.TCPExecutor.stop',
173+
side_effect=AttributeError("module 'os' has no attribute 'killpg'")), \
174+
patch.object(executor, '_windows_terminate_process') as mock_terminate:
175+
176+
result = executor.stop()
177+
178+
# Should call pg_ctl stop, fail on super().stop, then use Windows terminate
179+
mock_subprocess.assert_called_once()
180+
mock_terminate.assert_called_once_with(None)
181+
assert result is executor
182+
183+
def test_command_formatting_windows(self):
184+
"""Test that command is properly formatted for Windows."""
185+
with patch('platform.system', return_value='Windows'):
186+
executor = PostgreSQLExecutor(
187+
executable="C:/Program Files/PostgreSQL/bin/pg_ctl.exe",
188+
host="localhost",
189+
port=5555,
190+
datadir="C:/temp/data",
191+
unixsocketdir="C:/temp/socket",
192+
logfile="C:/temp/log.txt",
193+
startparams="-w -s",
194+
dbname="testdb",
195+
postgres_options="-c shared_preload_libraries=test"
196+
)
197+
198+
# The command should be properly formatted without quotes around stderr
199+
expected_parts = [
200+
"C:/Program Files/PostgreSQL/bin/pg_ctl.exe start",
201+
'-D "C:/temp/data"',
202+
"-o \"-F -p 5555 -c log_destination=stderr",
203+
"-c logging_collector=off",
204+
"-c unix_socket_directories=C:/temp/socket",
205+
"-c shared_preload_libraries=test\"",
206+
'-l "C:/temp/log.txt"',
207+
"-w -s"
208+
]
209+
210+
# Check if all expected parts are in the command
211+
command = executor.command
212+
for part in expected_parts:
213+
assert part in command, f"Expected '{part}' in command: {command}"

0 commit comments

Comments
 (0)