|
24 | 24 | limitations under the License. |
25 | 25 | """ |
26 | 26 |
|
| 27 | +# Standard library |
27 | 28 | from typing import List, Optional |
28 | 29 | import functools |
29 | 30 |
|
30 | | -from .beans import BlockData, JavaClassDesc, JavaInstance |
| 31 | +# Numpy (optional) |
| 32 | +try: |
| 33 | + import numpy |
| 34 | +except ImportError: |
| 35 | + numpy = None |
| 36 | + |
| 37 | + |
| 38 | +# Javaobj |
| 39 | +from .api import ObjectTransformer |
| 40 | +from .beans import JavaClassDesc, JavaInstance |
31 | 41 | from .core import JavaStreamParser |
32 | 42 | from .stream import DataStreamReader |
33 | | -from ..constants import TerminalCode |
| 43 | +from ..constants import TerminalCode, TypeCode |
34 | 44 | from ..utils import to_bytes, log_error, log_debug, read_struct, read_string |
35 | 45 |
|
36 | 46 |
|
@@ -405,7 +415,7 @@ def do_period(self, data): |
405 | 415 | return data |
406 | 416 |
|
407 | 417 |
|
408 | | -class DefaultObjectTransformer: |
| 418 | +class DefaultObjectTransformer(ObjectTransformer): |
409 | 419 |
|
410 | 420 | KNOWN_TRANSFORMERS = ( |
411 | 421 | JavaBool, |
@@ -454,3 +464,39 @@ def create_instance(self, classdesc): |
454 | 464 |
|
455 | 465 | log_debug(">>> java_object: {0}".format(java_object)) |
456 | 466 | return java_object |
| 467 | + |
| 468 | + |
| 469 | +class NumpyArrayTransformer(ObjectTransformer): |
| 470 | + """ |
| 471 | + Loads arrays as numpy arrays if possible |
| 472 | + """ |
| 473 | + |
| 474 | + # Convertion of a Java type char to its NumPy equivalent |
| 475 | + NUMPY_TYPE_MAP = { |
| 476 | + TypeCode.TYPE_BYTE: "B", |
| 477 | + TypeCode.TYPE_CHAR: "b", |
| 478 | + TypeCode.TYPE_DOUBLE: ">d", |
| 479 | + TypeCode.TYPE_FLOAT: ">f", |
| 480 | + TypeCode.TYPE_INTEGER: ">i", |
| 481 | + TypeCode.TYPE_LONG: ">l", |
| 482 | + TypeCode.TYPE_SHORT: ">h", |
| 483 | + TypeCode.TYPE_BOOLEAN: ">B", |
| 484 | + } |
| 485 | + |
| 486 | + def load_array(self, reader, field_type, size): |
| 487 | + # type: (DataStreamReader, TypeCode, int) -> Optional[list] |
| 488 | + """ |
| 489 | + Loads a Java array, if possible |
| 490 | + """ |
| 491 | + if numpy is not None: |
| 492 | + try: |
| 493 | + dtype = self.NUMPY_TYPE_MAP[field_type] |
| 494 | + except KeyError: |
| 495 | + # Unhandled data type |
| 496 | + return None |
| 497 | + else: |
| 498 | + return numpy.fromfile( |
| 499 | + reader.file_descriptor, dtype=dtype, count=size, |
| 500 | + ) |
| 501 | + |
| 502 | + return None |
0 commit comments