From cf67291a1e0ad1749eb538ca3855936397bb8f46 Mon Sep 17 00:00:00 2001 From: zxzinn <92992703+zxzinn@users.noreply.github.com> Date: Wed, 24 Dec 2025 15:56:54 +0800 Subject: [PATCH] feat: add safe JSON parsing helper function Add _safe_json_parse helper function to gracefully handle malformed JSON responses by returning response text instead of raising JSONDecodeError. This prevents crashes when servers return empty or invalid JSON bodies, particularly in error scenarios like HTTP 500 responses. Replace all direct _response.json() calls with _safe_json_parse throughout raw_client.py to ensure consistent error handling across all API endpoints. Add comprehensive unit tests covering valid JSON, empty responses, malformed JSON, and production error cases. Signed-off-by: zxzinn <92992703+zxzinn@users.noreply.github.com> --- src/cohere/v2/raw_client.py | 231 ++++++++++++++++++---------------- tests/test_safe_json_parse.py | 58 +++++++++ 2 files changed, 179 insertions(+), 110 deletions(-) create mode 100644 tests/test_safe_json_parse.py diff --git a/src/cohere/v2/raw_client.py b/src/cohere/v2/raw_client.py index e5cd6b435..6c90bd9a2 100644 --- a/src/cohere/v2/raw_client.py +++ b/src/cohere/v2/raw_client.py @@ -48,6 +48,17 @@ OMIT = typing.cast(typing.Any, ...) +def _safe_json_parse(response: typing.Any) -> typing.Any: + """ + Safely parse JSON from HTTP response. + Returns parsed JSON or response text if parsing fails. + """ + try: + return response.json() + except JSONDecodeError: + return response.text + + class RawV2Client: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper @@ -262,7 +273,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -273,7 +284,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -284,7 +295,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -295,7 +306,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -306,7 +317,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -317,7 +328,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -328,7 +339,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -339,7 +350,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -350,7 +361,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -361,7 +372,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -372,7 +383,7 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -383,11 +394,11 @@ def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text @@ -572,7 +583,7 @@ def chat( V2ChatResponse, construct_type( type_=V2ChatResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return HttpResponse(response=_response, data=_data) @@ -583,7 +594,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -594,7 +605,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -605,7 +616,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -616,7 +627,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -627,7 +638,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -638,7 +649,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -649,7 +660,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -660,7 +671,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -671,7 +682,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -682,7 +693,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -693,7 +704,7 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -704,11 +715,11 @@ def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @@ -819,7 +830,7 @@ def embed( EmbedByTypeResponse, construct_type( type_=EmbedByTypeResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return HttpResponse(response=_response, data=_data) @@ -830,7 +841,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -841,7 +852,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -852,7 +863,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -863,7 +874,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -874,7 +885,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -885,7 +896,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -896,7 +907,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -907,7 +918,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -918,7 +929,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -929,7 +940,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -940,7 +951,7 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -951,11 +962,11 @@ def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @@ -1030,7 +1041,7 @@ def rerank( V2RerankResponse, construct_type( type_=V2RerankResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return HttpResponse(response=_response, data=_data) @@ -1041,7 +1052,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1052,7 +1063,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1063,7 +1074,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1074,7 +1085,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1085,7 +1096,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1096,7 +1107,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1107,7 +1118,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1118,7 +1129,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1129,7 +1140,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1140,7 +1151,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1151,7 +1162,7 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1162,11 +1173,11 @@ def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @@ -1386,7 +1397,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1397,7 +1408,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1408,7 +1419,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1419,7 +1430,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1430,7 +1441,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1441,7 +1452,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1452,7 +1463,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1463,7 +1474,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1474,7 +1485,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1485,7 +1496,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1496,7 +1507,7 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1507,11 +1518,11 @@ async def _iter(): typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text @@ -1696,7 +1707,7 @@ async def chat( V2ChatResponse, construct_type( type_=V2ChatResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return AsyncHttpResponse(response=_response, data=_data) @@ -1707,7 +1718,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1718,7 +1729,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1729,7 +1740,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1740,7 +1751,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1751,7 +1762,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1762,7 +1773,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1773,7 +1784,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1784,7 +1795,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1795,7 +1806,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1806,7 +1817,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1817,7 +1828,7 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1828,11 +1839,11 @@ async def chat( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @@ -1943,7 +1954,7 @@ async def embed( EmbedByTypeResponse, construct_type( type_=EmbedByTypeResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return AsyncHttpResponse(response=_response, data=_data) @@ -1954,7 +1965,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1965,7 +1976,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1976,7 +1987,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1987,7 +1998,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -1998,7 +2009,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2009,7 +2020,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2020,7 +2031,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2031,7 +2042,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2042,7 +2053,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2053,7 +2064,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2064,7 +2075,7 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2075,11 +2086,11 @@ async def embed( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @@ -2154,7 +2165,7 @@ async def rerank( V2RerankResponse, construct_type( type_=V2RerankResponse, # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ) return AsyncHttpResponse(response=_response, data=_data) @@ -2165,7 +2176,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2176,7 +2187,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2187,7 +2198,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2198,7 +2209,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2209,7 +2220,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2220,7 +2231,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2231,7 +2242,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2242,7 +2253,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2253,7 +2264,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2264,7 +2275,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2275,7 +2286,7 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) @@ -2286,11 +2297,11 @@ async def rerank( typing.Optional[typing.Any], construct_type( type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), + object_=_safe_json_parse(_response), ), ), ) - _response_json = _response.json() + _response_json = _safe_json_parse(_response) except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) diff --git a/tests/test_safe_json_parse.py b/tests/test_safe_json_parse.py new file mode 100644 index 000000000..d2cd28bc6 --- /dev/null +++ b/tests/test_safe_json_parse.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import Mock +from json.decoder import JSONDecodeError + +from cohere.v2.raw_client import _safe_json_parse + + +class TestSafeJsonParse(unittest.TestCase): + """Test the _safe_json_parse helper function""" + + def test_valid_json_response(self) -> None: + """Test that valid JSON is parsed correctly""" + mock_response = Mock() + mock_response.json.return_value = {"key": "value", "status": "success"} + + result = _safe_json_parse(mock_response) + + self.assertEqual(result, {"key": "value", "status": "success"}) + mock_response.json.assert_called_once() + + def test_empty_response_body(self) -> None: + """Test that empty response body returns text instead of raising JSONDecodeError""" + mock_response = Mock() + mock_response.json.side_effect = JSONDecodeError("Expecting value", "", 0) + mock_response.text = "" + + result = _safe_json_parse(mock_response) + + self.assertEqual(result, "") + mock_response.json.assert_called_once() + + def test_malformed_json_response(self) -> None: + """Test that malformed JSON returns text instead of raising JSONDecodeError""" + mock_response = Mock() + mock_response.json.side_effect = JSONDecodeError("Expecting value", "not json", 0) + mock_response.text = "Internal Server Error" + + result = _safe_json_parse(mock_response) + + self.assertEqual(result, "Internal Server Error") + mock_response.json.assert_called_once() + + def test_500_error_with_empty_body(self) -> None: + """Test the actual production error case: HTTP 500 with empty response body""" + mock_response = Mock() + mock_response.status_code = 500 + mock_response.json.side_effect = JSONDecodeError("Expecting value: line 1 column 1 (char 0)", "", 0) + mock_response.text = "" + + result = _safe_json_parse(mock_response) + + self.assertEqual(result, "") + self.assertIsInstance(result, str) + mock_response.json.assert_called_once() + + +if __name__ == "__main__": + unittest.main()