Skip to content

Commit ba297e3

Browse files
authored
[RAPTOR-15440] consolidate usage of werkzeug.parse_options_header in one api (#1805)
1 parent f53a901 commit ba297e3

File tree

5 files changed

+91
-15
lines changed

5 files changed

+91
-15
lines changed

custom_model_runner/datarobot_drum/drum/root_predictors/generic_predictor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import urllib
88

9-
import werkzeug
109
from datarobot_drum.drum.adapters.cli.drum_score_adapter import DrumScoreAdapter
1110
from datarobot_drum.drum.enum import GPU_PREDICTORS
1211
from datarobot_drum.drum.enum import TARGET_TYPE_ARG_KEYWORD
@@ -21,6 +20,8 @@
2120
_resolve_outgoing_unstructured_data,
2221
)
2322

23+
from datarobot_drum.drum.root_predictors.utils import get_mimetype_charset_from_content_type_header
24+
2425

2526
class GenericPredictorComponent:
2627
def __init__(self, params: dict):
@@ -122,10 +123,9 @@ def materialize(self):
122123
def _materialize_unstructured(self, input_filename, output_filename):
123124
kwargs_params = {}
124125
query_params = dict(urllib.parse.parse_qsl(self._params.get("query_params")))
125-
mimetype, content_type_params_dict = werkzeug.http.parse_options_header(
126+
mimetype, charset = get_mimetype_charset_from_content_type_header(
126127
self._params.get("content_type")
127128
)
128-
charset = content_type_params_dict.get("charset")
129129

130130
with open(input_filename, "rb") as f:
131131
data_binary = f.read()

custom_model_runner/datarobot_drum/drum/root_predictors/predict_mixin.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import logging
88

9-
import werkzeug
109
from flask import request, Response, stream_with_context
1110
from requests_toolbelt import MultipartEncoder
1211

@@ -40,6 +39,8 @@
4039
_resolve_outgoing_unstructured_data,
4140
)
4241

42+
from datarobot_drum.drum.root_predictors.utils import get_mimetype_charset_from_content_type_header
43+
4344

4445
class PredictMixin:
4546
"""
@@ -55,9 +56,7 @@ def _log_if_possible(logger, log_level, message):
5556

5657
@staticmethod
5758
def _validate_content_type_header(header):
58-
ret_mimetype, content_type_params_dict = werkzeug.http.parse_options_header(header)
59-
ret_charset = content_type_params_dict.get("charset")
60-
return ret_mimetype, ret_charset
59+
return get_mimetype_charset_from_content_type_header(header)
6160

6261
@staticmethod
6362
def _fetch_data_from_request(file_key, logger=None):

custom_model_runner/datarobot_drum/drum/root_predictors/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import shlex
1111
import subprocess
1212
import time
13+
import werkzeug
14+
1315
from queue import Queue, Empty
1416
from threading import Thread
1517

18+
1619
from datarobot_drum.drum.common import get_drum_logger
1720
from datarobot_drum.drum.enum import (
1821
ArgumentOptionsEnvVars,
@@ -235,3 +238,9 @@ def _cmd_add_class_labels(
235238
labels_str = " ".join(['"{}"'.format(label) for label in labels])
236239
cmd += " --class-labels {}".format(labels_str)
237240
return cmd
241+
242+
243+
def get_mimetype_charset_from_content_type_header(header):
244+
mimetype, content_type_params_dict = werkzeug.http.parse_options_header(header)
245+
charset = content_type_params_dict.get("charset")
246+
return mimetype, charset

tests/functional/test_unstructured_mode_per_framework.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import pytest
88
import requests
9-
import werkzeug
109

1110
from datarobot_drum.drum.enum import ArgumentsOptions
1211
from datarobot_drum.drum.server import HTTP_422_UNPROCESSABLE_ENTITY
@@ -18,8 +17,23 @@
1817
_create_custom_model_dir,
1918
)
2019

20+
# This is only for Java environment case.
21+
# For the java env we don't install from source,
22+
# so this func will be available when Java env uses released DRUM with this func.
23+
# Anyhow, unstructured is not supported by java, so these tests are skipped for Java env.
24+
try:
25+
from datarobot_drum.drum.root_predictors.utils import (
26+
get_mimetype_charset_from_content_type_header,
27+
)
28+
except ImportError:
29+
30+
def get_mimetype_charset_from_content_type_header(header):
31+
return None, None
32+
33+
2134
from requests_toolbelt import MultipartEncoder
2235

36+
2337
from tests.constants import (
2438
R_NO_ARTIFACTS,
2539
SKLEARN_NO_ARTIFACTS,
@@ -338,11 +352,11 @@ def test_response_one_var_return(
338352
)
339353
assert response.ok
340354
content_type_header = response.headers["Content-Type"]
341-
mimetype, content_type_params_dict = werkzeug.http.parse_options_header(
355+
mimetype, charset = get_mimetype_charset_from_content_type_header(
342356
content_type_header
343357
)
344358
assert mimetype == "text/plain"
345-
assert content_type_params_dict["charset"] == UTF8
359+
assert charset == UTF8
346360
if data is None:
347361
assert len(response.content) == 0
348362
else:
@@ -363,11 +377,11 @@ def test_response_one_var_return(
363377
)
364378
assert response.ok
365379
content_type_header = response.headers["Content-Type"]
366-
mimetype, content_type_params_dict = werkzeug.http.parse_options_header(
380+
mimetype, charset = get_mimetype_charset_from_content_type_header(
367381
content_type_header
368382
)
369383
assert mimetype == "text/plain"
370-
assert content_type_params_dict["charset"] == UTF8
384+
assert charset == UTF8
371385
if data is None:
372386
assert len(response.content) == 0
373387
else:
@@ -381,10 +395,9 @@ def test_response_one_var_return(
381395
)
382396
assert response.ok
383397
content_type_header = response.headers["Content-Type"]
384-
mimetype, content_type_params_dict = werkzeug.http.parse_options_header(
398+
mimetype, charset = get_mimetype_charset_from_content_type_header(
385399
content_type_header
386400
)
387401
assert "application/octet-stream" == mimetype
388-
# check params dict is empty
389-
assert not any(content_type_params_dict)
402+
assert charset is None
390403
assert response.content == data_bytes

tests/unit/datarobot_drum/drum/root_predictors/test_drum_server_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,58 @@ def test_stream_p_open_ignores_empty_lines(self, mock_logger):
310310
mock_logger.assert_any_call(b"regular line")
311311
mock_logger.assert_any_call(b"error line")
312312
assert mock_logger.call_count == 2 # Only two lines should be logged
313+
314+
315+
class TestGetMimetypeCharsetFromContentTypeHeader:
316+
def test_typical_header(self):
317+
from custom_model_runner.datarobot_drum.drum.root_predictors.utils import (
318+
get_mimetype_charset_from_content_type_header,
319+
)
320+
321+
mimetype, charset = get_mimetype_charset_from_content_type_header(
322+
"text/html; charset=utf-8"
323+
)
324+
325+
assert mimetype == "text/html"
326+
assert charset == "utf-8"
327+
328+
def test_header_without_charset(self):
329+
from custom_model_runner.datarobot_drum.drum.root_predictors.utils import (
330+
get_mimetype_charset_from_content_type_header,
331+
)
332+
333+
mimetype, charset = get_mimetype_charset_from_content_type_header("application/json")
334+
335+
assert mimetype == "application/json"
336+
assert charset is None
337+
338+
def test_header_with_additional_params(self):
339+
from custom_model_runner.datarobot_drum.drum.root_predictors.utils import (
340+
get_mimetype_charset_from_content_type_header,
341+
)
342+
343+
mimetype, charset = get_mimetype_charset_from_content_type_header(
344+
"text/plain; charset=iso-8859-1; format=flowed"
345+
)
346+
347+
assert mimetype == "text/plain"
348+
assert charset == "iso-8859-1"
349+
350+
def test_empty_header(self):
351+
from custom_model_runner.datarobot_drum.drum.root_predictors.utils import (
352+
get_mimetype_charset_from_content_type_header,
353+
)
354+
355+
mimetype, charset = get_mimetype_charset_from_content_type_header("")
356+
357+
assert mimetype == ""
358+
assert charset is None
359+
360+
def test_none_header(self):
361+
from custom_model_runner.datarobot_drum.drum.root_predictors.utils import (
362+
get_mimetype_charset_from_content_type_header,
363+
)
364+
365+
mimetype, charset = get_mimetype_charset_from_content_type_header(None)
366+
assert mimetype == ""
367+
assert charset is None

0 commit comments

Comments
 (0)