Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251030161706.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `serialization` Add support for decoding enum strings to their underlying values when using mapstructure tags
8 changes: 7 additions & 1 deletion utils/config/service_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/ARM-software/golang-utils/utils/field"
"github.com/ARM-software/golang-utils/utils/keyring"
"github.com/ARM-software/golang-utils/utils/reflection"
"github.com/ARM-software/golang-utils/utils/serialization/maps" //nolint:misspell
)

const (
Expand Down Expand Up @@ -113,7 +114,12 @@ func LoadFromEnvironmentAndSystem(viperSession *viper.Viper, envVarPrefix string
}

// Merge together all the sources and unmarshal into struct
err = viperSession.Unmarshal(configurationToSet)
err = viperSession.Unmarshal(configurationToSet, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
maps.CustomTypeHookFunc(),
// Keep these two as they are the default values used by viper and we don't want to override them
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
)))
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "unable to fill configuration structure from the configuration session")
return
Expand Down
116 changes: 109 additions & 7 deletions utils/config/service_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
"github.com/ARM-software/golang-utils/utils/keyring"
mapstest "github.com/ARM-software/golang-utils/utils/serialization/maps/testing" //nolint:misspell
)

var (
Expand All @@ -39,13 +40,15 @@ var (
)

type DummyConfiguration struct {
Host string `mapstructure:"dummy_host"`
Port int `mapstructure:"port"`
DB string `mapstructure:"db"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Flag bool `mapstructure:"flag"`
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
Host string `mapstructure:"dummy_host"`
Port int `mapstructure:"port"`
DB string `mapstructure:"db"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Flag bool `mapstructure:"flag"`
TestEnum mapstest.TestEnumWithUnmarshal `mapstructure:"enum"`
TestEnum1 mapstest.TestEnumWithoutUnmarshal `mapstructure:"enum1"`
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
}

func (cfg *DummyConfiguration) Validate() error {
Expand All @@ -55,6 +58,7 @@ func (cfg *DummyConfiguration) Validate() error {
validation.Field(&cfg.DB, validation.Required),
validation.Field(&cfg.User, validation.Required),
validation.Field(&cfg.Password, validation.Required),
validation.Field(&cfg.TestEnum, validation.By(mapstest.ValidationFunc)),
)
}

Expand Down Expand Up @@ -189,6 +193,14 @@ func TestServiceConfigurationLoad(t *testing.T) {
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_USER", "a test user")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM", mapstest.TestEnumStringVer1)
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1)
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM1", "1")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMYCONFIG_DB", "a test db")
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB)
Expand Down Expand Up @@ -407,6 +419,8 @@ func TestFlagsBinding(t *testing.T) {
flagSet.Int("int", 0, "dummy int")
flagSet.Duration("time", time.Second, "dummy time")
flagSet.Bool("flag", false, "dummy flag")
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
flagSet.String("enum1", "1", "dummy enum")
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host2"), flagSet.Lookup("host2"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host1"), flagSet.Lookup("host2"))
Expand All @@ -419,6 +433,10 @@ func TestFlagsBinding(t *testing.T) {
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_USER", flagSet.Lookup("user1"), flagSet.Lookup("user2"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DB", flagSet.Lookup("db"))
require.NoError(t, err)
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_DB", flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"))
Expand Down Expand Up @@ -476,6 +494,12 @@ func TestFlagsBinding(t *testing.T) {
assert.Equal(t, expectedHost, configTest.TestConfig2.Host)
assert.Equal(t, expectedPassword, configTest.TestConfig.Password)
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
assert.NotEqual(t, mapstest.TestEnumStringVer1, configTest.TestConfig2.TestEnum)
assert.NotEqual(t, mapstest.TestEnumStringVer0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal1, configTest.TestConfig2.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, configTest.TestConfig2.TestEnum1)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
assert.Equal(t, expectedDB, configTest.TestConfig.DB)
assert.Equal(t, aDifferentDB, configTest.TestConfig2.DB)
assert.NotEqual(t, expectedDB, configTest.TestConfig2.DB)
Expand Down Expand Up @@ -516,6 +540,8 @@ func TestFlagBindingDefaults(t *testing.T) {
flagSet.Int("int", expectedInt, "dummy int")
flagSet.Duration("time", expectedDuration, "dummy time")
flagSet.Bool("flag", !DefaultDummyConfiguration().Flag, "dummy flag")
flagSet.String("enum", mapstest.TestEnumStringVer0, "dummy enum")
flagSet.String("enum1", "0", "dummy enum")
err = BindFlagToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host2"))
Expand All @@ -538,6 +564,10 @@ func TestFlagBindingDefaults(t *testing.T) {
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_Time", flagSet.Lookup("time"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_enum", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "DUMMY_enum1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB) // Should take precedence over flag default
require.NoError(t, err)
err = LoadFromViper(session, prefix, configTest, defaults)
Expand All @@ -556,6 +586,10 @@ func TestFlagBindingDefaults(t *testing.T) {
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
assert.Equal(t, aDifferentDB, configTest.TestConfig.DB)
assert.Equal(t, expectedDB, configTest.TestConfig2.DB)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig2.TestEnum)
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig2.TestEnum1)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
// Defaults from the default structure provided take precedence over defaults from flags when empty.
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig.Flag)
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig2.Flag)
Expand All @@ -575,6 +609,8 @@ func TestGenerateEnvFile_Defaults(t *testing.T) {
"TEST_PASSWORD": configTest.Password,
"TEST_PORT": configTest.Port,
"TEST_USER": configTest.User,
"TEST_ENUM": configTest.TestEnum,
"TEST_ENUM1": configTest.TestEnum1,
}

// Generate env file
Expand All @@ -601,6 +637,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
flagSet.String("password", "a password", "dummy password")
flagSet.String("user", "a user", "dummy user")
flagSet.String("db", "a db", "dummy db")
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
flagSet.String("enum1", "1", "dummy enum")
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_HOST", flagSet.Lookup("host"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "PASSWORD", flagSet.Lookup("password"))
Expand All @@ -611,6 +649,10 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "USER", flagSet.Lookup("user"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "ENUM", flagSet.Lookup("enum"))
require.NoError(t, err)
err = BindFlagToEnv(session, prefix, "ENUM1", flagSet.Lookup("enum1"))
require.NoError(t, err)
err = flagSet.Set("host", expectedHost)
require.NoError(t, err)
err = flagSet.Set("password", expectedPassword)
Expand All @@ -630,6 +672,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
"TEST_PASSWORD": configTest.Password,
"TEST_PORT": configTest.Port,
"TEST_USER": configTest.User,
"TEST_ENUM": configTest.TestEnum,
"TEST_ENUM1": configTest.TestEnum1,
}

// Generate env file
Expand Down Expand Up @@ -670,13 +714,17 @@ func TestGenerateEnvFile_Nested(t *testing.T) {
"TEST_DEEP_CONFIG_DUMMYCONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig.Password,
"TEST_DEEP_CONFIG_DUMMYCONFIG_PORT": configTest.TestConfigDeep.TestConfig.Port,
"TEST_DEEP_CONFIG_DUMMYCONFIG_USER": configTest.TestConfigDeep.TestConfig.User,
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM": configTest.TestConfigDeep.TestConfig.TestEnum,
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM1": configTest.TestConfigDeep.TestConfig.TestEnum1,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DB": configTest.TestConfigDeep.TestConfig2.DB,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DUMMY_HOST": configTest.TestConfigDeep.TestConfig2.Host,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_FLAG": configTest.TestConfigDeep.TestConfig2.Flag,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_HEALTHCHECK_PERIOD": configTest.TestConfigDeep.TestConfig2.HealthCheckPeriod,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig2.Password,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PORT": configTest.TestConfigDeep.TestConfig2.Port,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_USER": configTest.TestConfigDeep.TestConfig2.User,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM": configTest.TestConfigDeep.TestConfig2.TestEnum,
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM1": configTest.TestConfigDeep.TestConfig2.TestEnum1,
"TEST_DEEP_CONFIG_DUMMY_INT": configTest.TestConfigDeep.TestInt,
"TEST_DUMMY_STRING": configTest.TestString,
"TEST_DEEP_CONFIG_DUMMY_TIME": configTest.TestConfigDeep.TestTime,
Expand Down Expand Up @@ -1040,3 +1088,57 @@ func loadEnvIntoEnvironment(t *testing.T, envPath string) (err error) {

return
}

func TestCustomTypeHook_Success(t *testing.T) {
t.Cleanup(os.Clearenv)
os.Clearenv()

cfg := &ConfigurationTest{}
defaults := DefaultConfiguration()

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1"))

err := Load("test", cfg, defaults)
require.NoError(t, err)
require.NoError(t, cfg.Validate())

assert.Equal(t, mapstest.TestEnumWithUnmarshal1, cfg.TestConfig.TestEnum)
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, cfg.TestConfig.TestEnum1)
}

func TestCustomTypeHook_InvalidValue(t *testing.T) {
t.Cleanup(os.Clearenv)
os.Clearenv()

cfg := &ConfigurationTest{}
defaults := DefaultConfiguration()

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))

require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", "4"))

err := Load("test", cfg, defaults)
errortest.AssertError(t, err, commonerrors.ErrInvalid)
errortest.AssertErrorDescription(t, err, "structure failed validation")
}
53 changes: 51 additions & 2 deletions utils/serialization/maps/map.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package maps

import (
"encoding"
"reflect"
"strconv"
"time"

"github.com/go-viper/mapstructure/v2"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/maps"
"github.com/ARM-software/golang-utils/utils/safecast"
)

// ToMapFromPointer is like ToMap but deals with a pointer.
Expand Down Expand Up @@ -72,7 +75,7 @@ func FromMapToPointer[T any](m map[string]string, o T) (err error) {

err = mapstructureDecoder(expandedMap, o)
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise upload request")
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise the map")
}
return
}
Expand Down Expand Up @@ -148,11 +151,57 @@ func toTime(f reflect.Type, t reflect.Type, data any) (any, error) {
}
}

func toCustomTypeIntFallback(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t.Kind() != reflect.Int {
return data, nil
}

s, ok := data.(string)
if !ok {
return data, nil
}

i, err := strconv.Atoi(s)
if err != nil {
return data, nil
}

ptr := reflect.New(t).Elem()
ptr.SetInt(safecast.ToInt64(i))

return ptr.Interface(), nil
}

func toCustomType(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}

customType, ok := reflect.New(t).Interface().(encoding.TextUnmarshaler)
if !ok {
return toCustomTypeIntFallback(f, t, data)
}

err := customType.UnmarshalText([]byte(data.(string))) // we know it is a string based on reflection
if err != nil {
return toCustomTypeIntFallback(f, t, data)
}

return customType, nil
}

func CustomTypeHookFunc() mapstructure.DecodeHookFunc {
return toCustomType
}

func mapstructureDecoder(input, result any) error {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
timeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
timeHookFunc(), CustomTypeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
Result: result,
})
if err != nil {
Expand Down
Loading
Loading