Skip to content

Commit d52e511

Browse files
fix(openai): Avoid consuming iterables passed to the Embeddings API (#5491)
Avoid consuming single-use iterators passed to the Embeddings API. All iterables that are not dictionaries or strings are transformed to lists in the internals of `openai` before they are passed to an API call.
1 parent 963b4dd commit d52e511

File tree

2 files changed

+258
-31
lines changed

2 files changed

+258
-31
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from sentry_sdk.tracing import Span
5151
from sentry_sdk._types import TextPart
5252

53-
from openai.types.responses import ResponseInputParam
53+
from openai.types.responses import ResponseInputParam, SequenceNotStr
5454
from openai import Omit
5555

5656
try:
@@ -220,20 +220,6 @@ def _calculate_token_usage(
220220
)
221221

222222

223-
def _get_input_messages(
224-
kwargs: "dict[str, Any]",
225-
) -> "Optional[Union[Iterable[Any], list[str]]]":
226-
# Input messages (the prompt or data sent to the model)
227-
messages = kwargs.get("messages")
228-
if messages is None:
229-
messages = kwargs.get("input")
230-
231-
if isinstance(messages, str):
232-
messages = [messages]
233-
234-
return messages
235-
236-
237223
def _commmon_set_input_data(
238224
span: "Span",
239225
kwargs: "dict[str, Any]",
@@ -413,15 +399,47 @@ def _set_embeddings_input_data(
413399
kwargs: "dict[str, Any]",
414400
integration: "OpenAIIntegration",
415401
) -> None:
416-
messages = _get_input_messages(kwargs)
402+
messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get(
403+
"input"
404+
)
417405

418406
if (
419-
messages is not None
420-
and len(messages) > 0 # type: ignore
421-
and should_send_default_pii()
422-
and integration.include_prompts
407+
not should_send_default_pii()
408+
or not integration.include_prompts
409+
or messages is None
423410
):
424-
normalized_messages = normalize_message_roles(messages) # type: ignore
411+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
412+
_commmon_set_input_data(span, kwargs)
413+
414+
return
415+
416+
if isinstance(messages, str):
417+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
418+
_commmon_set_input_data(span, kwargs)
419+
420+
normalized_messages = normalize_message_roles([messages]) # type: ignore
421+
scope = sentry_sdk.get_current_scope()
422+
messages_data = truncate_and_annotate_embedding_inputs(
423+
normalized_messages, span, scope
424+
)
425+
if messages_data is not None:
426+
set_data_normalized(
427+
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
428+
)
429+
430+
return
431+
432+
# dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
433+
if not isinstance(messages, Iterable) or isinstance(messages, dict):
434+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
435+
_commmon_set_input_data(span, kwargs)
436+
return
437+
438+
messages = list(messages)
439+
kwargs["input"] = messages
440+
441+
if len(messages) > 0:
442+
normalized_messages = normalize_message_roles(messages)
425443
scope = sentry_sdk.get_current_scope()
426444
messages_data = truncate_and_annotate_embedding_inputs(
427445
normalized_messages, span, scope

tests/integrations/openai/test_openai.py

Lines changed: 219 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,13 @@ async def test_bad_chat_completion_async(sentry_init, capture_events):
930930

931931
@pytest.mark.parametrize(
932932
"send_default_pii, include_prompts",
933-
[(True, True), (True, False), (False, True), (False, False)],
933+
[
934+
(True, False),
935+
(False, True),
936+
(False, False),
937+
],
934938
)
935-
def test_embeddings_create(
939+
def test_embeddings_create_no_pii(
936940
sentry_init, capture_events, send_default_pii, include_prompts
937941
):
938942
sentry_init(
@@ -966,10 +970,109 @@ def test_embeddings_create(
966970
assert tx["type"] == "transaction"
967971
span = tx["spans"][0]
968972
assert span["op"] == "gen_ai.embeddings"
969-
if send_default_pii and include_prompts:
970-
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
973+
974+
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
975+
976+
assert span["data"]["gen_ai.usage.input_tokens"] == 20
977+
assert span["data"]["gen_ai.usage.total_tokens"] == 30
978+
979+
980+
@pytest.mark.parametrize(
981+
"input",
982+
[
983+
pytest.param(
984+
"hello",
985+
id="string",
986+
),
987+
pytest.param(
988+
["First text", "Second text", "Third text"],
989+
id="string_sequence",
990+
),
991+
pytest.param(
992+
iter(["First text", "Second text", "Third text"]),
993+
id="string_iterable",
994+
),
995+
pytest.param(
996+
[5, 8, 13, 21, 34],
997+
id="tokens",
998+
),
999+
pytest.param(
1000+
iter(
1001+
[5, 8, 13, 21, 34],
1002+
),
1003+
id="token_iterable",
1004+
),
1005+
pytest.param(
1006+
[
1007+
[5, 8, 13, 21, 34],
1008+
[8, 13, 21, 34, 55],
1009+
],
1010+
id="tokens_sequence",
1011+
),
1012+
pytest.param(
1013+
iter(
1014+
[
1015+
[5, 8, 13, 21, 34],
1016+
[8, 13, 21, 34, 55],
1017+
]
1018+
),
1019+
id="tokens_sequence_iterable",
1020+
),
1021+
],
1022+
)
1023+
def test_embeddings_create(sentry_init, capture_events, input, request):
1024+
sentry_init(
1025+
integrations=[OpenAIIntegration(include_prompts=True)],
1026+
traces_sample_rate=1.0,
1027+
send_default_pii=True,
1028+
)
1029+
events = capture_events()
1030+
1031+
client = OpenAI(api_key="z")
1032+
1033+
returned_embedding = CreateEmbeddingResponse(
1034+
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
1035+
model="some-model",
1036+
object="list",
1037+
usage=EmbeddingTokenUsage(
1038+
prompt_tokens=20,
1039+
total_tokens=30,
1040+
),
1041+
)
1042+
1043+
client.embeddings._post = mock.Mock(return_value=returned_embedding)
1044+
with start_transaction(name="openai tx"):
1045+
response = client.embeddings.create(input=input, model="text-embedding-3-large")
1046+
1047+
assert len(response.data[0].embedding) == 3
1048+
1049+
tx = events[0]
1050+
assert tx["type"] == "transaction"
1051+
span = tx["spans"][0]
1052+
assert span["op"] == "gen_ai.embeddings"
1053+
1054+
param_id = request.node.callspec.id
1055+
if param_id == "string":
1056+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
1057+
elif param_id == "string_sequence" or param_id == "string_iterable":
1058+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1059+
"First text",
1060+
"Second text",
1061+
"Third text",
1062+
]
1063+
elif param_id == "tokens" or param_id == "token_iterable":
1064+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1065+
5,
1066+
8,
1067+
13,
1068+
21,
1069+
34,
1070+
]
9711071
else:
972-
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
1072+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1073+
[5, 8, 13, 21, 34],
1074+
[8, 13, 21, 34, 55],
1075+
]
9731076

9741077
assert span["data"]["gen_ai.usage.input_tokens"] == 20
9751078
assert span["data"]["gen_ai.usage.total_tokens"] == 30
@@ -978,9 +1081,13 @@ def test_embeddings_create(
9781081
@pytest.mark.asyncio
9791082
@pytest.mark.parametrize(
9801083
"send_default_pii, include_prompts",
981-
[(True, True), (True, False), (False, True), (False, False)],
1084+
[
1085+
(True, False),
1086+
(False, True),
1087+
(False, False),
1088+
],
9821089
)
983-
async def test_embeddings_create_async(
1090+
async def test_embeddings_create_async_no_pii(
9841091
sentry_init, capture_events, send_default_pii, include_prompts
9851092
):
9861093
sentry_init(
@@ -1014,10 +1121,112 @@ async def test_embeddings_create_async(
10141121
assert tx["type"] == "transaction"
10151122
span = tx["spans"][0]
10161123
assert span["op"] == "gen_ai.embeddings"
1017-
if send_default_pii and include_prompts:
1018-
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
1124+
1125+
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
1126+
1127+
assert span["data"]["gen_ai.usage.input_tokens"] == 20
1128+
assert span["data"]["gen_ai.usage.total_tokens"] == 30
1129+
1130+
1131+
@pytest.mark.asyncio
1132+
@pytest.mark.parametrize(
1133+
"input",
1134+
[
1135+
pytest.param(
1136+
"hello",
1137+
id="string",
1138+
),
1139+
pytest.param(
1140+
["First text", "Second text", "Third text"],
1141+
id="string_sequence",
1142+
),
1143+
pytest.param(
1144+
iter(["First text", "Second text", "Third text"]),
1145+
id="string_iterable",
1146+
),
1147+
pytest.param(
1148+
[5, 8, 13, 21, 34],
1149+
id="tokens",
1150+
),
1151+
pytest.param(
1152+
iter(
1153+
[5, 8, 13, 21, 34],
1154+
),
1155+
id="token_iterable",
1156+
),
1157+
pytest.param(
1158+
[
1159+
[5, 8, 13, 21, 34],
1160+
[8, 13, 21, 34, 55],
1161+
],
1162+
id="tokens_sequence",
1163+
),
1164+
pytest.param(
1165+
iter(
1166+
[
1167+
[5, 8, 13, 21, 34],
1168+
[8, 13, 21, 34, 55],
1169+
]
1170+
),
1171+
id="tokens_sequence_iterable",
1172+
),
1173+
],
1174+
)
1175+
async def test_embeddings_create_async(sentry_init, capture_events, input, request):
1176+
sentry_init(
1177+
integrations=[OpenAIIntegration(include_prompts=True)],
1178+
traces_sample_rate=1.0,
1179+
send_default_pii=True,
1180+
)
1181+
events = capture_events()
1182+
1183+
client = AsyncOpenAI(api_key="z")
1184+
1185+
returned_embedding = CreateEmbeddingResponse(
1186+
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
1187+
model="some-model",
1188+
object="list",
1189+
usage=EmbeddingTokenUsage(
1190+
prompt_tokens=20,
1191+
total_tokens=30,
1192+
),
1193+
)
1194+
1195+
client.embeddings._post = AsyncMock(return_value=returned_embedding)
1196+
with start_transaction(name="openai tx"):
1197+
response = await client.embeddings.create(
1198+
input=input, model="text-embedding-3-large"
1199+
)
1200+
1201+
assert len(response.data[0].embedding) == 3
1202+
1203+
tx = events[0]
1204+
assert tx["type"] == "transaction"
1205+
span = tx["spans"][0]
1206+
assert span["op"] == "gen_ai.embeddings"
1207+
1208+
param_id = request.node.callspec.id
1209+
if param_id == "string":
1210+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
1211+
elif param_id == "string_sequence" or param_id == "string_iterable":
1212+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1213+
"First text",
1214+
"Second text",
1215+
"Third text",
1216+
]
1217+
elif param_id == "tokens" or param_id == "token_iterable":
1218+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1219+
5,
1220+
8,
1221+
13,
1222+
21,
1223+
34,
1224+
]
10191225
else:
1020-
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
1226+
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
1227+
[5, 8, 13, 21, 34],
1228+
[8, 13, 21, 34, 55],
1229+
]
10211230

10221231
assert span["data"]["gen_ai.usage.input_tokens"] == 20
10231232
assert span["data"]["gen_ai.usage.total_tokens"] == 30

0 commit comments

Comments
 (0)