Skip to content

Commit 0a57890

Browse files
gabewillendeanq
andauthored
Feature/E-2131 Utility function for resolving model-cache paths from Huggingface repositories (#377)
* added a utility for resolving model cache paths from a huggingface repository * Added a TODO for the `path_template` key word argument * added unit tests for model cache resolver * fixed module documentation * resolve to None when a repository is improperly formatted * fixed comment wording --------- Co-authored-by: Dean Quiñanola <dean.quinanola@runpod.io>
1 parent 527db3c commit 0a57890

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Utility function for transforming HuggingFace repositories into model-cache paths"""
2+
3+
import typing
4+
from runpod.serverless.modules.rp_logger import RunPodLogger
5+
6+
log = RunPodLogger()
7+
8+
9+
def resolve_model_cache_path_from_hugginface_repository(
10+
huggingface_repository: str,
11+
/,
12+
path_template: str = "/runpod/cache/{model}/{revision}", # TODO: Should we just hardcode this?
13+
) -> typing.Union[str, None]:
14+
"""
15+
Resolves the model-cache path for a HuggingFace model based on its repository string.
16+
17+
Args:
18+
huggingface_repository (str): Repository string in format "model_name:revision" or
19+
"org/model_name:revision". If no revision is specified,
20+
"main" is used. For example:
21+
- "runwayml/stable-diffusion-v1-5:experimental"
22+
- "runwayml/stable-diffusion-v1-5" (uses "main" revision)
23+
- "stable-diffusion-v1-5:main"
24+
path_template (str, optional): Template string for the cache path. Must contain {model}
25+
and {revision} placeholders. Defaults to "/runpod/cache/{model}/{revision}".
26+
27+
Returns:
28+
str | None: Absolute path where the model is cached, following the template provided in path_template. Returns None if no model name could be extracted.
29+
30+
Examples:
31+
>>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5:experimental")
32+
"/runpod/cache/runwayml/stable-diffusion-v1-5/experimental"
33+
>>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5")
34+
"/runpod/cache/runwayml/stable-diffusion-v1-5/main"
35+
>>> resolve_model_cache_path_from_hugginface_repository(":experimental")
36+
None
37+
"""
38+
model, *revision = huggingface_repository.rsplit(":", 1)
39+
if not model:
40+
# We could throw an exception here but returning None allows us to filter a list of repositories without needing a try/except block
41+
log.warn( # type: ignore in strict mode the typechecker complains about this method being partially unknown
42+
f'Unable to resolve the model-cache path for "{huggingface_repository}"'
43+
)
44+
return None
45+
return path_template.format(
46+
model=model, revision=revision[0] if revision else "main"
47+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
3+
from runpod.serverless.utils.rp_model_cache import (
4+
resolve_model_cache_path_from_hugginface_repository,
5+
)
6+
7+
8+
class TestModelCache(unittest.TestCase):
9+
"""Tests for rp_model_cache"""
10+
11+
def test_with_revision(self):
12+
"""Test with a revision"""
13+
path = resolve_model_cache_path_from_hugginface_repository(
14+
"runwayml/stable-diffusion-v1-5:experimental"
15+
)
16+
self.assertEqual(
17+
path, "/runpod/cache/runwayml/stable-diffusion-v1-5/experimental"
18+
)
19+
20+
def test_without_revision(self):
21+
"""Test without a revision"""
22+
path = resolve_model_cache_path_from_hugginface_repository(
23+
"runwayml/stable-diffusion-v1-5"
24+
)
25+
self.assertEqual(path, "/runpod/cache/runwayml/stable-diffusion-v1-5/main")
26+
27+
def test_with_multiple_colons(self):
28+
"""Test with multiple colons"""
29+
path = resolve_model_cache_path_from_hugginface_repository(
30+
"runwayml/stable-diffusion:v1-5:experimental"
31+
)
32+
self.assertEqual(
33+
path, "/runpod/cache/runwayml/stable-diffusion:v1-5/experimental"
34+
)
35+
36+
def test_with_custom_path_template(self):
37+
"""Test with a custom path template"""
38+
path = resolve_model_cache_path_from_hugginface_repository(
39+
"runwayml/stable-diffusion-v1-5:experimental",
40+
"/my-custom-model-cache/{model}/{revision}",
41+
)
42+
self.assertEqual(
43+
path, "/my-custom-model-cache/runwayml/stable-diffusion-v1-5/experimental"
44+
)
45+
46+
def test_with_missing_model_name(self):
47+
"""Test with a missing model name"""
48+
path = resolve_model_cache_path_from_hugginface_repository(":experimental")
49+
self.assertIsNone(path)
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)