Skip to content

Commit cf1641e

Browse files
committed
Merge branch 'main' into ltx1-dev
2 parents 554c7bd + 9997c59 commit cf1641e

File tree

6 files changed

+42
-53
lines changed

6 files changed

+42
-53
lines changed

src/maxdiffusion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
],
6565
}
6666

67+
if is_flax_available():
68+
from flax import config as flax_config
69+
70+
flax_config.update("flax_always_shard_variable", False)
71+
6772
try:
6873
if not is_onnx_available():
6974
raise OptionalDependencyNotAvailable()

src/maxdiffusion/configuration_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,11 @@ def load_config(
376376
if os.path.isfile(pretrained_model_name_or_path):
377377
config_file = pretrained_model_name_or_path
378378
elif os.path.isdir(pretrained_model_name_or_path):
379-
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
379+
if subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
380+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
381+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
380382
# Load from a PyTorch checkpoint
381383
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
382-
elif subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
383-
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
384384
else:
385385
raise EnvironmentError(f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.")
386386
else:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -444,19 +444,21 @@ def loss_fn(params):
444444
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
445445
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
446446

447-
model_pred = model(
448-
hidden_states=noisy_latents,
449-
timestep=timesteps,
450-
encoder_hidden_states=encoder_hidden_states,
451-
deterministic=False,
452-
rngs=nnx.Rngs(dropout_rng),
453-
)
447+
with jax.named_scope("forward_pass"):
448+
model_pred = model(
449+
hidden_states=noisy_latents,
450+
timestep=timesteps,
451+
encoder_hidden_states=encoder_hidden_states,
452+
deterministic=False,
453+
rngs=nnx.Rngs(dropout_rng),
454+
)
454455

455-
training_target = scheduler.training_target(latents, noise, timesteps)
456-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
457-
loss = (training_target - model_pred) ** 2
458-
loss = loss * training_weight
459-
loss = jnp.mean(loss)
456+
with jax.named_scope("loss"):
457+
training_target = scheduler.training_target(latents, noise, timesteps)
458+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
459+
loss = (training_target - model_pred) ** 2
460+
loss = loss * training_weight
461+
loss = jnp.mean(loss)
460462

461463
return loss
462464

src/maxdiffusion/utils/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import importlib
1515
import os
1616

17-
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
17+
from huggingface_hub.constants import HF_HOME, HUGGINGFACE_HUB_CACHE
1818
from packaging import version
1919

2020
from .import_utils import is_peft_available
@@ -34,7 +34,7 @@
3434
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
3535
DIFFUSERS_CACHE = default_cache_path
3636
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
37-
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
37+
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
3838
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
3939

4040
# Below should be `True` if the current version of `peft` and `transformers` are compatible with

src/maxdiffusion/utils/dynamic_modules_utils.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,14 @@
2121
import re
2222
import shutil
2323
import sys
24+
import tempfile
2425
from pathlib import Path
2526
from typing import Dict, Optional, Union
2627
from urllib import request
2728

28-
from huggingface_hub import HfFolder, hf_hub_download, model_info
29-
import huggingface_hub
29+
from huggingface_hub import get_token, hf_hub_download, model_info
3030
from packaging import version
3131

32-
cached_download = None
33-
3432
from .. import __version__
3533
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
3634

@@ -42,24 +40,6 @@
4240

4341
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4442

45-
# https://github.com/huggingface/huggingface_hub/releases/tag/v0.26.0
46-
# `cached_download(), url_to_filename(), filename_to_url() methods are now completely removed.
47-
# From now on, you will have to use hf_hub_download() to benefit from the new cache layout.`
48-
if hasattr(huggingface_hub, "__version__"):
49-
current_version = version.parse(huggingface_hub.__version__)
50-
target_version = version.parse("0.26.0")
51-
52-
if current_version < target_version:
53-
try:
54-
from huggingface_hub import cached_download
55-
56-
except ImportError:
57-
logger.error(
58-
f"huggingface_hub version {current_version} is below 0.26.0, but 'cached_download' could not be imported. It might have been removed or deprecated in this version as well."
59-
)
60-
else:
61-
logger.error("Could not determine huggingface_hub version. Unable to conditionally import 'cached_download'.")
62-
6343

6444
def get_diffusers_versions():
6545
url = "https://pypi.org/pypi/diffusers/json"
@@ -303,15 +283,17 @@ def get_cached_module_file(
303283
# community pipeline on GitHub
304284
github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path)
305285
try:
306-
resolved_module_file = cached_download(
307-
github_url,
308-
cache_dir=cache_dir,
309-
force_download=force_download,
310-
proxies=proxies,
311-
resume_download=resume_download,
312-
local_files_only=local_files_only,
313-
use_auth_token=False,
314-
)
286+
# Given that cached download has been removed, try using just urlopen
287+
fd, resolved_module_file = tempfile.mkstemp(dir=cache_dir)
288+
try:
289+
response = request.urlopen(github_url)
290+
with os.fdopen(fd, "wb") as f:
291+
f.write(response.read())
292+
except Exception:
293+
os.remove(resolved_module_file)
294+
raise EnvironmentError(
295+
f"Failed to download community pipeline from {github_url}. Please check if the url is correct."
296+
)
315297
submodule = "git"
316298
module_file = pretrained_model_name_or_path + ".py"
317299
except EnvironmentError:
@@ -328,7 +310,7 @@ def get_cached_module_file(
328310
proxies=proxies,
329311
resume_download=resume_download,
330312
local_files_only=local_files_only,
331-
use_auth_token=use_auth_token,
313+
token=use_auth_token,
332314
)
333315
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
334316
except EnvironmentError:
@@ -356,7 +338,7 @@ def get_cached_module_file(
356338
if isinstance(use_auth_token, str):
357339
token = use_auth_token
358340
elif use_auth_token is True:
359-
token = HfFolder.get_token()
341+
token = get_token()
360342
else:
361343
token = None
362344

src/maxdiffusion/utils/hub_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing import Dict, Optional, Union
2525
from uuid import uuid4
2626

27-
from huggingface_hub import HfFolder, ModelCard, ModelCardData, create_repo, hf_hub_download, upload_folder, whoami
27+
from huggingface_hub import ModelCard, ModelCardData, create_repo, get_token, hf_hub_download, upload_folder, whoami
2828
from huggingface_hub.file_download import REGEX_COMMIT_HASH
2929
from huggingface_hub.utils import (
3030
EntryNotFoundError,
@@ -92,7 +92,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
9292

9393
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
9494
if token is None:
95-
token = HfFolder.get_token()
95+
token = get_token()
9696
if organization is None:
9797
username = whoami(token)["name"]
9898
return f"{username}/{model_id}"
@@ -288,7 +288,7 @@ def _get_model_file(
288288
proxies=proxies,
289289
resume_download=resume_download,
290290
local_files_only=local_files_only,
291-
use_auth_token=use_auth_token,
291+
token=use_auth_token,
292292
user_agent=user_agent,
293293
subfolder=subfolder,
294294
revision=revision or commit_hash,

0 commit comments

Comments
 (0)