diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index aee24985c..ca971c0ef 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -11,7 +11,6 @@ on: - 'WORKSPACE' - 'src/**' pull_request: - branches: [ main ] paths-ignore: - 'WORKSPACE' - 'src/**' diff --git a/launchable/utils/click.py b/launchable/utils/click.py index 1114cfafb..255a94123 100644 --- a/launchable/utils/click.py +++ b/launchable/utils/click.py @@ -29,14 +29,20 @@ class PercentageType(ParamType): def convert(self, value: str, param: Optional[click.core.Parameter], ctx: Optional[click.core.Context]): try: + missing_percent = False if value.endswith('%'): x = float(value[:-1]) / 100 if 0 <= x <= 100: return x + else: + missing_percent = True except ValueError: pass - self.fail("Expected percentage like 50% but got '{}'".format(value), param, ctx) + msg = "Expected percentage like 50% but got '{}'".format(value) + if missing_percent and sys.platform.startswith("win"): + msg += " ('%' is a special character in batch files, so please write '50%%' to pass in '50%')" + self.fail(msg, param, ctx) class DurationType(ParamType): diff --git a/tests/utils/test_click.py b/tests/utils/test_click.py index b04029d64..9ee41600f 100644 --- a/tests/utils/test_click.py +++ b/tests/utils/test_click.py @@ -1,4 +1,5 @@ import datetime +import sys from datetime import timezone from typing import Sequence, Tuple from unittest import TestCase @@ -7,7 +8,51 @@ from click.testing import CliRunner from dateutil.tz import tzlocal -from launchable.utils.click import DATETIME_WITH_TZ, KEY_VALUE, convert_to_seconds +from launchable.utils.click import DATETIME_WITH_TZ, KEY_VALUE, PercentageType, convert_to_seconds + + +class PercentageTypeTest(TestCase): + ERROR_MSG = "Expected percentage like 50% but got" + WINDOWS_ERROR_MSG = "please write '50%%' to pass in '50%'" + + def test_invalid_value_windows(self): + pct = PercentageType() + orig_platform = sys.platform + sys.platform = "win32" + try: + with self.assertRaises(click.BadParameter) as cm: + pct.convert("50", None, None) + msg = str(cm.exception) + self.assertIn(self.ERROR_MSG + " '50'", msg) + self.assertIn(self.WINDOWS_ERROR_MSG, msg) + finally: + sys.platform = orig_platform + + def test_invalid_value_non_windows(self): + pct = PercentageType() + orig_platform = sys.platform + sys.platform = "linux" + try: + with self.assertRaises(click.BadParameter) as cm: + pct.convert("50", None, None) + msg = str(cm.exception) + self.assertIn(self.ERROR_MSG + " '50'", msg) + self.assertNotIn(self.WINDOWS_ERROR_MSG, msg) + finally: + sys.platform = orig_platform + + def test_invalid_float(self): + pct = PercentageType() + with self.assertRaises(click.BadParameter) as cm: + pct.convert("abc%", None, None) + msg = str(cm.exception) + self.assertIn(self.ERROR_MSG + " 'abc%'", msg) + + def test_valid(self): + pct = PercentageType() + self.assertEqual(pct.convert("50%", None, None), 0.5) + self.assertEqual(pct.convert("0%", None, None), 0.0) + self.assertEqual(pct.convert("100%", None, None), 1.0) class DurationTypeTest(TestCase):