From d0919fa7643033d3ef736a1a98a06731216558d8 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:58:18 +1200 Subject: [PATCH 1/4] Fix bug with repeated checkpoint path Code was calling `os.path.join()` too many times and causing the path to be repeated unnecessarily. Signed-off-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- fms_fsdp/utils/checkpointing_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 5381dc9d..83c9b59f 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -25,12 +25,8 @@ def get_latest(targdir, qualifier=lambda x: True): If directory is empty or nonexistent or no items qualify, return None.""" if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: latest = max( - [ - os.path.join(targdir, x) - for x in os.listdir(targdir) - if qualifier(os.path.join(targdir, x)) - ], - key=lambda path: int(path.split("/")[-1].split("_")[1]), + [x for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x))], + key=lambda path: int(path.split("_")[1]), ) return os.path.join(targdir, latest) return None From 4a2b4b878f724e584efea980646a58c7286b470b Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 28 Aug 2024 00:59:22 +0100 Subject: [PATCH 2/4] Remove os.path.join in get_oldest function also Signed-off-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- fms_fsdp/utils/checkpointing_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 83c9b59f..97a5d80b 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -37,11 +37,7 @@ def get_oldest(targdir, qualifier=lambda x: True): If directory is empty or nonexistent or no items qualify, return None.""" if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: oldest = min( - [ - os.path.join(targdir, x) - for x in os.listdir(targdir) - if qualifier(os.path.join(targdir, x)) - ], + [x for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x))], key=os.path.getctime, ) return os.path.join(targdir, oldest) From f082164a6170df40f42f0469738fe302f303846e Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 28 Aug 2024 00:59:49 +0100 Subject: [PATCH 3/4] Add unit tests for get_latest and get_oldest functions When a list of 3 files are passed into the get_oldest and get_latest functions, ensure that the files that were created first and last are returned respectively. Signed-off-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- tests/test_utils_checkpointing.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_utils_checkpointing.py diff --git a/tests/test_utils_checkpointing.py b/tests/test_utils_checkpointing.py new file mode 100644 index 00000000..d730cd2b --- /dev/null +++ b/tests/test_utils_checkpointing.py @@ -0,0 +1,36 @@ +""" +Tests for functions in utils/checkpointing_utils.py +""" + +import os +import tempfile + +from fms_fsdp.utils.checkpointing_utils import get_latest, get_oldest + + +def test_get_oldest(): + """ + Ensure that the get_oldest function returns the name of the file with the oldest + timestamp (i.e. that was created first). + """ + with tempfile.TemporaryDirectory() as tempdir: + for i in range(3): + filename = os.path.join(tempdir, f"file_{i}") + print("random content", file=open(file=filename, mode="w")) + + oldest_filename = get_oldest(targdir=tempdir) + assert oldest_filename.endswith("file_0") + + +def test_get_latest(): + """ + Ensure that the get_latest function returns the name of the file with the latest + integer suffix (i.e. that was created last). + """ + with tempfile.TemporaryDirectory() as tempdir: + for i in range(3): + filename = os.path.join(tempdir, f"file_{i}") + print("random content", file=open(file=filename, mode="w")) + + latest_filename = get_latest(targdir=tempdir) + assert latest_filename.endswith("file_2") From cf39fc7b13e5d418622d23a093d186af46a4dc30 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:26:27 +1200 Subject: [PATCH 4/4] Run os.path.getctime on full path instead of just filename Signed-off-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- fms_fsdp/utils/checkpointing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 97a5d80b..91050e24 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -38,7 +38,7 @@ def get_oldest(targdir, qualifier=lambda x: True): if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: oldest = min( [x for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x))], - key=os.path.getctime, + key=lambda path: os.path.getctime(os.path.join(targdir, path)), ) return os.path.join(targdir, oldest) return None