Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 120 additions & 65 deletions internal/assertions/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
//
// Len also fails if the object has a type that len() does not accept.
//
// The asserted object can be a string, a slice, a map, an array or a channel.
// The asserted object can be a string, a slice, a map, an array, pointer to array or a channel.
//
// See also [reflect.Len].
//
Expand Down Expand Up @@ -301,12 +301,16 @@ func MapNotContainsT[Map ~map[K]V, K comparable, V any](t T, m Map, key K, msgAn
return true
}

const unsupportedCollectionType = "%q has an unsupported type %s"

// Subset asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
//
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// nil values are considered as empty sets.
//
// # Usage
//
// assertions.Subset(t, []int{1, 2, 3}, []int{1, 2})
Expand All @@ -324,37 +328,46 @@ func Subset(t T, list, subset any, msgAndArgs ...any) (ok bool) {
h.Helper()
}

if subset == nil {
return true // we consider nil to be equal to the nil set
}
subsetType := reflect.TypeOf(subset)
listType := reflect.TypeOf(list)

listKind := reflect.TypeOf(list).Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
if subsetType == nil {
if listType == nil { // ∅ ⊂ ∅
return true // we consider nil to be equal to the nil set
}

listKind := listType.Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf(unsupportedCollectionType, list, listKind), msgAndArgs...)
}

return true
}

subsetKind := reflect.TypeOf(subset).Kind()
subsetKind := subsetType.Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
return Fail(t, fmt.Sprintf(unsupportedCollectionType, subset, subsetKind), msgAndArgs...)
}

if listType == nil {
subsetList := reflect.ValueOf(subset)
if subsetList.Len() == 0 {
return true
}

return Fail(t, fmt.Sprintf("%q is not a subset of the empty set", subset), msgAndArgs...)
}

listKind := listType.Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf(unsupportedCollectionType, list, listKind), msgAndArgs...)
}

if subsetKind == reflect.Map && listKind == reflect.Map {
subsetMap := reflect.ValueOf(subset)
actualMap := reflect.ValueOf(list)

for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)

if !av.IsValid() {
return Fail(t, fmt.Sprintf("%s does not contain %s", truncatingFormat("%#v", list), truncatingFormat("%#v", subset)), msgAndArgs...)
}
if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return Fail(t, fmt.Sprintf("%s does not contain %s", truncatingFormat("%#v", list), truncatingFormat("%#v", subset)), msgAndArgs...)
}
}

return true
return isSubsetMap(t, list, subset, subsetMap, actualMap, msgAndArgs...)
}

subsetList := reflect.ValueOf(subset)
Expand All @@ -366,18 +379,7 @@ func Subset(t T, list, subset any, msgAndArgs ...any) (ok bool) {
subsetList = reflect.ValueOf(keys)
}

for i := range subsetList.Len() {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", list), msgAndArgs...)
}
if !found {
return Fail(t, fmt.Sprintf("%s does not contain %#v", truncatingFormat("%#v", list), element), msgAndArgs...)
}
}

return true
return isSubsetList(t, list, subsetList, msgAndArgs...)
}

// SliceSubsetT asserts that a slice of comparable elements contains all the elements given in the subset.
Expand Down Expand Up @@ -426,37 +428,44 @@ func NotSubset(t T, list, subset any, msgAndArgs ...any) (ok bool) {
if h, ok := t.(H); ok {
h.Helper()
}
if subset == nil {
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
const emptySetMessage = "nil is the empty set which is a subset of every set"

subsetType := reflect.TypeOf(subset)
listType := reflect.TypeOf(list)

if subsetType == nil {
return Fail(t, emptySetMessage, msgAndArgs...)
}

listKind := reflect.TypeOf(list).Kind()
if listType == nil {
subsetKind := subsetType.Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf(unsupportedCollectionType, subset, subsetKind), msgAndArgs...)
}

subsetList := reflect.ValueOf(subset)
if subsetList.Len() != 0 {
return true
}

return Fail(t, emptySetMessage, msgAndArgs...)
}

listKind := listType.Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
return Fail(t, fmt.Sprintf(unsupportedCollectionType, list, listKind), msgAndArgs...)
}

subsetKind := reflect.TypeOf(subset).Kind()
subsetKind := subsetType.Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
return Fail(t, fmt.Sprintf(unsupportedCollectionType, subset, subsetKind), msgAndArgs...)
}

if subsetKind == reflect.Map && listKind == reflect.Map {
subsetMap := reflect.ValueOf(subset)
actualMap := reflect.ValueOf(list)

for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)

if !av.IsValid() {
return true
}
if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return true
}
}

return Fail(t, fmt.Sprintf("%s is a subset of %s", truncatingFormat("%q", subset), truncatingFormat("%q", list)), msgAndArgs...)
return isNotSubsetMap(t, list, subset, subsetMap, actualMap, msgAndArgs...)
}

subsetList := reflect.ValueOf(subset)
Expand All @@ -468,18 +477,7 @@ func NotSubset(t T, list, subset any, msgAndArgs ...any) (ok bool) {
subsetList = reflect.ValueOf(keys)
}

for i := range subsetList.Len() {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("%q could not be applied builtin len()", list), msgAndArgs...)
}
if !found {
return true
}
}

return Fail(t, fmt.Sprintf("%s is a subset of %s", truncatingFormat("%q", subset), truncatingFormat("%q", list)), msgAndArgs...)
return isNotSubsetList(t, list, subset, subsetList, msgAndArgs...)
}

// SliceNotSubsetT asserts that a slice of comparable elements does not contain all the elements given in the subset.
Expand Down Expand Up @@ -649,6 +647,63 @@ func NotElementsMatchT[E comparable](t T, listA, listB []E, msgAndArgs ...any) (
return true
}

func isSubsetMap(t T, list, subset any, subsetMap, actualMap reflect.Value, msgAndArgs ...any) bool {
for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)

if !av.IsValid() {
return Fail(t, fmt.Sprintf("%s does not contain %s", truncatingFormat("%#v", list), truncatingFormat("%#v", subset)), msgAndArgs...)
}
if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return Fail(t, fmt.Sprintf("%s does not contain %s", truncatingFormat("%#v", list), truncatingFormat("%#v", subset)), msgAndArgs...)
}
}

return true
}

func isNotSubsetMap(t T, list, subset any, subsetMap, actualMap reflect.Value, msgAndArgs ...any) bool {
for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)

if !av.IsValid() {
return true
}

if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return true
}
}

return Fail(t, fmt.Sprintf("%s is a subset of %s", truncatingFormat("%q", subset), truncatingFormat("%q", list)), msgAndArgs...)
}

func isSubsetList(t T, list any, subsetList reflect.Value, msgAndArgs ...any) bool {
for i := range subsetList.Len() {
element := subsetList.Index(i).Interface()
_, found := containsElement(list, element) // containsElement will work for this type: no need to check the ok bool
if !found {
return Fail(t, fmt.Sprintf("%s does not contain %#v", truncatingFormat("%#v", list), element), msgAndArgs...)
}
}

return true
}

func isNotSubsetList(t T, list, subset any, subsetList reflect.Value, msgAndArgs ...any) bool {
for i := range subsetList.Len() {
element := subsetList.Index(i).Interface()
_, found := containsElement(list, element)
if !found {
return true
}
}

return Fail(t, fmt.Sprintf("%s is a subset of %s", truncatingFormat("%q", subset), truncatingFormat("%q", list)), msgAndArgs...)
}

// containsElement tries to loop over the list check if the list includes the element.
//
// return (false, false) if impossible.
Expand Down
Loading
Loading