Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions keras/src/saving/file_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,33 +455,120 @@ def resave_weights(self, filepath):
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
metadata = metadata or {}

# ------------------------------------------------------
# Collect metadata for this HDF5 group
# ------------------------------------------------------
object_metadata = {}
for k, v in data.attrs.items():
object_metadata[k] = v
if object_metadata:
metadata[inner_path] = object_metadata

result = collections.OrderedDict()

# ------------------------------------------------------
# Iterate over all keys in this HDF5 group
# ------------------------------------------------------
for key in data.keys():
inner_path = f"{inner_path}/{key}"
# IMPORTANT:
# Never mutate inner_path; use local variable.
current_inner_path = f"{inner_path}/{key}"
value = data[key]

# ------------------------------------------------------
# CASE 1 — HDF5 GROUP → RECURSE
# ------------------------------------------------------
if isinstance(value, h5py.Group):
# Skip empty groups
if len(value) == 0:
continue

# Skip empty "vars" groups
if "vars" in value.keys() and len(value["vars"]) == 0:
continue

if hasattr(value, "keys"):
# Recurse into "vars" subgroup when present
if "vars" in value.keys():
result[key], metadata = self._extract_weights_from_store(
value["vars"], metadata=metadata, inner_path=inner_path
value["vars"],
metadata=metadata,
inner_path=current_inner_path,
)
else:
# Recurse normally
result[key], metadata = self._extract_weights_from_store(
value, metadata=metadata, inner_path=inner_path
value,
metadata=metadata,
inner_path=current_inner_path,
)
else:
result[key] = value[()]

continue # finished processing this key

# ------------------------------------------------------
# CASE 2 — HDF5 DATASET → SAFE LOADING
# ------------------------------------------------------

# Skip any objects that are not proper datasets
if not hasattr(value, "shape") or not hasattr(value, "dtype"):
continue

shape = value.shape
dtype = value.dtype

# ------------------------------------------------------
# Validate SHAPE (avoid malformed / malicious metadata)
# ------------------------------------------------------
try:
# No negative dims
if any(dim < 0 for dim in shape):
raise ValueError(
"Negative dimension in HDF5 dataset shape."
)

# Prevent absurdly high-rank tensors
if len(shape) > 64:
raise ValueError("HDF5 dataset rank too large (>64).")

# Ensure product does not overflow
num_elems = int(np.prod(shape))
if num_elems < 0:
raise ValueError(
"Overflow in dataset shape multiplication."
)

except Exception as e:
raise ValueError(
"Malformed HDF5 dataset shape encountered in .keras file; "
"refusing to load."
) from e

# ------------------------------------------------------
# Validate TOTAL memory size
# ------------------------------------------------------
max_bytes = 1 << 30 # 1 GiB

try:
size_bytes = num_elems * dtype.itemsize
except Exception as e:
raise ValueError(
"Malformed HDF5 dtype encountered in .keras file; "
"refusing to load."
) from e

if size_bytes > max_bytes:
raise ValueError(
f"HDF5 dataset too large to load safely "
f"({size_bytes} bytes; limit is {max_bytes})."
)

# ------------------------------------------------------
# SAFE — load dataset (guaranteed ≤ 1 GiB)
# ------------------------------------------------------
result[key] = value[()]

# ------------------------------------------------------
# Return final tree and metadata
# ------------------------------------------------------
return result, metadata

def _generate_filepath_info(self, rich_style=False):
Expand Down