Skip to content

Commit 9c0a3a9

Browse files
committed
Added support for Numpy arrays loading
javaobj loads a JavaArray bean which content is a numpy array instead of a list. This adds an indirection compared to the previous API. Fixes #33
1 parent 8539fce commit 9c0a3a9

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

javaobj/v2/api.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
from typing import Optional
2828

2929
from .beans import JavaClassDesc, JavaInstance
30-
31-
32-
class JavaStreamParser:
33-
pass
30+
from .stream import DataStreamReader
31+
from ..constants import TypeCode
3432

3533

3634
class ObjectTransformer:
@@ -49,4 +47,21 @@ def create_instance(self, classdesc):
4947
:param classdesc: The description of a Java class
5048
:return: The Python form of the object, or the original JavaObject
5149
"""
52-
raise NotImplementedError
50+
return None
51+
52+
def load_array(self, reader, field_type, size):
53+
# type: (DataStreamReader, TypeCode, int) -> Optional[list]
54+
"""
55+
Loads and returns the content of a Java array, if possible.
56+
57+
The result of this method must be the content of the array, i.e. a list
58+
or an array. It will be stored in a JavaArray bean created by the
59+
parser.
60+
61+
This method must return None if it can't handle the array.
62+
63+
:param reader: The data stream reader
64+
:param field_type: Type of the elements of the array
65+
:param size: Number of elements in the array
66+
"""
67+
return None

javaobj/v2/beans.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626

2727
from __future__ import absolute_import
2828

29-
from enum import Enum, IntEnum
29+
from enum import IntEnum
3030
from typing import Any, Dict, List, Optional, Set
3131
import logging
3232

33-
from .stream import DataStreamReader
3433
from ..constants import ClassDescFlags, TypeCode
3534
from ..modifiedutf8 import decode_modified_utf8, byte_to_int
3635
from ..utils import UNICODE_TYPE
@@ -518,7 +517,9 @@ def dump(self, indent=0):
518517
prefix = "\t" * indent
519518
sub_prefix = "\t" * (indent + 1)
520519
dump = [
521-
prefix + "[array 0x{0:x}: {1} items]".format(self.handle, len(self))
520+
"{0}[array 0x{1:x}: {2} items - stored as {3}]".format(
521+
prefix, self.handle, len(self), type(self.data).__name__
522+
)
522523
]
523524
for x in self:
524525
if isinstance(x, ParsedJavaContent):

javaobj/v2/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,13 @@ def _do_array(self, type_code):
636636
raise ValueError("Invalid array size")
637637

638638
# Array content
639-
content = [self._read_field_value(field_type) for _ in range(size)]
639+
for transformer in self.__transformers:
640+
content = transformer.load_array(self.__reader, field_type, size)
641+
if content is not None:
642+
break
643+
else:
644+
content = [self._read_field_value(field_type) for _ in range(size)]
645+
640646
return JavaArray(handle, cd, field_type, content)
641647

642648
def _do_exception(self, type_code):

javaobj/v2/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .api import ObjectTransformer
1010
from .core import JavaStreamParser
11-
from .transformers import DefaultObjectTransformer
11+
from .transformers import DefaultObjectTransformer, NumpyArrayTransformer
1212

1313
# ------------------------------------------------------------------------------
1414

@@ -31,6 +31,10 @@ def load(file_object, *transformers, **kwargs):
3131
else:
3232
all_transformers.append(DefaultObjectTransformer())
3333

34+
if kwargs.get("use_numpy_arrays", False):
35+
# Use the numpy array transformer if requested
36+
all_transformers.append(NumpyArrayTransformer())
37+
3438
# Parse the object(s)
3539
parser = JavaStreamParser(file_object, all_transformers)
3640
contents = parser.run()

javaobj/v2/transformers.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,23 @@
2424
limitations under the License.
2525
"""
2626

27+
# Standard library
2728
from typing import List, Optional
2829
import functools
2930

30-
from .beans import BlockData, JavaClassDesc, JavaInstance
31+
# Numpy (optional)
32+
try:
33+
import numpy
34+
except ImportError:
35+
numpy = None
36+
37+
38+
# Javaobj
39+
from .api import ObjectTransformer
40+
from .beans import JavaClassDesc, JavaInstance
3141
from .core import JavaStreamParser
3242
from .stream import DataStreamReader
33-
from ..constants import TerminalCode
43+
from ..constants import TerminalCode, TypeCode
3444
from ..utils import to_bytes, log_error, log_debug, read_struct, read_string
3545

3646

@@ -405,7 +415,7 @@ def do_period(self, data):
405415
return data
406416

407417

408-
class DefaultObjectTransformer:
418+
class DefaultObjectTransformer(ObjectTransformer):
409419

410420
KNOWN_TRANSFORMERS = (
411421
JavaBool,
@@ -454,3 +464,39 @@ def create_instance(self, classdesc):
454464

455465
log_debug(">>> java_object: {0}".format(java_object))
456466
return java_object
467+
468+
469+
class NumpyArrayTransformer(ObjectTransformer):
470+
"""
471+
Loads arrays as numpy arrays if possible
472+
"""
473+
474+
# Convertion of a Java type char to its NumPy equivalent
475+
NUMPY_TYPE_MAP = {
476+
TypeCode.TYPE_BYTE: "B",
477+
TypeCode.TYPE_CHAR: "b",
478+
TypeCode.TYPE_DOUBLE: ">d",
479+
TypeCode.TYPE_FLOAT: ">f",
480+
TypeCode.TYPE_INTEGER: ">i",
481+
TypeCode.TYPE_LONG: ">l",
482+
TypeCode.TYPE_SHORT: ">h",
483+
TypeCode.TYPE_BOOLEAN: ">B",
484+
}
485+
486+
def load_array(self, reader, field_type, size):
487+
# type: (DataStreamReader, TypeCode, int) -> Optional[list]
488+
"""
489+
Loads a Java array, if possible
490+
"""
491+
if numpy is not None:
492+
try:
493+
dtype = self.NUMPY_TYPE_MAP[field_type]
494+
except KeyError:
495+
# Unhandled data type
496+
return None
497+
else:
498+
return numpy.fromfile(
499+
reader.file_descriptor, dtype=dtype, count=size,
500+
)
501+
502+
return None

0 commit comments

Comments
 (0)