Skip to content

Commit 1f308b0

Browse files
fix: implement Speakeasy hook to fix multipart file field names
- Add MultipartFileFieldFixHook that patches serialize_multipart_form at runtime - Fix issue where Speakeasy generates incorrect '[]' suffix on file field names - Hook removes '[]' suffix only from file fields, preserves it for form arrays - Add comprehensive tests covering file field detection and patching logic Fixes multipart form uploads by intercepting serialization before HTTP requests. The hook survives Speakeasy code regeneration since it patches at SDK init time.
1 parent d448641 commit 1f308b0

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Hook to fix multipart form file field names that incorrectly have '[]' suffix."""
2+
3+
from typing import Any, Dict, List, Tuple
4+
from .types import SDKInitHook
5+
from glean.api_client.httpclient import HttpClient
6+
7+
8+
class MultipartFileFieldFixHook(SDKInitHook):
9+
"""
10+
Fixes multipart form serialization where file field names incorrectly have '[]' suffix.
11+
12+
Speakeasy sometimes generates code that adds '[]' to file field names in multipart forms,
13+
but this is incorrect. File fields should not have the array suffix, only regular form
14+
fields should use this convention.
15+
16+
This hook patches the serialize_multipart_form function to fix the issue at the source.
17+
"""
18+
19+
def sdk_init(self, base_url: str, client: HttpClient) -> Tuple[str, HttpClient]:
20+
"""Initialize the SDK and patch the multipart form serialization."""
21+
self._patch_multipart_serialization()
22+
return base_url, client
23+
24+
def _patch_multipart_serialization(self):
25+
"""Patch the serialize_multipart_form function to fix file field names."""
26+
from glean.api_client.utils import forms
27+
28+
# Store reference to original function
29+
original_serialize_multipart_form = forms.serialize_multipart_form
30+
31+
def fixed_serialize_multipart_form(
32+
media_type: str, request: Any
33+
) -> Tuple[str, Dict[str, Any], List[Tuple[str, Any]]]:
34+
"""Fixed version of serialize_multipart_form that doesn't add '[]' to file field names."""
35+
# Call the original function
36+
result_media_type, form_data, files_list = (
37+
original_serialize_multipart_form(media_type, request)
38+
)
39+
40+
# Fix file field names in the files list
41+
fixed_files = []
42+
for item in files_list:
43+
if isinstance(item, tuple) and len(item) >= 2:
44+
field_name = item[0]
45+
file_data = item[1]
46+
47+
# Remove '[]' suffix from file field names only
48+
# We can identify file fields by checking if the data looks like file content
49+
if field_name.endswith("[]") and self._is_file_field_data(
50+
file_data
51+
):
52+
fixed_field_name = field_name[:-2] # Remove '[]' suffix
53+
fixed_item = (fixed_field_name,) + item[1:]
54+
fixed_files.append(fixed_item)
55+
else:
56+
fixed_files.append(item)
57+
else:
58+
fixed_files.append(item)
59+
60+
return result_media_type, form_data, fixed_files
61+
62+
# Replace the original function with our fixed version
63+
forms.serialize_multipart_form = fixed_serialize_multipart_form
64+
65+
def _is_file_field_data(self, file_data: Any) -> bool:
66+
"""
67+
Determine if the data represents file field content.
68+
69+
File fields typically have tuple format: (filename, content) or (filename, content, content_type)
70+
where content is bytes, file-like object, or similar.
71+
"""
72+
if isinstance(file_data, tuple) and len(file_data) >= 2:
73+
# Check the structure: (filename, content, [optional content_type])
74+
content = file_data[1]
75+
76+
# File content is typically bytes, string, or file-like object
77+
return (
78+
isinstance(content, (bytes, str))
79+
or hasattr(content, "read") # File-like object
80+
or (
81+
hasattr(content, "__iter__") and not isinstance(content, str)
82+
) # Iterable but not string
83+
)
84+
return False

src/glean/api_client/_hooks/registration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .types import Hooks
2+
from .multipart_fix_hook import MultipartFileFieldFixHook
23

34

45
# This file is only ever generated once on the first generation and then is free to be modified.
@@ -11,3 +12,6 @@ def init_hooks(hooks: Hooks):
1112
"""Add hooks by calling hooks.register{sdk_init/before_request/after_success/after_error}Hook
1213
with an instance of a hook that implements that specific Hook interface
1314
Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance"""
15+
16+
# Register hook to fix multipart file field names that incorrectly have '[]' suffix
17+
hooks.register_sdk_init_hook(MultipartFileFieldFixHook())

tests/test_multipart_fix_hook.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Test for the multipart file field fix hook."""
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from src.glean.api_client._hooks.multipart_fix_hook import MultipartFileFieldFixHook
8+
from src.glean.api_client.httpclient import HttpClient
9+
10+
11+
class TestMultipartFileFieldFixHook:
12+
"""Test cases for the MultipartFileFieldFixHook."""
13+
14+
def setup_method(self):
15+
"""Set up test fixtures."""
16+
self.hook = MultipartFileFieldFixHook()
17+
self.mock_client = Mock(spec=HttpClient)
18+
19+
def test_sdk_init_returns_unchanged_params(self):
20+
"""Test that SDK init returns the same base_url and client."""
21+
base_url = "https://api.example.com"
22+
23+
with patch.object(self.hook, "_patch_multipart_serialization"):
24+
result_url, result_client = self.hook.sdk_init(base_url, self.mock_client)
25+
26+
assert result_url == base_url
27+
assert result_client == self.mock_client
28+
29+
def test_sdk_init_calls_patch_function(self):
30+
"""Test that SDK init calls the patch function."""
31+
base_url = "https://api.example.com"
32+
33+
with patch.object(self.hook, "_patch_multipart_serialization") as mock_patch:
34+
self.hook.sdk_init(base_url, self.mock_client)
35+
mock_patch.assert_called_once()
36+
37+
def test_is_file_field_data_identifies_file_content(self):
38+
"""Test the file field data identification logic."""
39+
# Test file field formats
40+
assert self.hook._is_file_field_data(("test.txt", b"content"))
41+
assert self.hook._is_file_field_data(("test.txt", b"content", "text/plain"))
42+
assert self.hook._is_file_field_data(("test.txt", "string content"))
43+
44+
# Test with file-like object
45+
mock_file = Mock()
46+
mock_file.read = Mock()
47+
assert self.hook._is_file_field_data(("test.txt", mock_file))
48+
49+
# Test non-file field formats
50+
assert not self.hook._is_file_field_data("regular_value")
51+
assert not self.hook._is_file_field_data(123)
52+
assert not self.hook._is_file_field_data(("single_item",))
53+
assert not self.hook._is_file_field_data((None, None))
54+
55+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
56+
def test_patch_multipart_serialization_replaces_function(self, mock_forms_module):
57+
"""Test that the patching replaces the serialize_multipart_form function."""
58+
# Mock the original function
59+
original_function = Mock()
60+
mock_forms_module.serialize_multipart_form = original_function
61+
62+
# Call the patch method
63+
self.hook._patch_multipart_serialization()
64+
65+
# Verify that the function was replaced
66+
assert mock_forms_module.serialize_multipart_form != original_function
67+
68+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
69+
def test_patched_function_fixes_file_field_names(self, mock_forms_module):
70+
"""Test that the patched function correctly fixes file field names."""
71+
# Mock original function to return data with '[]' suffix
72+
original_function = Mock()
73+
original_function.return_value = (
74+
"multipart/form-data",
75+
{"regular_field": "value"},
76+
[
77+
("file[]", ("test.txt", b"file content", "text/plain")),
78+
("documents[]", ("doc.pdf", b"pdf content", "application/pdf")),
79+
("regular_array[]", "regular_value"), # This should not be changed
80+
],
81+
)
82+
mock_forms_module.serialize_multipart_form = original_function
83+
84+
# Apply the patch
85+
self.hook._patch_multipart_serialization()
86+
87+
# Get the patched function
88+
patched_function = mock_forms_module.serialize_multipart_form
89+
90+
# Call the patched function
91+
media_type, form_data, files_list = patched_function(
92+
"multipart/form-data", Mock()
93+
)
94+
95+
# Verify the results
96+
assert media_type == "multipart/form-data"
97+
assert form_data == {"regular_field": "value"}
98+
99+
# Check that file field names are fixed but regular fields are not
100+
expected_files = [
101+
("file", ("test.txt", b"file content", "text/plain")),
102+
("documents", ("doc.pdf", b"pdf content", "application/pdf")),
103+
("regular_array[]", "regular_value"), # Should remain unchanged
104+
]
105+
assert files_list == expected_files
106+
107+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
108+
def test_patched_function_preserves_correct_names(self, mock_forms_module):
109+
"""Test that the patched function preserves already correct field names."""
110+
# Mock original function to return data without '[]' suffix
111+
original_function = Mock()
112+
original_function.return_value = (
113+
"multipart/form-data",
114+
{},
115+
[
116+
("file", ("test.txt", b"file content", "text/plain")),
117+
("document", ("doc.pdf", b"pdf content", "application/pdf")),
118+
],
119+
)
120+
mock_forms_module.serialize_multipart_form = original_function
121+
122+
# Apply the patch
123+
self.hook._patch_multipart_serialization()
124+
125+
# Get the patched function
126+
patched_function = mock_forms_module.serialize_multipart_form
127+
128+
# Call the patched function
129+
media_type, form_data, files_list = patched_function(
130+
"multipart/form-data", Mock()
131+
)
132+
133+
# Verify that nothing was changed
134+
expected_files = [
135+
("file", ("test.txt", b"file content", "text/plain")),
136+
("document", ("doc.pdf", b"pdf content", "application/pdf")),
137+
]
138+
assert files_list == expected_files
139+
140+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
141+
def test_patched_function_handles_mixed_fields(self, mock_forms_module):
142+
"""Test handling of mixed file and non-file fields."""
143+
# Mock original function with mixed field types
144+
original_function = Mock()
145+
original_function.return_value = (
146+
"multipart/form-data",
147+
{"form_field": "value"},
148+
[
149+
("correct_file", ("test1.txt", b"content1", "text/plain")),
150+
("wrong_file[]", ("test2.txt", b"content2", "text/plain")),
151+
("form_array[]", "form_value"), # Regular form field, should keep []
152+
(
153+
"json_field[]",
154+
("", '{"key": "value"}', "application/json"),
155+
), # JSON field, might need []
156+
],
157+
)
158+
mock_forms_module.serialize_multipart_form = original_function
159+
160+
# Apply the patch
161+
self.hook._patch_multipart_serialization()
162+
163+
# Get the patched function
164+
patched_function = mock_forms_module.serialize_multipart_form
165+
166+
# Call the patched function
167+
media_type, form_data, files_list = patched_function(
168+
"multipart/form-data", Mock()
169+
)
170+
171+
# Verify the results - only actual file fields should have [] removed
172+
expected_files = [
173+
("correct_file", ("test1.txt", b"content1", "text/plain")),
174+
("wrong_file", ("test2.txt", b"content2", "text/plain")), # Fixed
175+
("form_array[]", "form_value"), # Preserved - not a file field
176+
(
177+
"json_field[]",
178+
("", '{"key": "value"}', "application/json"),
179+
), # Preserved - JSON content
180+
]
181+
assert files_list == expected_files
182+
183+
def test_file_field_detection_edge_cases(self):
184+
"""Test edge cases for file field detection."""
185+
# Empty content
186+
assert not self.hook._is_file_field_data(("test.txt", ""))
187+
188+
# None content
189+
assert not self.hook._is_file_field_data(("test.txt", None))
190+
191+
# List/tuple content (should be considered file-like)
192+
assert self.hook._is_file_field_data(("test.txt", [1, 2, 3]))
193+
assert self.hook._is_file_field_data(("test.txt", (1, 2, 3)))
194+
195+
# String content (should be considered file content)
196+
assert self.hook._is_file_field_data(("test.txt", "string content"))
197+
198+
# But not if it's the first element
199+
assert not self.hook._is_file_field_data(("string content",))
200+
201+
202+
if __name__ == "__main__":
203+
pytest.main([__file__])

0 commit comments

Comments
 (0)