Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions src/glean/api_client/_hooks/multipart_fix_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Hook to fix multipart form file field names that incorrectly have '[]' suffix."""

from typing import Any, Dict, List, Tuple
from .types import SDKInitHook
from glean.api_client.httpclient import HttpClient
from glean.api_client.utils import forms


class MultipartFileFieldFixHook(SDKInitHook):
"""
Fixes multipart form serialization where file field names incorrectly have '[]' suffix.

Speakeasy sometimes generates code that adds '[]' to file field names in multipart forms,
but this is incorrect. File fields should not have the array suffix, only regular form
fields should use this convention.

This hook patches the serialize_multipart_form function to fix the issue at the source.
"""

def sdk_init(self, base_url: str, client: HttpClient) -> Tuple[str, HttpClient]:
"""Initialize the SDK and patch the multipart form serialization."""
self._patch_multipart_serialization()
return base_url, client

def _patch_multipart_serialization(self):
"""Patch the serialize_multipart_form function to fix file field names."""
# Store reference to original function
original_serialize_multipart_form = forms.serialize_multipart_form

def fixed_serialize_multipart_form(
media_type: str, request: Any
) -> Tuple[str, Dict[str, Any], List[Tuple[str, Any]]]:
"""Fixed version of serialize_multipart_form that doesn't add '[]' to file field names."""
# Call the original function
result_media_type, form_data, files_list = (
original_serialize_multipart_form(media_type, request)
)

# Fix file field names in the files list
fixed_files = []
for item in files_list:
if isinstance(item, tuple) and len(item) >= 2:
field_name = item[0]
file_data = item[1]

# Remove '[]' suffix from file field names only
# We can identify file fields by checking if the data looks like file content
if field_name.endswith("[]") and self._is_file_field_data(
file_data
):
fixed_field_name = field_name[:-2] # Remove '[]' suffix
fixed_item = (fixed_field_name,) + item[1:]
fixed_files.append(fixed_item)
else:
fixed_files.append(item)
else:
fixed_files.append(item)

return result_media_type, form_data, fixed_files

# Replace the original function with our fixed version
forms.serialize_multipart_form = fixed_serialize_multipart_form

def _is_file_field_data(self, file_data: Any) -> bool:
"""
Determine if the data represents file field content.

File fields typically have tuple format: (filename, content) or (filename, content, content_type)
where content is bytes, file-like object, or similar.
"""
if isinstance(file_data, tuple) and len(file_data) >= 2:
# Check the structure: (filename, content, [optional content_type])
filename = file_data[0]
content = file_data[1]

# If filename is empty, this is likely JSON content, not a file
if filename == "":
return False

# File content is typically bytes, string, or file-like object
# But exclude empty strings and None values
if content is None or content == "":
return False

return (
isinstance(content, (bytes, str))
or hasattr(content, "read") # File-like object
or (
hasattr(content, "__iter__") and not isinstance(content, str)
) # Iterable but not string
)
return False
4 changes: 4 additions & 0 deletions src/glean/api_client/_hooks/registration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .types import Hooks
from .multipart_fix_hook import MultipartFileFieldFixHook
from .agent_file_upload_error_hook import AgentFileUploadErrorHook


Expand All @@ -13,5 +14,8 @@ def init_hooks(hooks: Hooks):
with an instance of a hook that implements that specific Hook interface
Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance"""

# Register hook to fix multipart file field names that incorrectly have '[]' suffix
hooks.register_sdk_init_hook(MultipartFileFieldFixHook())

# Register hook to provide helpful error messages for agent file upload issues
hooks.register_after_error_hook(AgentFileUploadErrorHook())
203 changes: 203 additions & 0 deletions tests/test_multipart_fix_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Test for the multipart file field fix hook."""

from unittest.mock import Mock, patch

import pytest

from src.glean.api_client._hooks.multipart_fix_hook import MultipartFileFieldFixHook
from src.glean.api_client.httpclient import HttpClient


class TestMultipartFileFieldFixHook:
"""Test cases for the MultipartFileFieldFixHook."""

def setup_method(self):
"""Set up test fixtures."""
self.hook = MultipartFileFieldFixHook()
self.mock_client = Mock(spec=HttpClient)

def test_sdk_init_returns_unchanged_params(self):
"""Test that SDK init returns the same base_url and client."""
base_url = "https://api.example.com"

with patch.object(self.hook, "_patch_multipart_serialization"):
result_url, result_client = self.hook.sdk_init(base_url, self.mock_client)

assert result_url == base_url
assert result_client == self.mock_client

def test_sdk_init_calls_patch_function(self):
"""Test that SDK init calls the patch function."""
base_url = "https://api.example.com"

with patch.object(self.hook, "_patch_multipart_serialization") as mock_patch:
self.hook.sdk_init(base_url, self.mock_client)
mock_patch.assert_called_once()

def test_is_file_field_data_identifies_file_content(self):
"""Test the file field data identification logic."""
# Test file field formats
assert self.hook._is_file_field_data(("test.txt", b"content"))
assert self.hook._is_file_field_data(("test.txt", b"content", "text/plain"))
assert self.hook._is_file_field_data(("test.txt", "string content"))

# Test with file-like object
mock_file = Mock()
mock_file.read = Mock()
assert self.hook._is_file_field_data(("test.txt", mock_file))

# Test non-file field formats
assert not self.hook._is_file_field_data("regular_value")
assert not self.hook._is_file_field_data(123)
assert not self.hook._is_file_field_data(("single_item",))
assert not self.hook._is_file_field_data((None, None))

@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
def test_patch_multipart_serialization_replaces_function(self, mock_forms_module):
"""Test that the patching replaces the serialize_multipart_form function."""
# Mock the original function
original_function = Mock()
mock_forms_module.serialize_multipart_form = original_function

# Call the patch method
self.hook._patch_multipart_serialization()

# Verify that the function was replaced
assert mock_forms_module.serialize_multipart_form != original_function

@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
def test_patched_function_fixes_file_field_names(self, mock_forms_module):
"""Test that the patched function correctly fixes file field names."""
# Mock original function to return data with '[]' suffix
original_function = Mock()
original_function.return_value = (
"multipart/form-data",
{"regular_field": "value"},
[
("file[]", ("test.txt", b"file content", "text/plain")),
("documents[]", ("doc.pdf", b"pdf content", "application/pdf")),
("regular_array[]", "regular_value"), # This should not be changed
],
)
mock_forms_module.serialize_multipart_form = original_function

# Apply the patch
self.hook._patch_multipart_serialization()

# Get the patched function
patched_function = mock_forms_module.serialize_multipart_form

# Call the patched function
media_type, form_data, files_list = patched_function(
"multipart/form-data", Mock()
)

# Verify the results
assert media_type == "multipart/form-data"
assert form_data == {"regular_field": "value"}

# Check that file field names are fixed but regular fields are not
expected_files = [
("file", ("test.txt", b"file content", "text/plain")),
("documents", ("doc.pdf", b"pdf content", "application/pdf")),
("regular_array[]", "regular_value"), # Should remain unchanged
]
assert files_list == expected_files

@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
def test_patched_function_preserves_correct_names(self, mock_forms_module):
"""Test that the patched function preserves already correct field names."""
# Mock original function to return data without '[]' suffix
original_function = Mock()
original_function.return_value = (
"multipart/form-data",
{},
[
("file", ("test.txt", b"file content", "text/plain")),
("document", ("doc.pdf", b"pdf content", "application/pdf")),
],
)
mock_forms_module.serialize_multipart_form = original_function

# Apply the patch
self.hook._patch_multipart_serialization()

# Get the patched function
patched_function = mock_forms_module.serialize_multipart_form

# Call the patched function
media_type, form_data, files_list = patched_function(
"multipart/form-data", Mock()
)

# Verify that nothing was changed
expected_files = [
("file", ("test.txt", b"file content", "text/plain")),
("document", ("doc.pdf", b"pdf content", "application/pdf")),
]
assert files_list == expected_files

@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
def test_patched_function_handles_mixed_fields(self, mock_forms_module):
"""Test handling of mixed file and non-file fields."""
# Mock original function with mixed field types
original_function = Mock()
original_function.return_value = (
"multipart/form-data",
{"form_field": "value"},
[
("correct_file", ("test1.txt", b"content1", "text/plain")),
("wrong_file[]", ("test2.txt", b"content2", "text/plain")),
("form_array[]", "form_value"), # Regular form field, should keep []
(
"json_field[]",
("", '{"key": "value"}', "application/json"),
), # JSON field, might need []
],
)
mock_forms_module.serialize_multipart_form = original_function

# Apply the patch
self.hook._patch_multipart_serialization()

# Get the patched function
patched_function = mock_forms_module.serialize_multipart_form

# Call the patched function
media_type, form_data, files_list = patched_function(
"multipart/form-data", Mock()
)

# Verify the results - only actual file fields should have [] removed
expected_files = [
("correct_file", ("test1.txt", b"content1", "text/plain")),
("wrong_file", ("test2.txt", b"content2", "text/plain")), # Fixed
("form_array[]", "form_value"), # Preserved - not a file field
(
"json_field[]",
("", '{"key": "value"}', "application/json"),
), # Preserved - JSON content
]
assert files_list == expected_files

def test_file_field_detection_edge_cases(self):
"""Test edge cases for file field detection."""
# Empty content
assert not self.hook._is_file_field_data(("test.txt", ""))

# None content
assert not self.hook._is_file_field_data(("test.txt", None))

# List/tuple content (should be considered file-like)
assert self.hook._is_file_field_data(("test.txt", [1, 2, 3]))
assert self.hook._is_file_field_data(("test.txt", (1, 2, 3)))

# String content (should be considered file content)
assert self.hook._is_file_field_data(("test.txt", "string content"))

# But not if it's the first element
assert not self.hook._is_file_field_data(("string content",))


if __name__ == "__main__":
pytest.main([__file__])