|
20 | 20 | import os |
21 | 21 | import io |
22 | 22 | import uuid |
| 23 | +import urllib.parse |
23 | 24 |
|
24 | 25 | from typing import Union, Optional, IO, AnyStr, Any |
25 | 26 |
|
26 | 27 | try: |
27 | | - from typing import Iterator, Callable |
| 28 | + from typing import Iterator, Callable, Mapping |
28 | 29 | except ImportError: |
29 | | - from collections.abc import Iterator, Callable |
| 30 | + from collections.abc import Iterator, Callable, Mapping |
30 | 31 |
|
31 | 32 | from typing_extensions import ClassVar, Literal |
32 | 33 |
|
|
40 | 41 | CachedPath, |
41 | 42 | CachedStringIO, |
42 | 43 | CachedBytesIO, |
| 44 | + CachedRequest, |
43 | 45 | ) |
44 | 46 | from ..caches.abstract import CacheAbstract |
45 | 47 | from ..caches.memory import CacheQueue |
46 | 48 | from .utilities import no_cache, get_server |
47 | 49 | from ..utilities import StreamFinalizer |
48 | 50 |
|
| 51 | +from .reqstream import DeferredRequestStream |
| 52 | + |
49 | 53 |
|
50 | 54 | __all__ = ("ServiceData",) |
51 | 55 |
|
@@ -118,6 +122,69 @@ def allowed_cross_origin(self) -> str: |
118 | 122 | cross-origin data delivery will not be used.""" |
119 | 123 | return self.__allowed_cross_origin |
120 | 124 |
|
| 125 | + def register_request( |
| 126 | + self, |
| 127 | + url: Union[str, urllib.parse.ParseResult], |
| 128 | + headers: Optional[Mapping[str, str]] = None, |
| 129 | + file_name_fallback: str = "", |
| 130 | + download: bool = True, |
| 131 | + ) -> str: |
| 132 | + """Register the a remote request to the cache. |
| 133 | +
|
| 134 | + Arguments |
| 135 | + --------- |
| 136 | + url: `str | urllib.parse.ParseResult` |
| 137 | + The string-form or parsed URL of the remote file. |
| 138 | +
|
| 139 | + headers: `Mapping[str, str] | None` |
| 140 | + The customized headers for this request. If not specified, will not use |
| 141 | + any headers. |
| 142 | +
|
| 143 | + file_name_fallback: `str` |
| 144 | + The fall-back file name. If specified, it will be used as the saved file |
| 145 | + name when the remote service does not provide the file name. |
| 146 | +
|
| 147 | + download: `bool` |
| 148 | + If specified, will mark the returned address as a downloadable link. |
| 149 | +
|
| 150 | + Returns |
| 151 | + ------- |
| 152 | + #1: `str` |
| 153 | + The URL that would be used for accessing this temporarily cached file. |
| 154 | + """ |
| 155 | + url_parsed = urllib.parse.urlparse(url) if isinstance(url, str) else url |
| 156 | + file_name = ( |
| 157 | + url_parsed.path.replace("\\", "/").rsplit("/", maxsplit=1)[-1].strip() |
| 158 | + ) |
| 159 | + if not file_name: |
| 160 | + file_name = "Unknown" |
| 161 | + |
| 162 | + info = CachedFileInfo( |
| 163 | + type="request", |
| 164 | + data_size=0, |
| 165 | + file_name=file_name, |
| 166 | + content_type="", |
| 167 | + mime_type="application/octet-stream", |
| 168 | + one_time_service=False, |
| 169 | + ) |
| 170 | + data = CachedRequest( |
| 171 | + type="request", |
| 172 | + url=url_parsed.geturl(), |
| 173 | + headers=dict(headers) if headers else dict(), |
| 174 | + file_name_fallback=file_name_fallback, |
| 175 | + ) |
| 176 | + |
| 177 | + uid = uuid.uuid4().hex |
| 178 | + |
| 179 | + if isinstance(self.__cache, CacheQueue): |
| 180 | + cache = self.__cache.mirror |
| 181 | + else: |
| 182 | + cache = self.__cache |
| 183 | + cache.dump(key=uid, info=info, data=data) |
| 184 | + return "{0}?uid={1}{2}".format( |
| 185 | + self.__addr, uid, "&download=true" if download else "" |
| 186 | + ) |
| 187 | + |
121 | 188 | def register( |
122 | 189 | self, |
123 | 190 | fobj: Union[str, os.PathLike, io.StringIO, io.BytesIO], |
@@ -229,6 +296,11 @@ def loader() -> IO[Any]: |
229 | 296 | _data = data() |
230 | 297 | if _data["type"] == "path": |
231 | 298 | fobj = open(_data["path"], "rb") |
| 299 | + elif _data["type"] == "request": |
| 300 | + raise TypeError( |
| 301 | + "service: Should not use the request data in file-based data " |
| 302 | + "loader." |
| 303 | + ) |
232 | 304 | else: |
233 | 305 | fobj = _data["data"] |
234 | 306 | fobj.seek(0, io.SEEK_SET) |
@@ -264,7 +336,9 @@ def _stream_add_headers( |
264 | 336 | """Private method of `stream()` |
265 | 337 |
|
266 | 338 | Add customized headers to the data service response.""" |
267 | | - resp.headers["Content-Length"] = str(info["data_size"]) |
| 339 | + data_size = info["data_size"] |
| 340 | + if isinstance(data_size, str) or (isinstance(data_size, int) and data_size > 0): |
| 341 | + resp.headers["Content-Length"] = str(data_size) |
268 | 342 | if self.__allowed_cross_origin: |
269 | 343 | resp.headers["Access-Control-Allow-Origin"] = self.__allowed_cross_origin |
270 | 344 | resp.headers["Access-Control-Allow-Credentials"] = "true" |
@@ -305,41 +379,53 @@ def stream(self, uid: str, download: bool = False) -> flask.Response: |
305 | 379 |
|
306 | 380 | info, deferred = self.__cache.load(uid) |
307 | 381 |
|
308 | | - if info["data_size"] <= 0: |
309 | | - raise FileNotFoundError( |
310 | | - "services: The requested file {0} is empty.".format(uid) |
311 | | - ) |
312 | | - |
313 | 382 | file_type = info["type"] |
314 | | - if file_type not in ("path", "str", "bytes"): |
| 383 | + if file_type not in ("path", "str", "bytes", "request"): |
315 | 384 | raise TypeError( |
316 | 385 | "service: Cannot recognize the type of fobj: " "{0}".format(file_type) |
317 | 386 | ) |
318 | 387 |
|
319 | | - one_time_service = info["one_time_service"] |
320 | | - at_closed = self._stream_get_at_closed(cache=self.cache, uid=uid) |
| 388 | + if file_type != "request" and info["data_size"] <= 0: |
| 389 | + raise FileNotFoundError( |
| 390 | + "services: The requested file {0} is empty.".format(uid) |
| 391 | + ) |
321 | 392 |
|
322 | | - def provider(_deferred: Callable[[], IO[AnyStr]]) -> Iterator[AnyStr]: |
323 | | - """Streaming data provider.""" |
| 393 | + if info["type"] == "request": |
| 394 | + val = deferred() |
| 395 | + if val["type"] != "request": |
| 396 | + raise TypeError( |
| 397 | + "service: The data type ({0}) and the info type ({1}) does not " |
| 398 | + "match.".format(info["type"], val["type"]) |
| 399 | + ) |
| 400 | + streamer = DeferredRequestStream(info=info, data=val) |
| 401 | + stream = streamer.provide(chunk_size=self.__chunk_size) |
| 402 | + info = streamer.info |
| 403 | + else: |
| 404 | + one_time_service = info["one_time_service"] |
| 405 | + at_closed = self._stream_get_at_closed(cache=self.cache, uid=uid) |
| 406 | + |
| 407 | + def provider(_deferred: Callable[[], IO[AnyStr]]) -> Iterator[AnyStr]: |
| 408 | + """Streaming data provider.""" |
324 | 409 |
|
325 | | - with StreamFinalizer( |
326 | | - _deferred(), callback_on_exit=at_closed if one_time_service else None |
327 | | - ) as _fobj: |
328 | | - data = _fobj.read(self.__chunk_size) |
329 | | - while data: |
330 | | - yield data |
| 410 | + with StreamFinalizer( |
| 411 | + _deferred(), |
| 412 | + callback_on_exit=at_closed if one_time_service else None, |
| 413 | + ) as _fobj: |
331 | 414 | data = _fobj.read(self.__chunk_size) |
| 415 | + while data: |
| 416 | + yield data |
| 417 | + data = _fobj.read(self.__chunk_size) |
| 418 | + |
| 419 | + stream = provider(self._stream_data_to_loader(deferred)) |
332 | 420 |
|
333 | 421 | resp = flask.Response( |
334 | | - flask.stream_with_context(provider(self._stream_data_to_loader(deferred))), |
| 422 | + flask.stream_with_context(stream), |
335 | 423 | content_type=( |
336 | 424 | "application/octet-stream" if download else info["content_type"] |
337 | 425 | ), |
338 | 426 | mimetype=info["mime_type"], |
339 | 427 | ) |
340 | 428 | self._stream_add_headers(resp, info=info, uid=uid, download=download) |
341 | | - print(resp.headers) |
342 | | - |
343 | 429 | return resp |
344 | 430 |
|
345 | 431 | def serve( |
|
0 commit comments