Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions cassis/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TypeCheckError,
TypeSystem,
TypeSystemMode,
load_typesystem,
)

_validator_optional_string = validators.optional(validators.instance_of(str))
Expand Down Expand Up @@ -832,6 +833,119 @@ def _copy(self) -> "Cas":
result._xmi_id_generator = self._xmi_id_generator
return result

def deep_copy(self, copy_typesystem: bool = False) -> "Cas":
"""
Create and return a deep copy of this CAS object.
All feature structures, views, and sofas are copied. If `copy_typesystem` is True, the typesystem is also deep-copied;
otherwise, the original typesystem is shared between the original and the copy.
Args:
copy_typesystem (bool): Whether to copy the original typesystem or not. If True, the typesystem is deep-copied.
Returns:
Cas: A deep copy of this CAS object.
"""
ts = self.typesystem
if copy_typesystem:
ts = self.typesystem.to_xml()
ts = load_typesystem(ts)

cas_copy = Cas(ts,
document_language=self.document_language,
lenient=self._lenient,
sofa_mime=self.sofa_mime,
)

cas_copy._views = {}
cas_copy._sofas = {}

for sofa in self.sofas:

sofa_copy = Sofa(
sofaID=sofa.sofaID,
sofaNum=sofa.sofaNum,
type=ts.get_type(sofa.type.name),
xmiID=sofa.xmiID,
)
sofa_copy.mimeType = sofa.mimeType
sofa_copy.sofaArray = sofa.sofaArray
sofa_copy.sofaString = sofa.sofaString
sofa_copy.sofaURI = sofa.sofaURI

cas_copy._sofas[sofa_copy.sofaID] = sofa_copy
cas_copy._views[sofa_copy.sofaID] = View(sofa=sofa_copy)

# removes the _IntialView created with the initialization of the copied CAS
cas_copy._current_view = cas_copy._views["_InitialView"]

references = dict()
referenced_arrays = dict()

all_copied_fs = dict()
referenced_view = {}

for fs in self._find_all_fs():

# the referenced view is required when adding the fs to the copied cas later
if hasattr(fs, 'sofa') and fs.sofa and hasattr(fs, 'xmiID') and fs.xmiID:
referenced_view[fs.xmiID] = fs.sofa.sofaID

t = ts.get_type(fs.type.name)
fs_copy = t()

for feature in t.all_features:
if ts.is_primitive(feature.rangeType):
fs_copy[feature.name] = fs.get(feature.name)
elif ts.is_primitive_collection(feature.rangeType):
fs_copy[feature.name] = ts.get_type(feature.rangeType.name)()
fs_copy[feature.name].elements = fs.get(feature.name).elements
elif ts.is_array(feature.rangeType):
fs_copy[feature.name] = ts.get_type(TYPE_NAME_FS_ARRAY)()
# collect referenced xmiIDs for mapping later
referenced_list = []
for item in fs[feature.name].elements:
if hasattr(item, 'xmiID') and item.xmiID is not None:
referenced_list.append(item.xmiID)
referenced_arrays.setdefault(fs.xmiID, {})
referenced_arrays[fs.xmiID][feature.name] = referenced_list
elif feature.rangeType.name == TYPE_NAME_SOFA:
# ignore sofa references
pass
else:
if hasattr(fs[feature.name], 'xmiID') and fs[feature.name].xmiID is not None:
references.setdefault(feature.name, [])
references[feature.name].append((fs.xmiID, fs[feature.name].xmiID))
else:
warnings.warn(f"Original non-primitive feature \"{feature.name}\" was and not copied from feature structure {fs.xmiID}.")

fs_copy.xmiID = fs.xmiID
all_copied_fs[fs_copy.xmiID] = fs_copy

# set references to single objects
for feature, pairs in references.items():
for current_ID, reference_ID in pairs:
try:
all_copied_fs[current_ID][feature] = all_copied_fs[reference_ID]
except KeyError as e:
warnings.warn(f"Reference {reference_ID} not found for feature '{feature}' of feature structure {current_ID}")

# set references for objects in arrays
for current_ID, arrays in referenced_arrays.items():
for feature, referenced_list in arrays.items():
elements = [all_copied_fs[reference_ID] for reference_ID in referenced_list]
all_copied_fs[current_ID][feature].elements = elements

# add feature structures to the appropriate views
feature_structures = sorted(all_copied_fs.values(), key=lambda f: f.xmiID, reverse=False)
for item in all_copied_fs.values():
if hasattr(item, 'xmiID') and item.xmiID is not None:
view_name = referenced_view.get(item.xmiID)
if view_name is not None:
cas_copy._current_view = cas_copy._views[view_name]
cas_copy.add(item, keep_id=True)

cas_copy._xmi_id_generator = IdGenerator(initial_id=self._xmi_id_generator._next_id)
cas_copy._sofa_num_generator = IdGenerator(initial_id=self._sofa_num_generator._next_id)
return cas_copy


def _sort_func(a: FeatureStructure) -> Tuple[int, int, int]:
d = a.__slots__
Expand Down
62 changes: 62 additions & 0 deletions tests/test_cas.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add randomized tests for the deep copy. We have test_multi_type_random_serialization_deserialization and test_multi_feature_random_serialization_deserialization which create randomized CASes for (de)serialization and compare them afterwards. These can be used as templates for randomized deep_copy tests.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AnnotationHasNoSofa,
)
from tests.fixtures import *
from tests.test_files.test_cas_generators import MultiFeatureRandomCasGenerator, MultiTypeRandomCasGenerator

# Cas

Expand Down Expand Up @@ -540,3 +541,64 @@ def test_covered_text_on_annotation_without_sofa():

with pytest.raises(AnnotationHasNoSofa):
ann.get_covered_text()


def test_deep_copy_without_typesystem(small_xmi, small_typesystem_xml):
org = load_cas_from_xmi(small_xmi, typesystem=load_typesystem(small_typesystem_xml))
copy = org.deep_copy(copy_typesystem=False)

assert org != copy
assert len(copy.to_json(pretty_print=True)) == len(org.to_json(pretty_print=True))
assert copy.to_json(pretty_print=True) == org.to_json(pretty_print=True)

assert org.typesystem == copy.typesystem


def test_deep_copy_with_typesystem(small_xmi, small_typesystem_xml):
org = load_cas_from_xmi(small_xmi, typesystem=load_typesystem(small_typesystem_xml))
copy = org.deep_copy(copy_typesystem=True)

assert org != copy
assert len(copy.to_json(pretty_print=True)) == len(org.to_json(pretty_print=True))
assert copy.to_json(pretty_print=True) == org.to_json(pretty_print=True)


assert org.typesystem != copy.typesystem
assert len(org.typesystem.to_xml()) == len(copy.typesystem.to_xml())
assert org.typesystem.to_xml() == copy.typesystem.to_xml()


def test_random_multi_type_random_deep_copy():
generator = MultiTypeRandomCasGenerator()
for i in range(0, 10):
generator.size = (i + 1) * 10
generator.type_count = i + 1
typesystem = generator.generate_type_system()
org = generator.generate_cas(typesystem)
print(f"CAS size: {sum(len(view.get_all_annotations()) for view in org.views)}")
copy = org.deep_copy(copy_typesystem=True)

org_text = org.to_xmi(pretty_print=True)
copy_text = copy.to_xmi(pretty_print=True)

assert org != copy
assert len(org_text) == len(copy_text)
assert org_text == copy_text


def test_random_multi_feature_deep_copy():
generator = MultiFeatureRandomCasGenerator()
for i in range(0, 10):
generator.size = (i + 1) * 10
typesystem = generator.generate_type_system()
org = generator.generate_cas(typesystem)
print(f"CAS size: {sum(len(view.get_all_annotations()) for view in org.views)}")
copy = org.deep_copy(copy_typesystem=True)

org_text = org.to_xmi(pretty_print=True)
copy_text = copy.to_xmi(pretty_print=True)

assert org != copy
assert len(org_text) == len(copy_text)
assert org_text == copy_text