|
8 | 8 | import unittest |
9 | 9 | from importlib import reload |
10 | 10 | from unittest.mock import MagicMock, patch |
| 11 | +from urllib.error import HTTPError |
11 | 12 |
|
| 13 | +from ads.aqua.common.enums import PredictEndpoints |
12 | 14 | from notebook.base.handlers import IPythonHandler |
13 | 15 | from parameterized import parameterized |
14 | 16 |
|
@@ -280,6 +282,220 @@ def test_post(self, mock_get_model_deployment_response): |
280 | 282 | self.handler.write.assert_any_call("chunk2") |
281 | 283 | self.handler.finish.assert_called_once() |
282 | 284 |
|
| 285 | + def test_extract_text_from_choice_dict_delta_content(self): |
| 286 | + """Test dict choice with delta.content.""" |
| 287 | + choice = {"delta": {"content": "hello"}} |
| 288 | + result = self.handler._extract_text_from_choice(choice) |
| 289 | + self.assertEqual(result, "hello") |
| 290 | + |
| 291 | + def test_extract_text_from_choice_dict_delta_text(self): |
| 292 | + """Test dict choice with delta.text fallback.""" |
| 293 | + choice = {"delta": {"text": "world"}} |
| 294 | + result = self.handler._extract_text_from_choice(choice) |
| 295 | + self.assertEqual(result, "world") |
| 296 | + |
| 297 | + def test_extract_text_from_choice_dict_message_content(self): |
| 298 | + """Test dict choice with message.content.""" |
| 299 | + choice = {"message": {"content": "foo"}} |
| 300 | + result = self.handler._extract_text_from_choice(choice) |
| 301 | + self.assertEqual(result, "foo") |
| 302 | + |
| 303 | + def test_extract_text_from_choice_dict_top_level_text(self): |
| 304 | + """Test dict choice with top-level text.""" |
| 305 | + choice = {"text": "bar"} |
| 306 | + result = self.handler._extract_text_from_choice(choice) |
| 307 | + self.assertEqual(result, "bar") |
| 308 | + |
| 309 | + def test_extract_text_from_choice_object_delta_content(self): |
| 310 | + """Test object choice with delta.content attribute.""" |
| 311 | + choice = MagicMock() |
| 312 | + choice.delta = MagicMock(content="obj-content", text=None) |
| 313 | + result = self.handler._extract_text_from_choice(choice) |
| 314 | + self.assertEqual(result, "obj-content") |
| 315 | + |
| 316 | + def test_extract_text_from_choice_object_message_str(self): |
| 317 | + """Test object choice with message as string.""" |
| 318 | + choice = MagicMock(message="direct-string") |
| 319 | + result = self.handler._extract_text_from_choice(choice) |
| 320 | + self.assertEqual(result, "direct-string") |
| 321 | + |
| 322 | + def test_extract_text_from_choice_none_return(self): |
| 323 | + """Test choice with no text content returns None.""" |
| 324 | + choice = {} |
| 325 | + result = self.handler._extract_text_from_choice(choice) |
| 326 | + self.assertIsNone(result) |
| 327 | + |
| 328 | + def test_extract_text_from_chunk_dict_with_choices(self): |
| 329 | + """Test chunk dict with choices list.""" |
| 330 | + chunk = {"choices": [{"delta": {"content": "chunk-text"}}]} |
| 331 | + result = self.handler._extract_text_from_chunk(chunk) |
| 332 | + self.assertEqual(result, "chunk-text") |
| 333 | + |
| 334 | + def test_extract_text_from_chunk_dict_top_level_content(self): |
| 335 | + """Test chunk dict with top-level content (no choices).""" |
| 336 | + chunk = {"content": "direct-content"} |
| 337 | + result = self.handler._extract_text_from_chunk(chunk) |
| 338 | + self.assertEqual(result, "direct-content") |
| 339 | + |
| 340 | + def test_extract_text_from_chunk_object_choices(self): |
| 341 | + """Test object chunk with choices attribute.""" |
| 342 | + chunk = MagicMock() |
| 343 | + chunk.choices = [{"message": {"content": "obj-chunk"}}] |
| 344 | + result = self.handler._extract_text_from_chunk(chunk) |
| 345 | + self.assertEqual(result, "obj-chunk") |
| 346 | + |
| 347 | + def test_extract_text_from_chunk_empty(self): |
| 348 | + """Test empty/None chunk returns None.""" |
| 349 | + result = self.handler._extract_text_from_chunk({}) |
| 350 | + self.assertIsNone(result) |
| 351 | + result = self.handler._extract_text_from_chunk(None) |
| 352 | + self.assertIsNone(result) |
| 353 | + |
| 354 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 355 | + def test_missing_required_keys_raises_http_error(self, mock_aqua_app): |
| 356 | + """Test missing required payload keys raises HTTPError.""" |
| 357 | + payload = {"prompt": "test"} |
| 358 | + with self.assertRaises(HTTPError) as cm: |
| 359 | + list(self.handler._get_model_deployment_response("test-id", payload)) |
| 360 | + self.assertEqual(cm.exception.status_code, 400) |
| 361 | + self.assertIn("model", str(cm.exception)) |
| 362 | + |
| 363 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 364 | + @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
| 365 | + def test_chat_completions_no_image_yields_chunks(self, mock_extract, mock_aqua_app): |
| 366 | + """Test chat completions without image streams correctly.""" |
| 367 | + mock_deployment = MagicMock() |
| 368 | + mock_deployment.endpoint = "https://test-endpoint" |
| 369 | + mock_aqua_app.return_value.get.return_value = mock_deployment |
| 370 | + |
| 371 | + mock_stream = iter([MagicMock(choices=[{"delta": {"content": "hello"}}])]) |
| 372 | + mock_client = MagicMock() |
| 373 | + mock_client.chat.completions.create.return_value = mock_stream |
| 374 | + with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
| 375 | + payload = { |
| 376 | + "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
| 377 | + "prompt": "test prompt", |
| 378 | + "model": "test-model" |
| 379 | + } |
| 380 | + result = list(self.handler._get_model_deployment_response("test-id", payload)) |
| 381 | + |
| 382 | + mock_extract.assert_called() |
| 383 | + self.assertEqual(result, ["hello"]) |
| 384 | + |
| 385 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 386 | + @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
| 387 | + def test_text_completions_endpoint(self, mock_extract, mock_aqua_app): |
| 388 | + """Test text completions endpoint path.""" |
| 389 | + mock_deployment = MagicMock() |
| 390 | + mock_deployment.endpoint = "https://test-endpoint" |
| 391 | + mock_aqua_app.return_value.get.return_value = mock_deployment |
| 392 | + |
| 393 | + mock_stream = iter([MagicMock(choices=[{"delta": {"content": "text"}}])]) |
| 394 | + mock_client = MagicMock() |
| 395 | + mock_client.completions.create.return_value = mock_stream |
| 396 | + with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
| 397 | + payload = { |
| 398 | + "endpoint_type": PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT, |
| 399 | + "prompt": "test", |
| 400 | + "model": "test-model" |
| 401 | + } |
| 402 | + result = list(self.handler._get_model_deployment_response("test-id", payload)) |
| 403 | + |
| 404 | + self.assertEqual(result, ["text"]) |
| 405 | + |
| 406 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 407 | + @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
| 408 | + def test_image_chat_completions(self, mock_extract, mock_aqua_app): |
| 409 | + """Test chat completions with image input.""" |
| 410 | + mock_deployment = MagicMock() |
| 411 | + mock_deployment.endpoint = "https://test-endpoint" |
| 412 | + mock_aqua_app.return_value.get.return_value = mock_deployment |
| 413 | + |
| 414 | + mock_stream = iter([MagicMock()]) |
| 415 | + mock_client = MagicMock() |
| 416 | + mock_client.chat.completions.create.return_value = mock_stream |
| 417 | + with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
| 418 | + payload = { |
| 419 | + "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
| 420 | + "prompt": "describe image", |
| 421 | + "model": "test-model", |
| 422 | + "encoded_image": "data:image/jpeg;base64,...", |
| 423 | + "file_type": "image/jpeg" |
| 424 | + } |
| 425 | + list(self.handler._get_model_deployment_response("test-id", payload)) |
| 426 | + |
| 427 | + expected_call = call( |
| 428 | + model="test-model", |
| 429 | + messages=[{ |
| 430 | + "role": "user", |
| 431 | + "content": [ |
| 432 | + {"type": "text", "text": "describe image"}, |
| 433 | + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}} # Note: f-string expands |
| 434 | + ] |
| 435 | + }], |
| 436 | + stream=True |
| 437 | + ) |
| 438 | + mock_client.chat.completions.create.assert_has_calls([expected_call]) |
| 439 | + |
| 440 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 441 | + def test_unsupported_endpoint_type_raises_error(self, mock_aqua_app): |
| 442 | + """Test unsupported endpoint_type raises HTTPError.""" |
| 443 | + mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test") |
| 444 | + payload = { |
| 445 | + "endpoint_type": "invalid-type", |
| 446 | + "prompt": "test", |
| 447 | + "model": "test-model" |
| 448 | + } |
| 449 | + with self.assertRaises(HTTPError) as cm: |
| 450 | + list(self.handler._get_model_deployment_response("test-id", payload)) |
| 451 | + self.assertEqual(cm.exception.status_code, 400) |
| 452 | + |
| 453 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 454 | + @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
| 455 | + def test_responses_endpoint_with_params(self, mock_extract, mock_aqua_app): |
| 456 | + """Test responses endpoint with temperature/top_p filtering.""" |
| 457 | + mock_deployment = MagicMock() |
| 458 | + mock_deployment.endpoint = "https://test-endpoint" |
| 459 | + mock_aqua_app.return_value.get.return_value = mock_deployment |
| 460 | + |
| 461 | + mock_stream = iter([MagicMock()]) |
| 462 | + mock_client = MagicMock() |
| 463 | + mock_client.responses.create.return_value = mock_stream |
| 464 | + with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
| 465 | + payload = { |
| 466 | + "endpoint_type": PredictEndpoints.RESPONSES, |
| 467 | + "prompt": "test", |
| 468 | + "model": "test-model", |
| 469 | + "temperature": 0.7, |
| 470 | + "top_p": 0.9 |
| 471 | + } |
| 472 | + list(self.handler._get_model_deployment_response("test-id", payload)) |
| 473 | + |
| 474 | + mock_client.responses.create.assert_called_once_with( |
| 475 | + model="test-model", |
| 476 | + input="test", |
| 477 | + stream=True, |
| 478 | + temperature=0.7, |
| 479 | + top_p=0.9 |
| 480 | + ) |
| 481 | + |
| 482 | + @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
| 483 | + def test_stop_param_normalization(self, mock_aqua_app): |
| 484 | + """Test stop=[] gets normalized to None.""" |
| 485 | + mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test") |
| 486 | + payload = { |
| 487 | + "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
| 488 | + "prompt": "test", |
| 489 | + "model": "test-model", |
| 490 | + "stop": [] |
| 491 | + } |
| 492 | + # Just verify it doesn't crash - normalization happens before API calls |
| 493 | + try: |
| 494 | + next(self.handler._get_model_deployment_response("test-id", payload)) |
| 495 | + except HTTPError: |
| 496 | + pass # Expected due to missing client mocks, but normalization should work |
| 497 | + |
| 498 | + |
283 | 499 |
|
284 | 500 | class AquaModelListHandlerTestCase(unittest.TestCase): |
285 | 501 | default_params = { |
|
0 commit comments