diff --git a/cfgs/pipeline/split_inference.yaml b/cfgs/pipeline/split_inference.yaml index b60b6cb..368f7e7 100755 --- a/cfgs/pipeline/split_inference.yaml +++ b/cfgs/pipeline/split_inference.yaml @@ -36,6 +36,10 @@ codec: nn_task_part2: dump_results: False output_results_dir: "${codec.output_dir}/output_results" + dump_features: False + feature_dir: "${..output_dir_root}/features_pre_nn_part2/${dataset.datacatalog}/${dataset.config.dataset_name}" + dump_features_hash: False + hash_format: md5 conformance: save_conformance_files: False subsample_ratio: 9 diff --git a/compressai_vision/pipelines/base.py b/compressai_vision/pipelines/base.py index 0aa7a44..52a396e 100755 --- a/compressai_vision/pipelines/base.py +++ b/compressai_vision/pipelines/base.py @@ -49,6 +49,11 @@ min_max_normalization, ) from compressai_vision.model_wrappers import BaseWrapper +from compressai_vision.utils import ( + FileLikeHasher, + contiguous_features, + freeze_zip_timestamps, +) class Parts(Enum): @@ -349,6 +354,12 @@ def _from_features_to_output( """performs the inference of the 2nd part of the NN model""" output_results_dir = self.configs["nn_task_part2"].output_results_dir + seq_name = ( + seq_name + if seq_name is not None + else os.path.splitext(os.path.basename(x.get("file_name", "")))[0] + ) + results_file = f"{output_results_dir}/{seq_name}{self._output_ext}" assert "data" in x @@ -374,6 +385,40 @@ def _from_features_to_output( for k, v in zip(vision_model.split_layer_list, x["data"].values()) } + if ( + self.configs["nn_task_part2"].dump_features + or self.configs["nn_task_part2"].dump_features_hash + ): + feature_dir = self.configs["nn_task_part2"].feature_dir + self._create_folder(feature_dir) + + dump_feature_hash = self.configs["nn_task_part2"].dump_features_hash + hash_format = self.configs["nn_task_part2"].hash_format + + feature_output_ext = ( + f".{hash_format}" if dump_feature_hash else self._output_ext + ) + path = f"{feature_dir}/{seq_name}{feature_output_ext}" + + features_file = ( + FileLikeHasher(path, hash_format) if dump_feature_hash else path + ) + + self.logger.debug(f"dumping features prior to nn part2 in: {feature_dir}") + + # [TODO] align with nn_task_part1 dump features + features_to_dump = contiguous_features(x) + + with freeze_zip_timestamps(): + if dump_feature_hash: + torch.save(features_to_dump, features_file, pickle_protocol=4) + else: + with open(features_file, "wb") as f: + torch.save(features_to_dump, f, pickle_protocol=4) + + if hasattr(features_file, "close"): + features_file.close() + results = vision_model.features_to_output(x, self.device_nn_part2) if self.configs["nn_task_part2"].dump_results: self._create_folder(output_results_dir) diff --git a/compressai_vision/utils/__init__.py b/compressai_vision/utils/__init__.py index ffbf039..ca0b9eb 100644 --- a/compressai_vision/utils/__init__.py +++ b/compressai_vision/utils/__init__.py @@ -29,6 +29,7 @@ from . import dataio, git, pip, system from .external_exec import get_max_num_cpus +from .hash import FileLikeHasher, contiguous_features, freeze_zip_timestamps from .misc import dict_sum, dl_to_ld, ld_to_dl, metric_tracking, time_measure, to_cpu __all__ = [ @@ -43,4 +44,7 @@ "dict_sum", "dl_to_ld", "ld_to_dl", + "FileLikeHasher", + "freeze_zip_timestamps", + "contiguous_features", ] diff --git a/compressai_vision/utils/hash.py b/compressai_vision/utils/hash.py new file mode 100644 index 0000000..1bd74e6 --- /dev/null +++ b/compressai_vision/utils/hash.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022-2024, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import hashlib +import zipfile + +from collections import OrderedDict +from collections.abc import Mapping, Sequence +from contextlib import contextmanager +from typing import Tuple + +import torch + + +class FileLikeHasher: + def __init__(self, fn, algo: str = "md5"): + self._h = hashlib.new(algo) + self._fn = fn + self._nbytes = 0 + + def write(self, byts): + self._h.update(byts) + self._nbytes += len(byts) + return len(byts) + + def flush(self): + pass + + def close(self): + with open(self._fn, "w") as f: + f.write(self._h.hexdigest()) + f.write("\n") + + +@contextmanager +def freeze_zip_timestamps( + fixed: Tuple[int, int, int, int, int, int] = (1980, 1, 1, 0, 0, 0), +): + _orig_init = zipfile.ZipInfo.__init__ + + def _patched(self, *args, **kwargs): + _orig_init(self, *args, **kwargs) + self.date_time = fixed # ZIP fixed time + + zipfile.ZipInfo.__init__ = _patched + try: + yield + finally: + zipfile.ZipInfo.__init__ = _orig_init + + +def contiguous_features(obj): + if isinstance(obj, torch.Tensor): + return obj.to("cpu").contiguous().clone() + + if isinstance(obj, Mapping): + return OrderedDict( + (k, contiguous_features(v)) + for k, v in sorted(obj.items(), key=lambda item: str(item[0])) + if not str(k).startswith("file") + ) + + if isinstance(obj, set): + return tuple(sorted(obj, key=str)) + + if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): + return type(obj)(contiguous_features(v) for v in obj) + + return obj