diff --git a/src/fastcs/datatypes/string.py b/src/fastcs/datatypes/string.py index 93c53aae..e4deb15b 100644 --- a/src/fastcs/datatypes/string.py +++ b/src/fastcs/datatypes/string.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Any from fastcs.datatypes.datatype import DataType @@ -8,7 +9,11 @@ class String(DataType[str]): """`DataType` mapping to builtin ``str``.""" length: int | None = None - """Maximum length of string to display in transports""" + """Maximum length of string to display in transports. Must be >=1 or None.""" + + def __post_init__(self): + if self.length is not None and self.length < 1: + raise ValueError("String length must be >= 1") @property def dtype(self) -> type[str]: @@ -17,3 +22,12 @@ def dtype(self) -> type[str]: @property def initial_value(self) -> str: return "" + + def validate(self, value: Any) -> str: + """Truncate string to maximum length + + Returns: + The string, truncated to the maximum length if set + + """ + return super().validate(value)[: self.length] diff --git a/src/fastcs/transports/epics/ca/util.py b/src/fastcs/transports/epics/ca/util.py index 0c483a92..2f19d250 100644 --- a/src/fastcs/transports/epics/ca/util.py +++ b/src/fastcs/transports/epics/ca/util.py @@ -146,7 +146,10 @@ def cast_to_epics_type(datatype: DataType[DType_T], value: DType_T) -> Any: else: # enum backed by string record return datatype.validate(value).name case String() as string: - return value[: string.length] + if string.length is not None: + return value[: string.length] + else: + return value[:DEFAULT_STRING_WAVEFORM_LENGTH] case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): return value case _: diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index e49e80eb..b0b26d56 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -104,3 +104,12 @@ def test_dataset_equal(fastcs_datatype: DataType, value1, value2, expected): ) def test_dataset_all_equal(fastcs_datatype: DataType, values, expected): assert fastcs_datatype.all_equal(values) is expected + + +def test_string_length(): + assert String(length=10).validate("12345678901") == "1234567890" + + assert String().validate("12345678901") == "12345678901" + + with pytest.raises(ValueError): + String(length=0) diff --git a/tests/transports/epics/ca/test_ca_util.py b/tests/transports/epics/ca/test_ca_util.py index 8c23d8e0..2993b1e6 100644 --- a/tests/transports/epics/ca/test_ca_util.py +++ b/tests/transports/epics/ca/test_ca_util.py @@ -77,7 +77,8 @@ class ShortMixedEnum(enum.Enum): (Int(), 4, 4), (Float(), 1.0, 1.0), (Bool(), True, True), - (String(), "hey", "hey"), + (String(), "a" * 257, "a" * 256), + (String(length=3), "1234", "123"), # shorter enums can be represented by integers from 0-15 (Enum(ShortMixedEnum), ShortMixedEnum.STRING_MEMBER, 0), (Enum(ShortMixedEnum), ShortMixedEnum.INT_MEMBER, 1),