Skip to content

Commit b766b0a

Browse files
committed
convert: enable loading pt/pth/safetensors for all archs #12
1 parent 838b485 commit b766b0a

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

scripts/convert.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def convert_tensor_2d(self, tensor: Tensor):
7575
self.conv2d_weights.append(self._index)
7676
return tensor
7777

78-
def add_int32(self, name: str, value: int):
79-
print("*", name, "=", value)
80-
super().add_int32(name, value)
78+
def add_int32(self, key: str, val: int):
79+
print("*", key, "=", val)
80+
super().add_int32(key, val)
8181

8282
def set_tensor_layout(self, layout: TensorLayout):
8383
print("*", f"{self.arch}.tensor_data_layout", "=", layout.value)
@@ -201,7 +201,7 @@ def convert_sam(input_filepath: Path, writer: Writer):
201201
writer.add_license("apache-2.0")
202202
writer.set_tensor_layout_default(TensorLayout.nchw)
203203

204-
model: dict[str, Tensor] = torch.load(input_filepath, map_location="cpu", weights_only=True)
204+
model = load_model(input_filepath)
205205

206206
for key, tensor in model.items():
207207
name = key
@@ -286,8 +286,7 @@ def convert_birefnet(input_filepath: Path, writer: Writer):
286286
writer.add_license("mit")
287287
writer.set_tensor_layout_default(TensorLayout.nchw)
288288

289-
weights = safetensors.safe_open(input_filepath, "pt")
290-
model: dict[str, Tensor] = {k: weights.get_tensor(k) for k in weights.keys()}
289+
model = load_model(input_filepath)
291290

292291
x = model["bb.layers.0.blocks.0.attn.proj.bias"]
293292
if x.shape[0] == 96:
@@ -360,7 +359,7 @@ def convert_depth_anything(input_filepath: Path, writer: Writer):
360359
writer.add_license("cc-by-nc-4.0")
361360
writer.set_tensor_layout_default(TensorLayout.nchw)
362361

363-
model: dict[str, Tensor] = load_model(input_filepath)
362+
model = load_model(input_filepath)
364363

365364
if "pretrained.cls_token" in model:
366365
print("The converter is written for the transformers (.safetensors) version of the model.")
@@ -411,7 +410,7 @@ def convert_migan(input_filepath: Path, writer: Writer):
411410
writer.add_license("mit")
412411
writer.set_tensor_layout_default(TensorLayout.nchw)
413412

414-
model: dict[str, Tensor] = torch.load(input_filepath, weights_only=True)
413+
model = load_model(input_filepath)
415414

416415
if "encoder.b512.fromrgb.weight" in model:
417416
writer.add_int32("migan.image_size", 512)

0 commit comments

Comments
 (0)