diff --git a/changes/20251030161706.feature b/changes/20251030161706.feature new file mode 100644 index 0000000000..311ae68b98 --- /dev/null +++ b/changes/20251030161706.feature @@ -0,0 +1 @@ +:sparkles: `serialization` Add support for decoding enum strings to their underlying values when using mapstructure tags diff --git a/utils/config/service_configuration.go b/utils/config/service_configuration.go index f0c5cf0cf7..646c95be3f 100644 --- a/utils/config/service_configuration.go +++ b/utils/config/service_configuration.go @@ -20,6 +20,7 @@ import ( "github.com/ARM-software/golang-utils/utils/field" "github.com/ARM-software/golang-utils/utils/keyring" "github.com/ARM-software/golang-utils/utils/reflection" + "github.com/ARM-software/golang-utils/utils/serialization/maps" //nolint:misspell ) const ( @@ -113,7 +114,12 @@ func LoadFromEnvironmentAndSystem(viperSession *viper.Viper, envVarPrefix string } // Merge together all the sources and unmarshal into struct - err = viperSession.Unmarshal(configurationToSet) + err = viperSession.Unmarshal(configurationToSet, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( + maps.CustomTypeHookFunc(), + // Keep these two as they are the default values used by viper and we don't want to override them + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ))) if err != nil { err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "unable to fill configuration structure from the configuration session") return diff --git a/utils/config/service_configuration_test.go b/utils/config/service_configuration_test.go index 2c2da66a62..a567b8379c 100644 --- a/utils/config/service_configuration_test.go +++ b/utils/config/service_configuration_test.go @@ -26,6 +26,7 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" "github.com/ARM-software/golang-utils/utils/keyring" + mapstest "github.com/ARM-software/golang-utils/utils/serialization/maps/testing" //nolint:misspell ) var ( @@ -39,13 +40,15 @@ var ( ) type DummyConfiguration struct { - Host string `mapstructure:"dummy_host"` - Port int `mapstructure:"port"` - DB string `mapstructure:"db"` - User string `mapstructure:"user"` - Password string `mapstructure:"password"` - Flag bool `mapstructure:"flag"` - HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"` + Host string `mapstructure:"dummy_host"` + Port int `mapstructure:"port"` + DB string `mapstructure:"db"` + User string `mapstructure:"user"` + Password string `mapstructure:"password"` + Flag bool `mapstructure:"flag"` + TestEnum mapstest.TestEnumWithUnmarshal `mapstructure:"enum"` + TestEnum1 mapstest.TestEnumWithoutUnmarshal `mapstructure:"enum1"` + HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"` } func (cfg *DummyConfiguration) Validate() error { @@ -55,6 +58,7 @@ func (cfg *DummyConfiguration) Validate() error { validation.Field(&cfg.DB, validation.Required), validation.Field(&cfg.User, validation.Required), validation.Field(&cfg.Password, validation.Required), + validation.Field(&cfg.TestEnum, validation.By(mapstest.ValidationFunc)), ) } @@ -189,6 +193,14 @@ func TestServiceConfigurationLoad(t *testing.T) { require.NoError(t, err) err = os.Setenv("TEST_DUMMY_CONFIG_USER", "a test user") require.NoError(t, err) + err = os.Setenv("TEST_DUMMY_CONFIG_ENUM", mapstest.TestEnumStringVer1) + require.NoError(t, err) + err = os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1) + require.NoError(t, err) + err = os.Setenv("TEST_DUMMY_CONFIG_ENUM1", "1") + require.NoError(t, err) + err = os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1") + require.NoError(t, err) err = os.Setenv("TEST_DUMMYCONFIG_DB", "a test db") require.NoError(t, err) err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB) @@ -407,6 +419,8 @@ func TestFlagsBinding(t *testing.T) { flagSet.Int("int", 0, "dummy int") flagSet.Duration("time", time.Second, "dummy time") flagSet.Bool("flag", false, "dummy flag") + flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum") + flagSet.String("enum1", "1", "dummy enum") err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host2"), flagSet.Lookup("host2")) require.NoError(t, err) err = BindFlagsToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host1"), flagSet.Lookup("host2")) @@ -419,6 +433,10 @@ func TestFlagsBinding(t *testing.T) { require.NoError(t, err) err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_USER", flagSet.Lookup("user1"), flagSet.Lookup("user2")) require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM", flagSet.Lookup("enum")) + require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM1", flagSet.Lookup("enum1")) + require.NoError(t, err) err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DB", flagSet.Lookup("db")) require.NoError(t, err) err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_DB", flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2")) @@ -476,6 +494,12 @@ func TestFlagsBinding(t *testing.T) { assert.Equal(t, expectedHost, configTest.TestConfig2.Host) assert.Equal(t, expectedPassword, configTest.TestConfig.Password) assert.Equal(t, expectedPassword, configTest.TestConfig2.Password) + assert.NotEqual(t, mapstest.TestEnumStringVer1, configTest.TestConfig2.TestEnum) + assert.NotEqual(t, mapstest.TestEnumStringVer0, configTest.TestConfig.TestEnum) + assert.Equal(t, mapstest.TestEnumWithUnmarshal1, configTest.TestConfig2.TestEnum) + assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum) + assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, configTest.TestConfig2.TestEnum1) + assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1) assert.Equal(t, expectedDB, configTest.TestConfig.DB) assert.Equal(t, aDifferentDB, configTest.TestConfig2.DB) assert.NotEqual(t, expectedDB, configTest.TestConfig2.DB) @@ -516,6 +540,8 @@ func TestFlagBindingDefaults(t *testing.T) { flagSet.Int("int", expectedInt, "dummy int") flagSet.Duration("time", expectedDuration, "dummy time") flagSet.Bool("flag", !DefaultDummyConfiguration().Flag, "dummy flag") + flagSet.String("enum", mapstest.TestEnumStringVer0, "dummy enum") + flagSet.String("enum1", "0", "dummy enum") err = BindFlagToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host")) require.NoError(t, err) err = BindFlagToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host2")) @@ -538,6 +564,10 @@ func TestFlagBindingDefaults(t *testing.T) { require.NoError(t, err) err = BindFlagToEnv(session, prefix, "DUMMY_Time", flagSet.Lookup("time")) require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "DUMMY_enum", flagSet.Lookup("enum")) + require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "DUMMY_enum1", flagSet.Lookup("enum1")) + require.NoError(t, err) err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB) // Should take precedence over flag default require.NoError(t, err) err = LoadFromViper(session, prefix, configTest, defaults) @@ -556,6 +586,10 @@ func TestFlagBindingDefaults(t *testing.T) { assert.Equal(t, expectedPassword, configTest.TestConfig2.Password) assert.Equal(t, aDifferentDB, configTest.TestConfig.DB) assert.Equal(t, expectedDB, configTest.TestConfig2.DB) + assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig2.TestEnum) + assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum) + assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig2.TestEnum1) + assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1) // Defaults from the default structure provided take precedence over defaults from flags when empty. assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig.Flag) assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig2.Flag) @@ -575,6 +609,8 @@ func TestGenerateEnvFile_Defaults(t *testing.T) { "TEST_PASSWORD": configTest.Password, "TEST_PORT": configTest.Port, "TEST_USER": configTest.User, + "TEST_ENUM": configTest.TestEnum, + "TEST_ENUM1": configTest.TestEnum1, } // Generate env file @@ -601,6 +637,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) { flagSet.String("password", "a password", "dummy password") flagSet.String("user", "a user", "dummy user") flagSet.String("db", "a db", "dummy db") + flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum") + flagSet.String("enum1", "1", "dummy enum") err = BindFlagToEnv(session, prefix, "TEST_DUMMY_HOST", flagSet.Lookup("host")) require.NoError(t, err) err = BindFlagToEnv(session, prefix, "PASSWORD", flagSet.Lookup("password")) @@ -611,6 +649,10 @@ func TestGenerateEnvFile_Populated(t *testing.T) { require.NoError(t, err) err = BindFlagToEnv(session, prefix, "USER", flagSet.Lookup("user")) require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "ENUM", flagSet.Lookup("enum")) + require.NoError(t, err) + err = BindFlagToEnv(session, prefix, "ENUM1", flagSet.Lookup("enum1")) + require.NoError(t, err) err = flagSet.Set("host", expectedHost) require.NoError(t, err) err = flagSet.Set("password", expectedPassword) @@ -630,6 +672,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) { "TEST_PASSWORD": configTest.Password, "TEST_PORT": configTest.Port, "TEST_USER": configTest.User, + "TEST_ENUM": configTest.TestEnum, + "TEST_ENUM1": configTest.TestEnum1, } // Generate env file @@ -670,6 +714,8 @@ func TestGenerateEnvFile_Nested(t *testing.T) { "TEST_DEEP_CONFIG_DUMMYCONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig.Password, "TEST_DEEP_CONFIG_DUMMYCONFIG_PORT": configTest.TestConfigDeep.TestConfig.Port, "TEST_DEEP_CONFIG_DUMMYCONFIG_USER": configTest.TestConfigDeep.TestConfig.User, + "TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM": configTest.TestConfigDeep.TestConfig.TestEnum, + "TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM1": configTest.TestConfigDeep.TestConfig.TestEnum1, "TEST_DEEP_CONFIG_DUMMY_CONFIG_DB": configTest.TestConfigDeep.TestConfig2.DB, "TEST_DEEP_CONFIG_DUMMY_CONFIG_DUMMY_HOST": configTest.TestConfigDeep.TestConfig2.Host, "TEST_DEEP_CONFIG_DUMMY_CONFIG_FLAG": configTest.TestConfigDeep.TestConfig2.Flag, @@ -677,6 +723,8 @@ func TestGenerateEnvFile_Nested(t *testing.T) { "TEST_DEEP_CONFIG_DUMMY_CONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig2.Password, "TEST_DEEP_CONFIG_DUMMY_CONFIG_PORT": configTest.TestConfigDeep.TestConfig2.Port, "TEST_DEEP_CONFIG_DUMMY_CONFIG_USER": configTest.TestConfigDeep.TestConfig2.User, + "TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM": configTest.TestConfigDeep.TestConfig2.TestEnum, + "TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM1": configTest.TestConfigDeep.TestConfig2.TestEnum1, "TEST_DEEP_CONFIG_DUMMY_INT": configTest.TestConfigDeep.TestInt, "TEST_DUMMY_STRING": configTest.TestString, "TEST_DEEP_CONFIG_DUMMY_TIME": configTest.TestConfigDeep.TestTime, @@ -1040,3 +1088,57 @@ func loadEnvIntoEnvironment(t *testing.T, envPath string) (err error) { return } + +func TestCustomTypeHook_Success(t *testing.T) { + t.Cleanup(os.Clearenv) + os.Clearenv() + + cfg := &ConfigurationTest{} + defaults := DefaultConfiguration() + + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost)) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword)) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user")) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user")) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB)) + require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt))) + require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String())) + + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1)) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1")) + + err := Load("test", cfg, defaults) + require.NoError(t, err) + require.NoError(t, cfg.Validate()) + + assert.Equal(t, mapstest.TestEnumWithUnmarshal1, cfg.TestConfig.TestEnum) + assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, cfg.TestConfig.TestEnum1) +} + +func TestCustomTypeHook_InvalidValue(t *testing.T) { + t.Cleanup(os.Clearenv) + os.Clearenv() + + cfg := &ConfigurationTest{} + defaults := DefaultConfiguration() + + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost)) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword)) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user")) + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword)) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user")) + require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB)) + require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt))) + require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String())) + + require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", "4")) + + err := Load("test", cfg, defaults) + errortest.AssertError(t, err, commonerrors.ErrInvalid) + errortest.AssertErrorDescription(t, err, "structure failed validation") +} diff --git a/utils/serialization/maps/map.go b/utils/serialization/maps/map.go index 095656a44e..482ce71dba 100644 --- a/utils/serialization/maps/map.go +++ b/utils/serialization/maps/map.go @@ -1,13 +1,16 @@ package maps import ( + "encoding" "reflect" + "strconv" "time" "github.com/go-viper/mapstructure/v2" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/maps" + "github.com/ARM-software/golang-utils/utils/safecast" ) // ToMapFromPointer is like ToMap but deals with a pointer. @@ -72,7 +75,7 @@ func FromMapToPointer[T any](m map[string]string, o T) (err error) { err = mapstructureDecoder(expandedMap, o) if err != nil { - err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise upload request") + err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise the map") } return } @@ -148,11 +151,57 @@ func toTime(f reflect.Type, t reflect.Type, data any) (any, error) { } } +func toCustomTypeIntFallback(f reflect.Type, t reflect.Type, data any) (any, error) { + if f == nil || t == nil || f.Kind() != reflect.String { + return data, nil + } + if t.Kind() != reflect.Int { + return data, nil + } + + s, ok := data.(string) + if !ok { + return data, nil + } + + i, err := strconv.Atoi(s) + if err != nil { + return data, nil + } + + ptr := reflect.New(t).Elem() + ptr.SetInt(safecast.ToInt64(i)) + + return ptr.Interface(), nil +} + +func toCustomType(f reflect.Type, t reflect.Type, data any) (any, error) { + if f == nil || t == nil || f.Kind() != reflect.String { + return data, nil + } + + customType, ok := reflect.New(t).Interface().(encoding.TextUnmarshaler) + if !ok { + return toCustomTypeIntFallback(f, t, data) + } + + err := customType.UnmarshalText([]byte(data.(string))) // we know it is a string based on reflection + if err != nil { + return toCustomTypeIntFallback(f, t, data) + } + + return customType, nil +} + +func CustomTypeHookFunc() mapstructure.DecodeHookFunc { + return toCustomType +} + func mapstructureDecoder(input, result any) error { decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ WeaklyTypedInput: true, DecodeHook: mapstructure.ComposeDecodeHookFunc( - timeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()), + timeHookFunc(), CustomTypeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()), Result: result, }) if err != nil { diff --git a/utils/serialization/maps/map_test.go b/utils/serialization/maps/map_test.go index a63c83ab2c..25b85c3e52 100644 --- a/utils/serialization/maps/map_test.go +++ b/utils/serialization/maps/map_test.go @@ -10,6 +10,7 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + mapstest "github.com/ARM-software/golang-utils/utils/serialization/maps/testing" //nolint:misspell ) type TestStruct0 struct { @@ -107,6 +108,18 @@ type TestStruct3WithTime struct { Struct TestStruct2WithTime } +type TestStructWithEnum struct { + Time time.Time `mapstructure:"some_time"` + TestEnum mapstest.TestEnumWithUnmarshal `mapstructure:"test_enum"` + Struct TestStruct2WithTime +} + +type TestStructWithEnumInvalid struct { + Time time.Time `mapstructure:"some_time"` + TestEnum mapstest.TestEnumWithoutUnmarshal `mapstructure:"test_enum"` + Struct TestStruct2WithTime +} + type TestStruct4 struct { Field1 string `mapstructure:"field_1"` Field2 string `mapstructure:"field_2"` @@ -170,6 +183,53 @@ func TestToMap(t *testing.T) { assert.WithinDuration(t, testStruct.Struct.Time, newStruct.Struct.Time, 0) assert.Equal(t, testStruct.Struct.Duration, newStruct.Struct.Duration) }) + t.Run("with custom type (has UnmarshalText)", func(t *testing.T) { + random, err := faker.RandomInt(0, 1000, 2) + require.NoError(t, err) + testStruct := TestStructWithEnum{ + Time: time.Now().UTC(), + TestEnum: mapstest.TestEnumWithUnmarshal1, + Struct: TestStruct2WithTime{ + Time: time.Unix(faker.RandomUnixTime(), 0), + Duration: time.Duration(random[1]) * time.Second, + }, + } + structMap, err := ToMap[TestStructWithEnum](&testStruct) + structMap["test_enum"] = mapstest.TestEnumStringVer1 // change to the alternate version for unmarshalling + require.NoError(t, err) + _, err = ToMapFromPointer[TestStructWithEnum](testStruct) + errortest.AssertError(t, err, commonerrors.ErrInvalid) + newStruct := TestStructWithEnum{} + require.NoError(t, FromMap[TestStructWithEnum](structMap, &newStruct)) + errortest.AssertError(t, FromMapToPointer[TestStructWithEnum](structMap, newStruct), commonerrors.ErrInvalid) + assert.WithinDuration(t, testStruct.Time, newStruct.Time, 0) + assert.Equal(t, testStruct.TestEnum, newStruct.TestEnum) + assert.WithinDuration(t, testStruct.Struct.Time, newStruct.Struct.Time, 0) + assert.Equal(t, testStruct.Struct.Duration, newStruct.Struct.Duration) + }) + t.Run("with custom type (no UnmarshalText)", func(t *testing.T) { + random, err := faker.RandomInt(0, 1000, 2) + require.NoError(t, err) + testStruct := TestStructWithEnumInvalid{ + Time: time.Now().UTC(), + TestEnum: mapstest.TestEnumWithoutUnmarshal1, + Struct: TestStruct2WithTime{ + Time: time.Unix(faker.RandomUnixTime(), 0), + Duration: time.Duration(random[1]) * time.Second, + }, + } + structMap, err := ToMap[TestStructWithEnumInvalid](&testStruct) + require.NoError(t, err) + _, err = ToMapFromPointer[TestStructWithEnumInvalid](testStruct) + errortest.AssertError(t, err, commonerrors.ErrInvalid) + newStruct := TestStructWithEnumInvalid{} + require.NoError(t, FromMap[TestStructWithEnumInvalid](structMap, &newStruct)) + errortest.AssertError(t, FromMapToPointer[TestStructWithEnumInvalid](structMap, newStruct), commonerrors.ErrInvalid) + assert.WithinDuration(t, testStruct.Time, newStruct.Time, 0) + assert.Equal(t, testStruct.TestEnum, newStruct.TestEnum) + assert.WithinDuration(t, testStruct.Struct.Time, newStruct.Struct.Time, 0) + assert.Equal(t, testStruct.Struct.Duration, newStruct.Struct.Duration) + }) t.Run("invalid", func(t *testing.T) { var testMap map[string]string testStruct := TestStruct3WithTime{} diff --git a/utils/serialization/maps/testing/testing.go b/utils/serialization/maps/testing/testing.go new file mode 100644 index 0000000000..ce5d4b656e --- /dev/null +++ b/utils/serialization/maps/testing/testing.go @@ -0,0 +1,44 @@ +package testing + +import ( + "github.com/ARM-software/golang-utils/utils/commonerrors" +) + +type TestEnumWithUnmarshal int + +const ( + TestEnumStringVer0 = "test0" + TestEnumStringVer1 = "test1" +) + +func (i *TestEnumWithUnmarshal) UnmarshalText(text []byte) error { + v, ok := map[string]TestEnumWithUnmarshal{ + TestEnumStringVer0: TestEnumWithUnmarshal0, + TestEnumStringVer1: TestEnumWithUnmarshal1, + }[string(text)] + if !ok { + return commonerrors.ErrInvalid + } + *i = v + return nil +} + +func ValidationFunc(value any) error { + e, ok := value.(TestEnumWithUnmarshal) + if !ok || (e != TestEnumWithUnmarshal0 && e != TestEnumWithUnmarshal1) { + return commonerrors.ErrInvalid + } + return nil +} + +const ( + TestEnumWithUnmarshal0 TestEnumWithUnmarshal = iota + TestEnumWithUnmarshal1 +) + +type TestEnumWithoutUnmarshal int + +const ( + TestEnumWithoutUnmarshal0 TestEnumWithoutUnmarshal = iota + TestEnumWithoutUnmarshal1 +)