Skip to content

Commit df27ccf

Browse files
committed
added test cases and PR review comments
1 parent 3c4895b commit df27ccf

File tree

2 files changed

+286
-8
lines changed

2 files changed

+286
-8
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua.extension.errors import Errors
1616
from ads.aqua.modeldeployment import AquaDeploymentApp
1717
from ads.config import COMPARTMENT_OCID
18+
from ads.aqua import logger
1819

1920

2021
class AquaDeploymentHandler(AquaAPIhandler):
@@ -222,7 +223,36 @@ def list_shapes(self):
222223

223224
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
224225

225-
def _extract_text_from_choice(self, choice):
226+
def _extract_text_from_choice(self, choice: dict) -> str:
227+
"""
228+
Extract text content from a single choice structure.
229+
230+
Handles both dictionary-based API responses and object-based SDK responses.
231+
For dict choices, it checks delta-based streaming fields, message-based
232+
non-streaming fields, and finally top-level text/content keys.
233+
For object choices, it inspects `.delta`, `.message`, and top-level
234+
`.text` or `.content` attributes.
235+
236+
Parameters
237+
----------
238+
choice : dict
239+
A choice entry from a model response. It may be:
240+
- A dict originating from a JSON API response (streaming or non-streaming).
241+
- An SDK-style object with attributes such as `delta`, `message`,
242+
`text`, or `content`.
243+
244+
For dicts, the method checks:
245+
• delta → content/text
246+
• message → content/text
247+
• top-level → text/content
248+
249+
For objects, the method checks the same fields via attributes.
250+
251+
Returns
252+
-------
253+
str | None:
254+
The extracted text if present; otherwise None.
255+
"""
226256
# choice may be a dict or an object
227257
if isinstance(choice, dict):
228258
# streaming chunk: {"delta": {"content": "..."}}
@@ -246,7 +276,31 @@ def _extract_text_from_choice(self, choice):
246276
return getattr(msg, "content", None) or getattr(msg, "text", None)
247277
return getattr(choice, "text", None) or getattr(choice, "content", None)
248278

249-
def _extract_text_from_chunk(self, chunk):
279+
def _extract_text_from_chunk(self, chunk: dict) -> str :
280+
"""
281+
Extract text content from a model response chunk.
282+
283+
Supports both dict-form chunks (streaming or non-streaming) and SDK-style
284+
object chunks. When choices are present, extraction is delegated to
285+
`_extract_text_from_choice`. If no choices exist, top-level text/content
286+
fields or attributes are used.
287+
288+
Parameters
289+
----------
290+
chunk : dict
291+
A chunk returned from a model stream or full response. It may be:
292+
- A dict containing a `choices` list or top-level text/content fields.
293+
- An SDK-style object with a `choices` attribute or top-level
294+
`text`/`content` attributes.
295+
296+
If `choices` is present, the method extracts text from the first
297+
choice using `_extract_text_from_choice`. Otherwise, it falls back
298+
to top-level text/content.
299+
Returns
300+
-------
301+
str
302+
The extracted text if present; otherwise None.
303+
"""
250304
if chunk :
251305
if isinstance(chunk, dict):
252306
choices = chunk.get("choices") or []
@@ -311,6 +365,13 @@ def _get_model_deployment_response(
311365

312366
model_deployment = AquaDeploymentApp().get(model_deployment_id)
313367
endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1"
368+
369+
required_keys = ["endpoint_type", "prompt", "model"]
370+
missing = [k for k in required_keys if k not in payload]
371+
372+
if missing:
373+
raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}")
374+
314375
endpoint_type = payload["endpoint_type"]
315376
aqua_client = OpenAI(base_url=endpoint)
316377

@@ -381,7 +442,7 @@ def _get_model_deployment_response(
381442
{"type": "text", "text": payload["prompt"]},
382443
{
383444
"type": "image_url",
384-
"image_url": {"url": f"{self.encoded_image}"},
445+
"image_url": {"url": f"{encoded_image}"},
385446
},
386447
],
387448
}
@@ -397,7 +458,7 @@ def _get_model_deployment_response(
397458

398459
response = aqua_client.chat.completions.create(**api_kwargs)
399460

400-
elif self.file_type.startswith("audio"):
461+
elif file_type.startswith("audio"):
401462
api_kwargs = {
402463
"model": model,
403464
"messages": [
@@ -407,7 +468,7 @@ def _get_model_deployment_response(
407468
{"type": "text", "text": payload["prompt"]},
408469
{
409470
"type": "audio_url",
410-
"audio_url": {"url": f"{self.encoded_image}"},
471+
"audio_url": {"url": f"{encoded_image}"},
411472
},
412473
],
413474
}
@@ -426,7 +487,7 @@ def _get_model_deployment_response(
426487
for chunk in response:
427488
piece = self._extract_text_from_chunk(chunk)
428489
if piece:
429-
print(piece, end="", flush=True)
490+
yield piece
430491
except ExtendedRequestError as ex:
431492
raise HTTPError(400, str(ex))
432493
except Exception as ex:
@@ -468,6 +529,8 @@ def _get_model_deployment_response(
468529
raise HTTPError(400, str(ex))
469530
except Exception as ex:
470531
raise HTTPError(500, str(ex))
532+
else:
533+
raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}")
471534

472535
@handle_exceptions
473536
def post(self, model_deployment_id):
@@ -502,12 +565,11 @@ def post(self, model_deployment_id):
502565
)
503566
try:
504567
for chunk in response_gen:
505-
print(chunk)
506568
self.write(chunk)
507569
self.flush()
508570
self.finish()
509571
except Exception as ex:
510-
self.set_status(ex.status_code)
572+
self.set_status(getattr(ex, "status_code", 500))
511573
self.write({"message": "Error occurred", "reason": str(ex)})
512574
self.finish()
513575

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import unittest
99
from importlib import reload
1010
from unittest.mock import MagicMock, patch
11+
from urllib.error import HTTPError
1112

13+
from ads.aqua.common.enums import PredictEndpoints
1214
from notebook.base.handlers import IPythonHandler
1315
from parameterized import parameterized
1416

@@ -280,6 +282,220 @@ def test_post(self, mock_get_model_deployment_response):
280282
self.handler.write.assert_any_call("chunk2")
281283
self.handler.finish.assert_called_once()
282284

285+
def test_extract_text_from_choice_dict_delta_content(self):
286+
"""Test dict choice with delta.content."""
287+
choice = {"delta": {"content": "hello"}}
288+
result = self.handler._extract_text_from_choice(choice)
289+
self.assertEqual(result, "hello")
290+
291+
def test_extract_text_from_choice_dict_delta_text(self):
292+
"""Test dict choice with delta.text fallback."""
293+
choice = {"delta": {"text": "world"}}
294+
result = self.handler._extract_text_from_choice(choice)
295+
self.assertEqual(result, "world")
296+
297+
def test_extract_text_from_choice_dict_message_content(self):
298+
"""Test dict choice with message.content."""
299+
choice = {"message": {"content": "foo"}}
300+
result = self.handler._extract_text_from_choice(choice)
301+
self.assertEqual(result, "foo")
302+
303+
def test_extract_text_from_choice_dict_top_level_text(self):
304+
"""Test dict choice with top-level text."""
305+
choice = {"text": "bar"}
306+
result = self.handler._extract_text_from_choice(choice)
307+
self.assertEqual(result, "bar")
308+
309+
def test_extract_text_from_choice_object_delta_content(self):
310+
"""Test object choice with delta.content attribute."""
311+
choice = MagicMock()
312+
choice.delta = MagicMock(content="obj-content", text=None)
313+
result = self.handler._extract_text_from_choice(choice)
314+
self.assertEqual(result, "obj-content")
315+
316+
def test_extract_text_from_choice_object_message_str(self):
317+
"""Test object choice with message as string."""
318+
choice = MagicMock(message="direct-string")
319+
result = self.handler._extract_text_from_choice(choice)
320+
self.assertEqual(result, "direct-string")
321+
322+
def test_extract_text_from_choice_none_return(self):
323+
"""Test choice with no text content returns None."""
324+
choice = {}
325+
result = self.handler._extract_text_from_choice(choice)
326+
self.assertIsNone(result)
327+
328+
def test_extract_text_from_chunk_dict_with_choices(self):
329+
"""Test chunk dict with choices list."""
330+
chunk = {"choices": [{"delta": {"content": "chunk-text"}}]}
331+
result = self.handler._extract_text_from_chunk(chunk)
332+
self.assertEqual(result, "chunk-text")
333+
334+
def test_extract_text_from_chunk_dict_top_level_content(self):
335+
"""Test chunk dict with top-level content (no choices)."""
336+
chunk = {"content": "direct-content"}
337+
result = self.handler._extract_text_from_chunk(chunk)
338+
self.assertEqual(result, "direct-content")
339+
340+
def test_extract_text_from_chunk_object_choices(self):
341+
"""Test object chunk with choices attribute."""
342+
chunk = MagicMock()
343+
chunk.choices = [{"message": {"content": "obj-chunk"}}]
344+
result = self.handler._extract_text_from_chunk(chunk)
345+
self.assertEqual(result, "obj-chunk")
346+
347+
def test_extract_text_from_chunk_empty(self):
348+
"""Test empty/None chunk returns None."""
349+
result = self.handler._extract_text_from_chunk({})
350+
self.assertIsNone(result)
351+
result = self.handler._extract_text_from_chunk(None)
352+
self.assertIsNone(result)
353+
354+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
355+
def test_missing_required_keys_raises_http_error(self, mock_aqua_app):
356+
"""Test missing required payload keys raises HTTPError."""
357+
payload = {"prompt": "test"}
358+
with self.assertRaises(HTTPError) as cm:
359+
list(self.handler._get_model_deployment_response("test-id", payload))
360+
self.assertEqual(cm.exception.status_code, 400)
361+
self.assertIn("model", str(cm.exception))
362+
363+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
364+
@patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk')
365+
def test_chat_completions_no_image_yields_chunks(self, mock_extract, mock_aqua_app):
366+
"""Test chat completions without image streams correctly."""
367+
mock_deployment = MagicMock()
368+
mock_deployment.endpoint = "https://test-endpoint"
369+
mock_aqua_app.return_value.get.return_value = mock_deployment
370+
371+
mock_stream = iter([MagicMock(choices=[{"delta": {"content": "hello"}}])])
372+
mock_client = MagicMock()
373+
mock_client.chat.completions.create.return_value = mock_stream
374+
with patch.object(self.handler, 'OpenAI', return_value=mock_client):
375+
payload = {
376+
"endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT,
377+
"prompt": "test prompt",
378+
"model": "test-model"
379+
}
380+
result = list(self.handler._get_model_deployment_response("test-id", payload))
381+
382+
mock_extract.assert_called()
383+
self.assertEqual(result, ["hello"])
384+
385+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
386+
@patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk')
387+
def test_text_completions_endpoint(self, mock_extract, mock_aqua_app):
388+
"""Test text completions endpoint path."""
389+
mock_deployment = MagicMock()
390+
mock_deployment.endpoint = "https://test-endpoint"
391+
mock_aqua_app.return_value.get.return_value = mock_deployment
392+
393+
mock_stream = iter([MagicMock(choices=[{"delta": {"content": "text"}}])])
394+
mock_client = MagicMock()
395+
mock_client.completions.create.return_value = mock_stream
396+
with patch.object(self.handler, 'OpenAI', return_value=mock_client):
397+
payload = {
398+
"endpoint_type": PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT,
399+
"prompt": "test",
400+
"model": "test-model"
401+
}
402+
result = list(self.handler._get_model_deployment_response("test-id", payload))
403+
404+
self.assertEqual(result, ["text"])
405+
406+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
407+
@patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk')
408+
def test_image_chat_completions(self, mock_extract, mock_aqua_app):
409+
"""Test chat completions with image input."""
410+
mock_deployment = MagicMock()
411+
mock_deployment.endpoint = "https://test-endpoint"
412+
mock_aqua_app.return_value.get.return_value = mock_deployment
413+
414+
mock_stream = iter([MagicMock()])
415+
mock_client = MagicMock()
416+
mock_client.chat.completions.create.return_value = mock_stream
417+
with patch.object(self.handler, 'OpenAI', return_value=mock_client):
418+
payload = {
419+
"endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT,
420+
"prompt": "describe image",
421+
"model": "test-model",
422+
"encoded_image": "data:image/jpeg;base64,...",
423+
"file_type": "image/jpeg"
424+
}
425+
list(self.handler._get_model_deployment_response("test-id", payload))
426+
427+
expected_call = call(
428+
model="test-model",
429+
messages=[{
430+
"role": "user",
431+
"content": [
432+
{"type": "text", "text": "describe image"},
433+
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}} # Note: f-string expands
434+
]
435+
}],
436+
stream=True
437+
)
438+
mock_client.chat.completions.create.assert_has_calls([expected_call])
439+
440+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
441+
def test_unsupported_endpoint_type_raises_error(self, mock_aqua_app):
442+
"""Test unsupported endpoint_type raises HTTPError."""
443+
mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test")
444+
payload = {
445+
"endpoint_type": "invalid-type",
446+
"prompt": "test",
447+
"model": "test-model"
448+
}
449+
with self.assertRaises(HTTPError) as cm:
450+
list(self.handler._get_model_deployment_response("test-id", payload))
451+
self.assertEqual(cm.exception.status_code, 400)
452+
453+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
454+
@patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk')
455+
def test_responses_endpoint_with_params(self, mock_extract, mock_aqua_app):
456+
"""Test responses endpoint with temperature/top_p filtering."""
457+
mock_deployment = MagicMock()
458+
mock_deployment.endpoint = "https://test-endpoint"
459+
mock_aqua_app.return_value.get.return_value = mock_deployment
460+
461+
mock_stream = iter([MagicMock()])
462+
mock_client = MagicMock()
463+
mock_client.responses.create.return_value = mock_stream
464+
with patch.object(self.handler, 'OpenAI', return_value=mock_client):
465+
payload = {
466+
"endpoint_type": PredictEndpoints.RESPONSES,
467+
"prompt": "test",
468+
"model": "test-model",
469+
"temperature": 0.7,
470+
"top_p": 0.9
471+
}
472+
list(self.handler._get_model_deployment_response("test-id", payload))
473+
474+
mock_client.responses.create.assert_called_once_with(
475+
model="test-model",
476+
input="test",
477+
stream=True,
478+
temperature=0.7,
479+
top_p=0.9
480+
)
481+
482+
@patch('ads.aqua.modeldeployment.AquaDeploymentApp')
483+
def test_stop_param_normalization(self, mock_aqua_app):
484+
"""Test stop=[] gets normalized to None."""
485+
mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test")
486+
payload = {
487+
"endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT,
488+
"prompt": "test",
489+
"model": "test-model",
490+
"stop": []
491+
}
492+
# Just verify it doesn't crash - normalization happens before API calls
493+
try:
494+
next(self.handler._get_model_deployment_response("test-id", payload))
495+
except HTTPError:
496+
pass # Expected due to missing client mocks, but normalization should work
497+
498+
283499

284500
class AquaModelListHandlerTestCase(unittest.TestCase):
285501
default_params = {

0 commit comments

Comments
 (0)