diff --git a/cassis/cas.py b/cassis/cas.py index 1125abc..c370159 100644 --- a/cassis/cas.py +++ b/cassis/cas.py @@ -23,6 +23,7 @@ TypeCheckError, TypeSystem, TypeSystemMode, + load_typesystem, ) _validator_optional_string = validators.optional(validators.instance_of(str)) @@ -832,6 +833,119 @@ def _copy(self) -> "Cas": result._xmi_id_generator = self._xmi_id_generator return result + def deep_copy(self, copy_typesystem: bool = False) -> "Cas": + """ + Create and return a deep copy of this CAS object. + All feature structures, views, and sofas are copied. If `copy_typesystem` is True, the typesystem is also deep-copied; + otherwise, the original typesystem is shared between the original and the copy. + Args: + copy_typesystem (bool): Whether to copy the original typesystem or not. If True, the typesystem is deep-copied. + Returns: + Cas: A deep copy of this CAS object. + """ + ts = self.typesystem + if copy_typesystem: + ts = self.typesystem.to_xml() + ts = load_typesystem(ts) + + cas_copy = Cas(ts, + document_language=self.document_language, + lenient=self._lenient, + sofa_mime=self.sofa_mime, + ) + + cas_copy._views = {} + cas_copy._sofas = {} + + for sofa in self.sofas: + + sofa_copy = Sofa( + sofaID=sofa.sofaID, + sofaNum=sofa.sofaNum, + type=ts.get_type(sofa.type.name), + xmiID=sofa.xmiID, + ) + sofa_copy.mimeType = sofa.mimeType + sofa_copy.sofaArray = sofa.sofaArray + sofa_copy.sofaString = sofa.sofaString + sofa_copy.sofaURI = sofa.sofaURI + + cas_copy._sofas[sofa_copy.sofaID] = sofa_copy + cas_copy._views[sofa_copy.sofaID] = View(sofa=sofa_copy) + + # removes the _IntialView created with the initialization of the copied CAS + cas_copy._current_view = cas_copy._views["_InitialView"] + + references = dict() + referenced_arrays = dict() + + all_copied_fs = dict() + referenced_view = {} + + for fs in self._find_all_fs(): + + # the referenced view is required when adding the fs to the copied cas later + if hasattr(fs, 'sofa') and fs.sofa and hasattr(fs, 'xmiID') and fs.xmiID: + referenced_view[fs.xmiID] = fs.sofa.sofaID + + t = ts.get_type(fs.type.name) + fs_copy = t() + + for feature in t.all_features: + if ts.is_primitive(feature.rangeType): + fs_copy[feature.name] = fs.get(feature.name) + elif ts.is_primitive_collection(feature.rangeType): + fs_copy[feature.name] = ts.get_type(feature.rangeType.name)() + fs_copy[feature.name].elements = fs.get(feature.name).elements + elif ts.is_array(feature.rangeType): + fs_copy[feature.name] = ts.get_type(TYPE_NAME_FS_ARRAY)() + # collect referenced xmiIDs for mapping later + referenced_list = [] + for item in fs[feature.name].elements: + if hasattr(item, 'xmiID') and item.xmiID is not None: + referenced_list.append(item.xmiID) + referenced_arrays.setdefault(fs.xmiID, {}) + referenced_arrays[fs.xmiID][feature.name] = referenced_list + elif feature.rangeType.name == TYPE_NAME_SOFA: + # ignore sofa references + pass + else: + if hasattr(fs[feature.name], 'xmiID') and fs[feature.name].xmiID is not None: + references.setdefault(feature.name, []) + references[feature.name].append((fs.xmiID, fs[feature.name].xmiID)) + else: + warnings.warn(f"Original non-primitive feature \"{feature.name}\" was and not copied from feature structure {fs.xmiID}.") + + fs_copy.xmiID = fs.xmiID + all_copied_fs[fs_copy.xmiID] = fs_copy + + # set references to single objects + for feature, pairs in references.items(): + for current_ID, reference_ID in pairs: + try: + all_copied_fs[current_ID][feature] = all_copied_fs[reference_ID] + except KeyError as e: + warnings.warn(f"Reference {reference_ID} not found for feature '{feature}' of feature structure {current_ID}") + + # set references for objects in arrays + for current_ID, arrays in referenced_arrays.items(): + for feature, referenced_list in arrays.items(): + elements = [all_copied_fs[reference_ID] for reference_ID in referenced_list] + all_copied_fs[current_ID][feature].elements = elements + + # add feature structures to the appropriate views + feature_structures = sorted(all_copied_fs.values(), key=lambda f: f.xmiID, reverse=False) + for item in all_copied_fs.values(): + if hasattr(item, 'xmiID') and item.xmiID is not None: + view_name = referenced_view.get(item.xmiID) + if view_name is not None: + cas_copy._current_view = cas_copy._views[view_name] + cas_copy.add(item, keep_id=True) + + cas_copy._xmi_id_generator = IdGenerator(initial_id=self._xmi_id_generator._next_id) + cas_copy._sofa_num_generator = IdGenerator(initial_id=self._sofa_num_generator._next_id) + return cas_copy + def _sort_func(a: FeatureStructure) -> Tuple[int, int, int]: d = a.__slots__ diff --git a/tests/test_cas.py b/tests/test_cas.py index 670db07..317a50f 100644 --- a/tests/test_cas.py +++ b/tests/test_cas.py @@ -11,6 +11,7 @@ AnnotationHasNoSofa, ) from tests.fixtures import * +from tests.test_files.test_cas_generators import MultiFeatureRandomCasGenerator, MultiTypeRandomCasGenerator # Cas @@ -540,3 +541,64 @@ def test_covered_text_on_annotation_without_sofa(): with pytest.raises(AnnotationHasNoSofa): ann.get_covered_text() + + +def test_deep_copy_without_typesystem(small_xmi, small_typesystem_xml): + org = load_cas_from_xmi(small_xmi, typesystem=load_typesystem(small_typesystem_xml)) + copy = org.deep_copy(copy_typesystem=False) + + assert org != copy + assert len(copy.to_json(pretty_print=True)) == len(org.to_json(pretty_print=True)) + assert copy.to_json(pretty_print=True) == org.to_json(pretty_print=True) + + assert org.typesystem == copy.typesystem + + +def test_deep_copy_with_typesystem(small_xmi, small_typesystem_xml): + org = load_cas_from_xmi(small_xmi, typesystem=load_typesystem(small_typesystem_xml)) + copy = org.deep_copy(copy_typesystem=True) + + assert org != copy + assert len(copy.to_json(pretty_print=True)) == len(org.to_json(pretty_print=True)) + assert copy.to_json(pretty_print=True) == org.to_json(pretty_print=True) + + + assert org.typesystem != copy.typesystem + assert len(org.typesystem.to_xml()) == len(copy.typesystem.to_xml()) + assert org.typesystem.to_xml() == copy.typesystem.to_xml() + + +def test_random_multi_type_random_deep_copy(): + generator = MultiTypeRandomCasGenerator() + for i in range(0, 10): + generator.size = (i + 1) * 10 + generator.type_count = i + 1 + typesystem = generator.generate_type_system() + org = generator.generate_cas(typesystem) + print(f"CAS size: {sum(len(view.get_all_annotations()) for view in org.views)}") + copy = org.deep_copy(copy_typesystem=True) + + org_text = org.to_xmi(pretty_print=True) + copy_text = copy.to_xmi(pretty_print=True) + + assert org != copy + assert len(org_text) == len(copy_text) + assert org_text == copy_text + + +def test_random_multi_feature_deep_copy(): + generator = MultiFeatureRandomCasGenerator() + for i in range(0, 10): + generator.size = (i + 1) * 10 + typesystem = generator.generate_type_system() + org = generator.generate_cas(typesystem) + print(f"CAS size: {sum(len(view.get_all_annotations()) for view in org.views)}") + copy = org.deep_copy(copy_typesystem=True) + + org_text = org.to_xmi(pretty_print=True) + copy_text = copy.to_xmi(pretty_print=True) + + assert org != copy + assert len(org_text) == len(copy_text) + assert org_text == copy_text +