Skip to content

Commit 7ec026b

Browse files
committed
ruff check
1 parent 3a4463b commit 7ec026b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
391391
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
392392
if not is_sparse:
393393
# down_weight is copied to each split
394-
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
394+
ait_sd.update({k: down_weight for k in ait_down_keys})
395395

396396
# up_weight is split to each split
397397
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -534,7 +534,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
534534
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
535535

536536
# down_weight is copied to each split
537-
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
537+
ait_sd.update({k: down_weight for k in ait_down_keys})
538538

539539
# up_weight is split to each split
540540
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

src/maxdiffusion/pipelines/pipeline_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def load_module(name, value):
473473
class_obj = import_flax_or_no_model(pipeline_module, class_name)
474474

475475
importable_classes = ALL_IMPORTABLE_CLASSES
476-
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
476+
class_candidates = {c: class_obj for c in importable_classes.keys()}
477477
else:
478478
# else we just import it from the library.
479479

0 commit comments

Comments
 (0)