diff --git a/changes/20251014182838.feature b/changes/20251014182838.feature new file mode 100644 index 0000000000..36286d9e92 --- /dev/null +++ b/changes/20251014182838.feature @@ -0,0 +1 @@ +:sparkles: [field] Added utilities to return nil if a value is empty as opposed to a pointer to an empty value diff --git a/changes/20251014182937.bugfix b/changes/20251014182937.bugfix new file mode 100644 index 0000000000..839099c333 --- /dev/null +++ b/changes/20251014182937.bugfix @@ -0,0 +1 @@ +:bug: `[config]` escape mapstructure special tags when reporting an validation error diff --git a/utils/config/validation.go b/utils/config/validation.go index 17d4d1566b..153a3ef992 100644 --- a/utils/config/validation.go +++ b/utils/config/validation.go @@ -6,8 +6,14 @@ package config import ( "reflect" + "strings" + + "github.com/ARM-software/golang-utils/utils/collection" + fieldUtils "github.com/ARM-software/golang-utils/utils/field" ) +var specialMapstructureTags = []string{"squash", "remain", "omitempty", "omitzero"} // See https://pkg.go.dev/github.com/go-viper/mapstructure/v2#section-readme + // ValidateEmbedded uses reflection to find embedded structs and validate them func ValidateEmbedded(cfg Validator) error { r := reflect.ValueOf(cfg).Elem() @@ -32,10 +38,27 @@ func ValidateEmbedded(cfg Validator) error { func wrapFieldValidationError(field reflect.StructField, err error) error { mapStructureStr, hasTag := field.Tag.Lookup("mapstructure") - mapStructure := &mapStructureStr + mapStructure := fieldUtils.ToOptionalStringOrNilIfEmpty(processMapStructureString(mapStructureStr)) if !hasTag { mapStructure = nil } err = WrapFieldValidationError(field.Name, mapStructure, nil, err) return err } + +// mapstructure has some special tags which need to be accounted for. +func processMapStructureString(str string) string { + processedStr := strings.TrimSpace(str) + if processedStr == "-" { + return "" + } + + elements := strings.Split(processedStr, ",") + if len(elements) == 1 { + return processedStr + } + elements = collection.GenericRemove(func(str1, str2 string) bool { + return strings.EqualFold(strings.TrimSpace(str1), strings.TrimSpace(str2)) + }, elements, specialMapstructureTags...) + return strings.TrimSpace(strings.Join(elements, ",")) +} diff --git a/utils/config/validation_test.go b/utils/config/validation_test.go new file mode 100644 index 0000000000..9b311237fe --- /dev/null +++ b/utils/config/validation_test.go @@ -0,0 +1,51 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_processMapStructureString(t *testing.T) { + tests := []struct { + mapstructureTag string + expectedProcessedTag string + }{ + {}, + { + mapstructureTag: " ", + }, + { + mapstructureTag: " - ", + }, + { + mapstructureTag: " , omitzero ", + }, + { + mapstructureTag: " ,omitempty , omitzero , SQUASH ", + }, + { + mapstructureTag: "test ,omitempty , omitzero , squash ", + expectedProcessedTag: "test", + }, + { + mapstructureTag: "person_name", + expectedProcessedTag: "person_name", + }, + { + mapstructureTag: " person_name ", + expectedProcessedTag: "person_name", + }, + { + mapstructureTag: " person_name ,remain ", + expectedProcessedTag: "person_name", + }, + } + + for i := range tests { + test := tests[i] + t.Run(test.mapstructureTag, func(t *testing.T) { + assert.Equal(t, test.expectedProcessedTag, processMapStructureString(test.mapstructureTag)) + }) + } +} diff --git a/utils/field/fields.go b/utils/field/fields.go index 3772063f34..5a9e24dd5e 100644 --- a/utils/field/fields.go +++ b/utils/field/fields.go @@ -6,103 +6,157 @@ // package field provides utilities to set structure fields. It was inspired by the kubernetes package https://pkg.go.dev/k8s.io/utils/pointer. package field -import "time" +import ( + "time" + + "github.com/ARM-software/golang-utils/utils/value" +) // ToOptionalInt returns a pointer to an int func ToOptionalInt(f int) *int { - return ToOptional(f) + return ToOptional[int](f) +} + +// ToOptionalIntOrNilIfEmpty returns a pointer to an int unless it is empty and in that case returns nil. +func ToOptionalIntOrNilIfEmpty(f int) *int { + return ToOptionalOrNilIfEmpty[int](f) } // OptionalInt returns the value of an optional field or else // returns defaultValue. func OptionalInt(ptr *int, defaultValue int) int { - return Optional(ptr, defaultValue) + return Optional[int](ptr, defaultValue) } // ToOptionalInt32 returns a pointer to an int32. func ToOptionalInt32(f int32) *int32 { - return ToOptional(f) + return ToOptional[int32](f) +} + +// ToOptionalInt32OrNilIfEmpty returns a pointer to an int32 unless it is empty and in that case returns nil. +func ToOptionalInt32OrNilIfEmpty(f int32) *int32 { + return ToOptionalOrNilIfEmpty[int32](f) } // OptionalInt32 returns the value of an optional field or else // returns defaultValue. func OptionalInt32(ptr *int32, defaultValue int32) int32 { - return Optional(ptr, defaultValue) + return Optional[int32](ptr, defaultValue) } // ToOptionalUint returns a pointer to an uint func ToOptionalUint(f uint) *uint { - return ToOptional(f) + return ToOptional[uint](f) +} + +// ToOptionalUintOrNilIfEmpty returns a pointer to a Uint unless it is empty and in that case returns nil. +func ToOptionalUintOrNilIfEmpty(f uint) *uint { + return ToOptionalOrNilIfEmpty[uint](f) } // OptionalUint returns the value of an optional field or else returns defaultValue. func OptionalUint(ptr *uint, defaultValue uint) uint { - return Optional(ptr, defaultValue) + return Optional[uint](ptr, defaultValue) } // ToOptionalUint32 returns a pointer to an uint32. func ToOptionalUint32(f uint32) *uint32 { - return ToOptional(f) + return ToOptional[uint32](f) +} + +// ToOptionalUint32OrNilIfEmpty returns a pointer to an Uint32 unless it is empty and in that case returns nil. +func ToOptionalUint32OrNilIfEmpty(f uint32) *uint32 { + return ToOptionalOrNilIfEmpty[uint32](f) } // OptionalUint32 returns the value of an optional field or else returns defaultValue. func OptionalUint32(ptr *uint32, defaultValue uint32) uint32 { - return Optional(ptr, defaultValue) + return Optional[uint32](ptr, defaultValue) } // ToOptionalInt64 returns a pointer to an int64. func ToOptionalInt64(f int64) *int64 { - return ToOptional(f) + return ToOptional[int64](f) +} + +// ToOptionalInt64OrNilIfEmpty returns a pointer to an int64 unless it is empty and in that case returns nil. +func ToOptionalInt64OrNilIfEmpty(f int64) *int64 { + return ToOptionalOrNilIfEmpty[int64](f) } // OptionalInt64 returns the value of an optional field or else returns defaultValue. func OptionalInt64(ptr *int64, defaultValue int64) int64 { - return Optional(ptr, defaultValue) + return Optional[int64](ptr, defaultValue) } // ToOptionalUint64 returns a pointer to an uint64. func ToOptionalUint64(f uint64) *uint64 { - return ToOptional(f) + return ToOptional[uint64](f) +} + +// ToOptionalUint64OrNilIfEmpty returns a pointer to an Uint64 unless it is empty and in that case returns nil. +func ToOptionalUint64OrNilIfEmpty(f uint64) *uint64 { + return ToOptionalOrNilIfEmpty[uint64](f) } // OptionalUint64 returns the value of an optional field or else returns defaultValue. func OptionalUint64(ptr *uint64, defaultValue uint64) uint64 { - return Optional(ptr, defaultValue) + return Optional[uint64](ptr, defaultValue) } // ToOptionalBool returns a pointer to a bool. func ToOptionalBool(b bool) *bool { - return ToOptional(b) + return ToOptional[bool](b) +} + +// ToOptionalBoolOrNilIfEmpty returns a pointer to a boolean unless it is empty and in that case returns nil. +func ToOptionalBoolOrNilIfEmpty(f bool) *bool { + return ToOptionalOrNilIfEmpty[bool](f) } // OptionalBool returns the value of an optional field or else returns defaultValue. func OptionalBool(ptr *bool, defaultValue bool) bool { - return Optional(ptr, defaultValue) + return Optional[bool](ptr, defaultValue) } // ToOptionalString returns a pointer to a string. func ToOptionalString(s string) *string { - return ToOptional(s) + return ToOptional[string](s) +} + +// ToOptionalStringOrNilIfEmpty returns a pointer to a string unless it is empty and in that case returns nil. +func ToOptionalStringOrNilIfEmpty(f string) *string { + return ToOptionalOrNilIfEmpty[string](f) } // OptionalString returns the value of an optional field or else returns defaultValue. func OptionalString(ptr *string, defaultValue string) string { - return Optional(ptr, defaultValue) + return Optional[string](ptr, defaultValue) } -// ToOptionalAny returns a pointer to a object. +// ToOptionalAny returns a pointer to an object. func ToOptionalAny(a any) *any { - return ToOptional(a) + return ToOptional[any](a) +} + +// ToOptionalAnyOrNilIfEmpty returns a pointer to an object unless it is empty and in that case returns nil. +func ToOptionalAnyOrNilIfEmpty(f any) *any { + return ToOptionalOrNilIfEmpty[any](f) } // OptionalAny returns the value of an optional field or else returns defaultValue. func OptionalAny(ptr *any, defaultValue any) any { - return Optional(ptr, defaultValue) + return Optional[any](ptr, defaultValue) } // ToOptionalFloat32 returns a pointer to a float32. func ToOptionalFloat32(f float32) *float32 { - return ToOptional(f) + return ToOptional[float32](f) +} + +// ToOptionalFloat32OrNilIfEmpty returns a pointer to a float32 unless it is empty and in that case returns nil. +func ToOptionalFloat32OrNilIfEmpty(f float32) *float32 { + return ToOptionalOrNilIfEmpty[float32](f) } // OptionalFloat32 returns the value of an optional field or else returns defaultValue. @@ -112,32 +166,47 @@ func OptionalFloat32(ptr *float32, defaultValue float32) float32 { // ToOptionalFloat64 returns a pointer to a float64. func ToOptionalFloat64(f float64) *float64 { - return ToOptional(f) + return ToOptional[float64](f) +} + +// ToOptionalFloat64OrNilIfEmpty returns a pointer to a float64 unless it is empty and in that case returns nil. +func ToOptionalFloat64OrNilIfEmpty(f float64) *float64 { + return ToOptionalOrNilIfEmpty[float64](f) } // OptionalFloat64 returns the value of an optional field or else returns defaultValue. func OptionalFloat64(ptr *float64, defaultValue float64) float64 { - return Optional(ptr, defaultValue) + return Optional[float64](ptr, defaultValue) } // ToOptionalDuration returns a pointer to a Duration. func ToOptionalDuration(f time.Duration) *time.Duration { - return ToOptional(f) + return ToOptional[time.Duration](f) +} + +// ToOptionalDurationOrNilIfEmpty returns a pointer to a duration unless it is empty and in that case returns nil. +func ToOptionalDurationOrNilIfEmpty(f time.Duration) *time.Duration { + return ToOptionalOrNilIfEmpty[time.Duration](f) } // OptionalDuration returns the value of an optional field or else returns defaultValue. func OptionalDuration(ptr *time.Duration, defaultValue time.Duration) time.Duration { - return Optional(ptr, defaultValue) + return Optional[time.Duration](ptr, defaultValue) } // ToOptionalTime returns a pointer to a Time. func ToOptionalTime(f time.Time) *time.Time { - return ToOptional(f) + return ToOptional[time.Time](f) +} + +// ToOptionalTimeOrNilIfEmpty returns a pointer to a time unless it is empty and in that case returns nil. +func ToOptionalTimeOrNilIfEmpty(f time.Time) *time.Time { + return ToOptionalOrNilIfEmpty[time.Time](f) } // OptionalTime returns the value of an optional field or else returns defaultValue. func OptionalTime(ptr *time.Time, defaultValue time.Time) time.Time { - return Optional(ptr, defaultValue) + return Optional[time.Time](ptr, defaultValue) } // ToOptional returns a pointer to the given field value. @@ -145,6 +214,14 @@ func ToOptional[T any](v T) *T { return &v } +// ToOptionalOrNilIfEmpty returns a pointer to the given field value unless it is empty and in that case returns nil. +func ToOptionalOrNilIfEmpty[T any](v T) *T { + if value.IsEmpty(v) { + return nil + } + return ToOptional[T](v) +} + // Optional returns the value of an optional field or else returns defaultValue. func Optional[T any](ptr *T, defaultValue T) T { if ptr != nil { diff --git a/utils/field/fields_test.go b/utils/field/fields_test.go index 49a03ea6ba..308a8fb457 100644 --- a/utils/field/fields_test.go +++ b/utils/field/fields_test.go @@ -11,16 +11,18 @@ import ( "github.com/go-faker/faker/v4" "github.com/stretchr/testify/assert" + "github.com/ARM-software/golang-utils/utils/reflection" "github.com/ARM-software/golang-utils/utils/safecast" ) func TestOptionalField(t *testing.T) { tests := []struct { - fieldType string - value any - defaultValue any - setFunction func(any) any - getFunction func(any, any) any + fieldType string + value any + defaultValue any + setFunction func(any) any + setFunctionOrNil func(any) any + getFunction func(any, any) any }{ { fieldType: "Int", @@ -29,6 +31,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalInt(a.(int)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalIntOrNilIfEmpty(a.(int)) + }, getFunction: func(a any, a2 any) any { var ptr *int if a != nil { @@ -44,6 +49,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalUint(a.(uint)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalUintOrNilIfEmpty(a.(uint)) + }, getFunction: func(a any, a2 any) any { var ptr *uint if a != nil { @@ -59,6 +67,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalInt32(a.(int32)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalInt32OrNilIfEmpty(a.(int32)) + }, getFunction: func(a any, a2 any) any { var ptr *int32 if a != nil { @@ -74,6 +85,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalUint32(a.(uint32)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalUint32OrNilIfEmpty(a.(uint32)) + }, getFunction: func(a any, a2 any) any { var ptr *uint32 if a != nil { @@ -89,6 +103,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalInt64(a.(int64)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalInt64OrNilIfEmpty(a.(int64)) + }, getFunction: func(a any, a2 any) any { var ptr *int64 if a != nil { @@ -104,6 +121,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalUint64(a.(uint64)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalUint64OrNilIfEmpty(a.(uint64)) + }, getFunction: func(a any, a2 any) any { var ptr *uint64 if a != nil { @@ -119,6 +139,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalFloat32(a.(float32)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalFloat32OrNilIfEmpty(a.(float32)) + }, getFunction: func(a any, a2 any) any { var ptr *float32 if a != nil { @@ -134,6 +157,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalFloat64(a.(float64)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalFloat64OrNilIfEmpty(a.(float64)) + }, getFunction: func(a any, a2 any) any { var ptr *float64 if a != nil { @@ -149,6 +175,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalBool(a.(bool)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalBoolOrNilIfEmpty(a.(bool)) + }, getFunction: func(a any, a2 any) any { var ptr *bool if a != nil { @@ -164,6 +193,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalString(a.(string)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalStringOrNilIfEmpty(a.(string)) + }, getFunction: func(a any, a2 any) any { var ptr *string if a != nil { @@ -179,6 +211,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalDuration(a.(time.Duration)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalDurationOrNilIfEmpty(a.(time.Duration)) + }, getFunction: func(a any, a2 any) any { var ptr *time.Duration if a != nil { @@ -194,6 +229,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalTime(a.(time.Time)) }, + setFunctionOrNil: func(a any) any { + return ToOptionalTimeOrNilIfEmpty(a.(time.Time)) + }, getFunction: func(a any, a2 any) any { var ptr *time.Time if a != nil { @@ -209,6 +247,9 @@ func TestOptionalField(t *testing.T) { setFunction: func(a any) any { return ToOptionalAny(a) }, + setFunctionOrNil: func(a any) any { + return ToOptionalAnyOrNilIfEmpty(a) + }, getFunction: func(a any, a2 any) any { var ptr *any if a != nil { @@ -225,6 +266,14 @@ func TestOptionalField(t *testing.T) { assert.NotNil(t, to) assert.Equal(t, test.defaultValue, test.getFunction(nil, test.defaultValue)) assert.Equal(t, test.value, test.getFunction(to, test.defaultValue)) + to2 := test.setFunctionOrNil(test.value) + if reflection.IsEmpty(test.value) { + assert.Nil(t, to2) + assert.Equal(t, test.defaultValue, test.getFunction(to2, test.defaultValue)) + } else { + assert.NotNil(t, to2) + assert.Equal(t, test.value, test.getFunction(to2, test.defaultValue)) + } }) } } diff --git a/utils/reflection/reflection.go b/utils/reflection/reflection.go index 5d568c37c9..decb407205 100644 --- a/utils/reflection/reflection.go +++ b/utils/reflection/reflection.go @@ -5,12 +5,11 @@ package reflection import ( - "fmt" "reflect" - "strings" "unsafe" "github.com/ARM-software/golang-utils/utils/commonerrors" + valueUtils "github.com/ARM-software/golang-utils/utils/value" ) func GetUnexportedStructureField(structure interface{}, fieldName string) interface{} { @@ -21,9 +20,7 @@ func GetStructureField(field reflect.Value) interface{} { if !field.IsValid() { return nil } - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 - Elem(). - Interface() + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 } func SetUnexportedStructureField(structure interface{}, fieldName string, value interface{}) { SetStructureField(fetchStructureField(structure, fieldName), value) @@ -32,9 +29,7 @@ func SetStructureField(field reflect.Value, value interface{}) { if !field.IsValid() { return } - reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 - Elem(). - Set(reflect.ValueOf(value)) + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 } func fetchStructureField(structure interface{}, fieldName string) reflect.Value { @@ -62,7 +57,7 @@ func GetStructField(structure interface{}, fieldName string) (interface{}, bool) } } -// SetStructField attempts to set a field of a structure to the given vaule +// SetStructField attempts to set a field of a structure to the given value // It returns nil or an error, in case the field doesn't exist on the structure // or the value and the field have different types func SetStructField(structure interface{}, fieldName string, value interface{}) error { @@ -70,12 +65,12 @@ func SetStructField(structure interface{}, fieldName string, value interface{}) Field := ValueStructure.Elem().FieldByName(fieldName) // Test field exists on structure if !Field.IsValid() { - return fmt.Errorf("error with field [%v]: %w", fieldName, commonerrors.ErrInvalid) + return commonerrors.Newf(commonerrors.ErrInvalid, "error with field [%v]", fieldName) } // test field is settable if !Field.CanSet() { - return fmt.Errorf("error with unsettable field [%v]: %w", fieldName, commonerrors.ErrUnsupported) + return commonerrors.Newf(commonerrors.ErrUnsupported, "error with unsettable field [%v]", fieldName) } // Helper variables @@ -101,7 +96,7 @@ func SetStructField(structure interface{}, fieldName string, value interface{}) // Check that the underlying types are the same (e.g. no int and string) if fieldUnderlyingType != valueUnderlyingType { - return fmt.Errorf("conflicting types, field [%v] and value [%v]: %w", fieldKind, valueKind, commonerrors.ErrConflict) + return commonerrors.Newf(commonerrors.ErrConflict, "conflicting types, field [%v] and value [%v]", fieldKind, valueKind) } if fieldKind == reflect.Ptr { @@ -198,48 +193,19 @@ func InheritsFrom(object interface{}, parentType reflect.Type) bool { // IsEmpty checks whether a value is empty i.e. "", nil, 0, [], {}, false, etc. // For Strings, a string is considered empty if it is "" or if it only contains whitespaces func IsEmpty(value any) bool { - if value == nil { - return true - } - if valueStr, ok := value.(string); ok { - return len(strings.TrimSpace(valueStr)) == 0 - } - if valueStrPtr, ok := value.(*string); ok { - if valueStrPtr == nil { - return true - } - return len(strings.TrimSpace(*valueStrPtr)) == 0 - } - if valueBool, ok := value.(bool); ok { - // if set to true, then value is not empty - return !valueBool - } - objValue := reflect.ValueOf(value) - switch objValue.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: - return objValue.Len() == 0 - case reflect.Ptr: - if objValue.IsNil() { - return true - } - deref := objValue.Elem().Interface() - return IsEmpty(deref) - default: - zero := reflect.Zero(objValue.Type()) - return reflect.DeepEqual(value, zero.Interface()) - } + return valueUtils.IsEmpty(value) } // ToStructPtr returns an instance of the pointer (interface) to the object obj. func ToStructPtr(obj reflect.Value) (val interface{}, err error) { if !obj.IsValid() { - err = fmt.Errorf("%w: obj value [%v] is not valid", commonerrors.ErrUnsupported, obj) + err = commonerrors.Newf(commonerrors.ErrUnsupported, "obj value [%v] is not valid", obj) return } vp := reflect.New(obj.Type()) if !vp.CanInterface() || !obj.CanInterface() { - err = fmt.Errorf("%w: cannot get the value of the object pointer of type %T", commonerrors.ErrUnsupported, obj.Type()) + err = commonerrors.Newf(commonerrors.ErrUnsupported, "cannot get the value of the object pointer of type %T", obj.Type()) return } vp.Elem().Set(obj) diff --git a/utils/reflection/reflection_test.go b/utils/reflection/reflection_test.go index b7101526a2..31e1645af6 100644 --- a/utils/reflection/reflection_test.go +++ b/utils/reflection/reflection_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" "github.com/ARM-software/golang-utils/utils/field" ) @@ -245,7 +246,7 @@ func TestSetStructField_InvalidField(t *testing.T) { err := SetStructField(&testStructure, "Title", "NEW_title") assert.NotNil(t, err) - assert.Equal(t, err, fmt.Errorf("error with field [%v]: %w", "Title", commonerrors.ErrInvalid)) + errortest.AssertError(t, err, commonerrors.ErrInvalid) } func TestSetStructField_UnsettableField(t *testing.T) { @@ -259,7 +260,7 @@ func TestSetStructField_UnsettableField(t *testing.T) { err := SetStructField(&testStructure, "unexported", "NEW_title") assert.NotNil(t, err) - assert.Equal(t, err, fmt.Errorf("error with unsettable field [%v]: %w", "unexported", commonerrors.ErrUnsupported)) + errortest.AssertError(t, err, commonerrors.ErrUnsupported) assert.NotEqual(t, testStructure.unexported, "NEW_title") assert.Equal(t, testStructure.unexported, "unsettable_field") } @@ -277,7 +278,7 @@ func TestSetStructField_FieldAndValueDifferentTypes(t *testing.T) { err := SetStructField(&testStructure, "Title", 133) assert.NotNil(t, err) - assert.Equal(t, err, fmt.Errorf("conflicting types, field [%v] and value [%v]: %w", reflect.ValueOf(testStructure).FieldByName("Title").Type().Kind(), reflect.TypeOf(123), commonerrors.ErrConflict)) + errortest.AssertError(t, err, commonerrors.ErrConflict) } func TestInheritsFrom(t *testing.T) { diff --git a/utils/value/empty.go b/utils/value/empty.go new file mode 100644 index 0000000000..df26060f98 --- /dev/null +++ b/utils/value/empty.go @@ -0,0 +1,41 @@ +package value + +import ( + "reflect" + "strings" +) + +// IsEmpty checks whether a value is empty i.e. "", nil, 0, [], {}, false, etc. +// For Strings, a string is considered empty if it is "" or if it only contains whitespaces +func IsEmpty(value any) bool { + if value == nil { + return true + } + if valueStr, ok := value.(string); ok { + return len(strings.TrimSpace(valueStr)) == 0 + } + if valueStrPtr, ok := value.(*string); ok { + if valueStrPtr == nil { + return true + } + return len(strings.TrimSpace(*valueStrPtr)) == 0 + } + if valueBool, ok := value.(bool); ok { + // if set to true, then value is not empty + return !valueBool + } + objValue := reflect.ValueOf(value) + switch objValue.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: + return objValue.Len() == 0 + case reflect.Ptr: + if objValue.IsNil() { + return true + } + deref := objValue.Elem().Interface() + return IsEmpty(deref) + default: + zero := reflect.Zero(objValue.Type()) + return reflect.DeepEqual(value, zero.Interface()) + } +} diff --git a/utils/value/empty_test.go b/utils/value/empty_test.go new file mode 100644 index 0000000000..b01224d458 --- /dev/null +++ b/utils/value/empty_test.go @@ -0,0 +1,136 @@ +package value + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestIsEmpty(t *testing.T) { + type testInterface interface { + } + var testEmptyPtr testInterface + emptyStr := "" + whiteSpace := " " + + aFilledChannel := make(chan struct{}, 1) + aFilledChannel <- struct{}{} + tests := []struct { + value interface{} + isEmpty bool + differsFromAssertEmpty bool + }{ + { + value: nil, + isEmpty: true, + }, + { + value: 0, + isEmpty: true, + }, + { + value: uint(0), + isEmpty: true, + }, + { + value: float64(0), + isEmpty: true, + }, + { + value: "", + isEmpty: true, + }, + { + value: " ", + isEmpty: true, + differsFromAssertEmpty: true, + }, + { + value: (*string)(nil), + isEmpty: true, + }, + { + value: &emptyStr, + isEmpty: true, + }, + { + value: &whiteSpace, + isEmpty: true, + differsFromAssertEmpty: true, + }, + { + value: false, + isEmpty: true, + }, + { + value: []string{}, + isEmpty: true, + }, + { + value: []int64{}, + isEmpty: true, + }, + { + value: []int64{int64(0)}, + isEmpty: false, + }, + { + value: "blah", + isEmpty: false, + }, + { + value: 1, + isEmpty: false, + }, + { + value: true, + isEmpty: false, + }, + { + value: testEmptyPtr, + isEmpty: true, + }, + { + value: map[string]string{}, + isEmpty: true, + }, + { + value: map[string]interface{}{}, + isEmpty: true, + }, + { + value: map[string]interface{}{"foo": "bar"}, + isEmpty: false, + }, + { + value: time.Time{}, + isEmpty: true, + }, + { + value: time.Now(), + isEmpty: false, + }, + { + value: make(chan struct{}), + isEmpty: true, + }, + { + value: aFilledChannel, + isEmpty: false, + }, + } + + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("subtest #%v (%v)", i, test.value), func(t *testing.T) { + assert.Equal(t, test.isEmpty, IsEmpty(test.value)) + if test.isEmpty && !test.differsFromAssertEmpty { + assert.Empty(t, test.value) + } else { + assert.NotEmpty(t, test.value) + } + }) + } +}