Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251030161706.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `serialization` Add support for decoding enum strings to their underlying values when using mapstructure tags
8 changes: 7 additions & 1 deletion utils/config/service_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/ARM-software/golang-utils/utils/field"
"github.com/ARM-software/golang-utils/utils/keyring"
"github.com/ARM-software/golang-utils/utils/reflection"
"github.com/ARM-software/golang-utils/utils/serialization/maps" //nolint:misspell
)

const (
Expand Down Expand Up @@ -113,7 +114,12 @@ func LoadFromEnvironmentAndSystem(viperSession *viper.Viper, envVarPrefix string
}

// Merge together all the sources and unmarshal into struct
err = viperSession.Unmarshal(configurationToSet)
err = viperSession.Unmarshal(configurationToSet, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
maps.CustomTypeHookFunc(),
// Keep these two as they are the default values used by viper and we don't want to override them
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
)))
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "unable to fill configuration structure from the configuration session")
return
Expand Down
116 changes: 109 additions & 7 deletions utils/config/service_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
"github.com/ARM-software/golang-utils/utils/keyring"
mapstest "github.com/ARM-software/golang-utils/utils/serialization/maps/testing" //nolint:misspell
)

var (
Expand All @@ -39,13 +40,15 @@ var (
)

type DummyConfiguration struct {
Host string `mapstructure:"dummy_host"`
Port int `mapstructure:"port"`
DB string `mapstructure:"db"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Flag bool `mapstructure:"flag"`
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
Host string `mapstructure:"dummy_host"`
Port int `mapstructure:"port"`
DB string `mapstructure:"db"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Flag bool `mapstructure:"flag"`
TestEnum mapstest.TestEnumWithUnmarshal `mapstructure:"enum"`
TestEnum1 mapstest.TestEnumWithoutUnmarshal `mapstructure:"enum1"`
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
}

func (cfg *DummyConfiguration) Validate() error {
Expand All @@ -55,6 +58,7 @@ func (cfg *DummyConfiguration) Validate() error {
validation.Field(&cfg.DB, validation.Required),
validation.Field(&cfg.User, validation.Required),
validation.Field(&cfg.Password, validation.Required),
validation.Field(&cfg.TestEnum, validation.By(mapstest.ValidationFunc)),
)
}

Expand Down Expand Up @@ -189,6 +193,14 @@ func TestServiceConfigurationLoad(t *testing.T) {
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_USER", "a test user")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM", mapstest.TestEnumStringVer1)
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1)
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM1", "1")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_DB", "a test db")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB)
Expand Down Expand Up @@ -407,6 +419,8 @@ func TestFlagsBinding(t *testing.T) {
flagSet.Int("int", 0, "dummy int")
flagSet.Duration("time", time.Second, "dummy time")
flagSet.Bool("flag", false, "dummy flag")
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
flagSet.String("enum1", "1", "dummy enum")
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host2"), flagSet.Lookup("host2"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host1"), flagSet.Lookup("host2"))
Expand All @@ -419,6 +433,10 @@ func TestFlagsBinding(t *testing.T) {
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_USER", flagSet.Lookup("user1"), flagSet.Lookup("user2"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DB", flagSet.Lookup("db"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_DB", flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"))
Expand Down Expand Up @@ -476,6 +494,12 @@ func TestFlagsBinding(t *testing.T) {
assert.Equal(t, expectedHost, configTest.TestConfig2.Host)
assert.Equal(t, expectedPassword, configTest.TestConfig.Password)
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
assert.NotEqual(t, mapstest.TestEnumStringVer1, configTest.TestConfig2.TestEnum)
assert.NotEqual(t, mapstest.TestEnumStringVer0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal1, configTest.TestConfig2.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, configTest.TestConfig2.TestEnum1)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
assert.Equal(t, expectedDB, configTest.TestConfig.DB)
assert.Equal(t, aDifferentDB, configTest.TestConfig2.DB)
assert.NotEqual(t, expectedDB, configTest.TestConfig2.DB)
Expand Down Expand Up @@ -516,6 +540,8 @@ func TestFlagBindingDefaults(t *testing.T) {
flagSet.Int("int", expectedInt, "dummy int")
flagSet.Duration("time", expectedDuration, "dummy time")
flagSet.Bool("flag", !DefaultDummyConfiguration().Flag, "dummy flag")
flagSet.String("enum", mapstest.TestEnumStringVer0, "dummy enum")
flagSet.String("enum1", "0", "dummy enum")
err = BindFlagToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host2"))
Expand All @@ -538,6 +564,10 @@ func TestFlagBindingDefaults(t *testing.T) {
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_Time", flagSet.Lookup("time"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_enum", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_enum1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB) // Should take precedence over flag default
require.NoError(t, err)
err = LoadFromViper(session, prefix, configTest, defaults)
Expand All @@ -556,6 +586,10 @@ func TestFlagBindingDefaults(t *testing.T) {
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
assert.Equal(t, aDifferentDB, configTest.TestConfig.DB)
assert.Equal(t, expectedDB, configTest.TestConfig2.DB)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig2.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig2.TestEnum1)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
// Defaults from the default structure provided take precedence over defaults from flags when empty.
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig.Flag)
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig2.Flag)
Expand All @@ -575,6 +609,8 @@ func TestGenerateEnvFile_Defaults(t *testing.T) {
"TEST_PASSWORD": configTest.Password,
"TEST_PORT": configTest.Port,
"TEST_USER": configTest.User,
"TEST_ENUM": configTest.TestEnum,
"TEST_ENUM1": configTest.TestEnum1,
}

// Generate env file
Expand All @@ -601,6 +637,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
flagSet.String("password", "a password", "dummy password")
flagSet.String("user", "a user", "dummy user")
flagSet.String("db", "a db", "dummy db")
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
flagSet.String("enum1", "1", "dummy enum")
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_HOST", flagSet.Lookup("host"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "PASSWORD", flagSet.Lookup("password"))
Expand All @@ -611,6 +649,10 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "USER", flagSet.Lookup("user"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "ENUM", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "ENUM1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = flagSet.Set("host", expectedHost)
require.NoError(t, err)
err = flagSet.Set("password", expectedPassword)
Expand All @@ -630,6 +672,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
"TEST_PASSWORD": configTest.Password,
"TEST_PORT": configTest.Port,
"TEST_USER": configTest.User,
"TEST_ENUM": configTest.TestEnum,
"TEST_ENUM1": configTest.TestEnum1,
}

// Generate env file
Expand Down Expand Up @@ -670,13 +714,17 @@ func TestGenerateEnvFile_Nested(t *testing.T) {
"TEST_DEEP_CONFIG_DUMMYCONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig.Password,
"TEST_DEEP_CONFIG_DUMMYCONFIG_PORT": configTest.TestConfigDeep.TestConfig.Port,
"TEST_DEEP_CONFIG_DUMMYCONFIG_USER": configTest.TestConfigDeep.TestConfig.User,
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM": configTest.TestConfigDeep.TestConfig.TestEnum,
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM1": configTest.TestConfigDeep.TestConfig.TestEnum1,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DB": configTest.TestConfigDeep.TestConfig2.DB,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DUMMY_HOST": configTest.TestConfigDeep.TestConfig2.Host,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_FLAG": configTest.TestConfigDeep.TestConfig2.Flag,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_HEALTHCHECK_PERIOD": configTest.TestConfigDeep.TestConfig2.HealthCheckPeriod,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig2.Password,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PORT": configTest.TestConfigDeep.TestConfig2.Port,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_USER": configTest.TestConfigDeep.TestConfig2.User,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM": configTest.TestConfigDeep.TestConfig2.TestEnum,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM1": configTest.TestConfigDeep.TestConfig2.TestEnum1,
"TEST_DEEP_CONFIG_DUMMY_INT": configTest.TestConfigDeep.TestInt,
"TEST_DUMMY_STRING": configTest.TestString,
"TEST_DEEP_CONFIG_DUMMY_TIME": configTest.TestConfigDeep.TestTime,
Expand Down Expand Up @@ -1040,3 +1088,57 @@ func loadEnvIntoEnvironment(t *testing.T, envPath string) (err error) {

return
}

func TestCustomTypeHook_Success(t *testing.T) {
t.Cleanup(os.Clearenv)
os.Clearenv()

cfg := &ConfigurationTest{}
defaults := DefaultConfiguration()

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1"))

err := Load("test", cfg, defaults)
require.NoError(t, err)
require.NoError(t, cfg.Validate())

assert.Equal(t, mapstest.TestEnumWithUnmarshal1, cfg.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, cfg.TestConfig.TestEnum1)
}

func TestCustomTypeHook_InvalidValue(t *testing.T) {
t.Cleanup(os.Clearenv)
os.Clearenv()

cfg := &ConfigurationTest{}
defaults := DefaultConfiguration()

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", "4"))

err := Load("test", cfg, defaults)
errortest.AssertError(t, err, commonerrors.ErrInvalid)
errortest.AssertErrorDescription(t, err, "structure failed validation")
}
53 changes: 51 additions & 2 deletions utils/serialization/maps/map.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package maps

import (
"encoding"
"reflect"
"strconv"
"time"

"github.com/go-viper/mapstructure/v2"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/maps"
"github.com/ARM-software/golang-utils/utils/safecast"
)

// ToMapFromPointer is like ToMap but deals with a pointer.
Expand Down Expand Up @@ -72,7 +75,7 @@ func FromMapToPointer[T any](m map[string]string, o T) (err error) {

err = mapstructureDecoder(expandedMap, o)
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise upload request")
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise the map")
}
return
}
Expand Down Expand Up @@ -148,11 +151,57 @@ func toTime(f reflect.Type, t reflect.Type, data any) (any, error) {
}
}

func toCustomTypeIntFallback(f reflect.Type, t reflect.Type, data any) (any, error) {
if f == nil || t == nil || f.Kind() != reflect.String {
return data, nil
}
if t.Kind() != reflect.Int {
return data, nil
}

s, ok := data.(string)
if !ok {
return data, nil
}

i, err := strconv.Atoi(s)
if err != nil {
return data, nil
}

ptr := reflect.New(t).Elem()
ptr.SetInt(safecast.ToInt64(i))

return ptr.Interface(), nil
}

func toCustomType(f reflect.Type, t reflect.Type, data any) (any, error) {
if f == nil || t == nil || f.Kind() != reflect.String {
return data, nil
}

customType, ok := reflect.New(t).Interface().(encoding.TextUnmarshaler)
if !ok {
return toCustomTypeIntFallback(f, t, data)
}

err := customType.UnmarshalText([]byte(data.(string))) // we know it is a string based on reflection
if err != nil {
return toCustomTypeIntFallback(f, t, data)
}

return customType, nil
}

func CustomTypeHookFunc() mapstructure.DecodeHookFunc {
return toCustomType
}

func mapstructureDecoder(input, result any) error {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
timeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
timeHookFunc(), CustomTypeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
Result: result,
})
if err != nil {
Expand Down
Loading
Loading