|
13 | 13 | from typing import Any, Callable, Dict, List, Set, Tuple, Union |
14 | 14 |
|
15 | 15 | import pytz |
| 16 | +import requests |
16 | 17 | import wrapt |
17 | 18 | from PIL import Image |
18 | 19 | from requests import codes |
@@ -69,6 +70,53 @@ def validate_project_state( |
69 | 70 | ) |
70 | 71 |
|
71 | 72 |
|
| 73 | +@wrapt.decorator |
| 74 | +def validate_image_file_size( |
| 75 | + wrapped: Callable[..., str], |
| 76 | + instance: Any, # pylint: disable=unused-argument |
| 77 | + args: Tuple[_RequestObjectProxy, _Context], |
| 78 | + kwargs: Dict, |
| 79 | +) -> str: |
| 80 | + """ |
| 81 | + Validate the file size of the image given to the query endpoint. |
| 82 | +
|
| 83 | + Args: |
| 84 | + wrapped: An endpoint function for `requests_mock`. |
| 85 | + instance: The class that the endpoint function is in. |
| 86 | + args: The arguments given to the endpoint function. |
| 87 | + kwargs: The keyword arguments given to the endpoint function. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + The result of calling the endpoint. |
| 91 | + An `UNPROCESSABLE_ENTITY` response if the image is given and is not |
| 92 | + either a PNG or a JPEG. |
| 93 | + """ |
| 94 | + request, _ = args |
| 95 | + body_file = io.BytesIO(request.body) |
| 96 | + |
| 97 | + _, pdict = cgi.parse_header(request.headers['Content-Type']) |
| 98 | + parsed = cgi.parse_multipart( |
| 99 | + fp=body_file, |
| 100 | + pdict={ |
| 101 | + 'boundary': pdict['boundary'].encode(), |
| 102 | + }, |
| 103 | + ) |
| 104 | + |
| 105 | + [image] = parsed['image'] |
| 106 | + |
| 107 | + image_file = io.BytesIO(image) |
| 108 | + pil_image = Image.open(image_file) |
| 109 | + file_size_bytes = len(pil_image.tobytes()) |
| 110 | + |
| 111 | + if pil_image.format != 'PNG': |
| 112 | + return wrapped(*args, **kwargs) |
| 113 | + |
| 114 | + documented_max_png_bytes = 2 * 1024 * 1024 |
| 115 | + if file_size_bytes > documented_max_png_bytes: |
| 116 | + raise requests.exceptions.ConnectionError |
| 117 | + return wrapped(*args, **kwargs) |
| 118 | + |
| 119 | + |
72 | 120 | @wrapt.decorator |
73 | 121 | def validate_image_format( |
74 | 122 | wrapped: Callable[..., str], |
@@ -626,6 +674,7 @@ def decorator(method: Callable[..., str]) -> Callable[..., str]: |
626 | 674 | validate_date_header_given, |
627 | 675 | validate_include_target_data, |
628 | 676 | validate_max_num_results, |
| 677 | + validate_image_file_size, |
629 | 678 | validate_image_file_contents, |
630 | 679 | validate_image_format, |
631 | 680 | validate_image_field_given, |
|
0 commit comments