|
5 | 5 | # This is proprietary source code of DataRobot, Inc. and its affiliates. |
6 | 6 | # Released under the terms of DataRobot Tool and Utility Agreement. |
7 | 7 | # |
| 8 | +import io |
| 9 | +import os |
| 10 | +import subprocess |
| 11 | +import threading |
| 12 | +import time |
| 13 | +from queue import Queue |
8 | 14 | from typing import Tuple, Any |
9 | 15 | from unittest.mock import patch, Mock |
10 | 16 |
|
|
13 | 19 | from datarobot_drum.drum.enum import TargetType, ArgumentsOptions |
14 | 20 | from datarobot_drum.drum.root_predictors.drum_server_utils import DrumServerRun, DrumServerProcess |
15 | 21 |
|
| 22 | +from custom_model_runner.datarobot_drum.drum.root_predictors.utils import ( |
| 23 | + _queue_output, |
| 24 | + _stream_p_open, |
| 25 | +) |
| 26 | + |
16 | 27 |
|
17 | 28 | @pytest.fixture |
18 | 29 | def module_under_test(): |
@@ -110,7 +121,7 @@ class TestingThread: |
110 | 121 | def __init__(self, name, target, args: Tuple[str, DrumServerProcess, Any]): |
111 | 122 | self.name = name |
112 | 123 | self.target = target |
113 | | - self.command, self.process_object_holder, self.verbose = args |
| 124 | + self.command, self.process_object_holder, self.verbose, self.stream_output = args |
114 | 125 |
|
115 | 126 | def start(self): |
116 | 127 | self.process_object_holder.process = Mock(pid=123) |
@@ -141,3 +152,161 @@ def test_calls_thread_correctly(self, mock_get_command, runner): |
141 | 152 | with runner: |
142 | 153 | pass |
143 | 154 | assert runner.server_thread.command == mock_get_command.return_value |
| 155 | + |
| 156 | + |
| 157 | +class TestStreamingOutput: |
| 158 | + def test_queue_output_handles_stdout_and_stderr(self): |
| 159 | + # Create mock objects for stdout, stderr and queue |
| 160 | + mock_stdout = Mock() |
| 161 | + mock_stderr = Mock() |
| 162 | + test_queue = Queue() |
| 163 | + |
| 164 | + # Configure stdout mock to return a sequence of lines and then stop |
| 165 | + stdout_lines = [b"stdout line 1\n", b"stdout line 2\n", b""] |
| 166 | + mock_stdout.readline.side_effect = stdout_lines |
| 167 | + |
| 168 | + # Configure stderr mock to return a sequence of lines and then stop |
| 169 | + stderr_lines = [b"stderr line 1\n", b"stderr line 2\n", b""] |
| 170 | + mock_stderr.readline.side_effect = stderr_lines |
| 171 | + |
| 172 | + # Call the function |
| 173 | + _queue_output(mock_stdout, mock_stderr, test_queue) |
| 174 | + |
| 175 | + # Verify the queue contains all the expected output lines |
| 176 | + expected_lines = [ |
| 177 | + b"stdout line 1\n", |
| 178 | + b"stdout line 2\n", |
| 179 | + b"stderr line 1\n", |
| 180 | + b"stderr line 2\n", |
| 181 | + ] |
| 182 | + |
| 183 | + # Get all items from the queue and compare with expected |
| 184 | + actual_lines = [] |
| 185 | + while not test_queue.empty(): |
| 186 | + actual_lines.append(test_queue.get()) |
| 187 | + |
| 188 | + assert actual_lines == expected_lines |
| 189 | + |
| 190 | + # Verify close was called on both streams |
| 191 | + mock_stdout.close.assert_called_once() |
| 192 | + mock_stderr.close.assert_called_once() |
| 193 | + |
| 194 | + @patch("custom_model_runner.datarobot_drum.drum.root_predictors.utils.logger.info") |
| 195 | + def test_stream_p_open_handles_process_output(self, mock_logger): |
| 196 | + # Create a mock Popen object |
| 197 | + mock_process = Mock(spec=subprocess.Popen) |
| 198 | + |
| 199 | + # Create real BytesIO objects that will be read by the real _queue_output function |
| 200 | + mock_process.stdout = io.BytesIO(b"stdout line 1\nstdout line 2\n") |
| 201 | + mock_process.stderr = io.BytesIO(b"stderr line 1\nstderr line 2\n") |
| 202 | + |
| 203 | + # Configure process to terminate after first poll |
| 204 | + mock_process.poll.side_effect = [None, 0] |
| 205 | + |
| 206 | + # Call the function with real stdout/stderr streams |
| 207 | + stdout, stderr = _stream_p_open(mock_process) |
| 208 | + |
| 209 | + # Verify the results |
| 210 | + assert stdout == "" |
| 211 | + assert stderr == "" |
| 212 | + |
| 213 | + # Verify logger.info was called with the expected content |
| 214 | + expected_outputs = [ |
| 215 | + b"stdout line 1", |
| 216 | + b"stdout line 2", |
| 217 | + b"stderr line 1", |
| 218 | + b"stderr line 2", |
| 219 | + ] |
| 220 | + |
| 221 | + # Get the actual arguments passed to logger.info |
| 222 | + actual_outputs = [call_args[0][0] for call_args in mock_logger.call_args_list] |
| 223 | + |
| 224 | + # Verify all expected outputs were logged |
| 225 | + assert len(actual_outputs) == len(expected_outputs) |
| 226 | + for expected in expected_outputs: |
| 227 | + assert expected in actual_outputs |
| 228 | + |
| 229 | + # Verify poll was called to check for process termination |
| 230 | + assert mock_process.poll.call_count == 2 |
| 231 | + |
| 232 | + @patch("custom_model_runner.datarobot_drum.drum.root_predictors.utils.logger.info") |
| 233 | + def test_stream_p_open_handles_empty_streams(self, mock_logger): |
| 234 | + # Create a mock Popen object with empty streams |
| 235 | + mock_process = Mock(spec=subprocess.Popen) |
| 236 | + mock_process.stdout = io.BytesIO(b"") |
| 237 | + mock_process.stderr = io.BytesIO(b"") |
| 238 | + |
| 239 | + # Configure process to terminate immediately |
| 240 | + mock_process.poll.return_value = 0 |
| 241 | + |
| 242 | + # Call the function |
| 243 | + stdout, stderr = _stream_p_open(mock_process) |
| 244 | + |
| 245 | + # Verify results |
| 246 | + assert stdout == "" |
| 247 | + assert stderr == "" |
| 248 | + mock_logger.assert_not_called() # No output to log |
| 249 | + mock_process.poll.assert_called_once() # Process termination was checked |
| 250 | + |
| 251 | + @patch("custom_model_runner.datarobot_drum.drum.root_predictors.utils.logger.info") |
| 252 | + def test_stream_p_open_handles_delayed_output(self, mock_logger): |
| 253 | + # Create a mock process |
| 254 | + mock_process = Mock(spec=subprocess.Popen) |
| 255 | + |
| 256 | + # Create pipes that we can write to after the function starts |
| 257 | + r_stdout, w_stdout = os.pipe() |
| 258 | + r_stderr, w_stderr = os.pipe() |
| 259 | + |
| 260 | + # Convert file descriptors to file-like objects |
| 261 | + mock_process.stdout = os.fdopen(r_stdout, "rb") |
| 262 | + mock_process.stderr = os.fdopen(r_stderr, "rb") |
| 263 | + |
| 264 | + # Set up process to run and then terminate |
| 265 | + mock_process.poll.side_effect = [None, None, 0] |
| 266 | + |
| 267 | + # Create a thread to run the function we're testing |
| 268 | + def run_stream_p_open(): |
| 269 | + return _stream_p_open(mock_process) |
| 270 | + |
| 271 | + thread = threading.Thread(target=run_stream_p_open) |
| 272 | + thread.daemon = True |
| 273 | + thread.start() |
| 274 | + |
| 275 | + # Write data to the pipes |
| 276 | + time.sleep(0.1) # Give the function time to start |
| 277 | + os.write(w_stdout, b"delayed stdout\n") |
| 278 | + os.write(w_stderr, b"delayed stderr\n") |
| 279 | + os.close(w_stdout) |
| 280 | + os.close(w_stderr) |
| 281 | + |
| 282 | + # Wait for the thread to finish |
| 283 | + thread.join(timeout=2) |
| 284 | + assert not thread.is_alive() |
| 285 | + |
| 286 | + # Verify logger.info was called with the expected content |
| 287 | + mock_logger.assert_any_call(b"delayed stdout") |
| 288 | + mock_logger.assert_any_call(b"delayed stderr") |
| 289 | + |
| 290 | + @patch("custom_model_runner.datarobot_drum.drum.root_predictors.utils.logger.info") |
| 291 | + def test_stream_p_open_ignores_empty_lines(self, mock_logger): |
| 292 | + # Create a mock Popen object |
| 293 | + mock_process = Mock(spec=subprocess.Popen) |
| 294 | + |
| 295 | + # Create streams with empty lines |
| 296 | + mock_process.stdout = io.BytesIO(b"regular line\n \n\n") |
| 297 | + mock_process.stderr = io.BytesIO(b"error line\n \n") |
| 298 | + |
| 299 | + # Configure process to terminate after output is processed |
| 300 | + mock_process.poll.side_effect = [None, 0] |
| 301 | + |
| 302 | + # Call the function |
| 303 | + stdout, stderr = _stream_p_open(mock_process) |
| 304 | + |
| 305 | + # Verify results |
| 306 | + assert stdout == "" |
| 307 | + assert stderr == "" |
| 308 | + |
| 309 | + # Verify only non-empty lines were logged |
| 310 | + mock_logger.assert_any_call(b"regular line") |
| 311 | + mock_logger.assert_any_call(b"error line") |
| 312 | + assert mock_logger.call_count == 2 # Only two lines should be logged |
0 commit comments