diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index b5fa511169..2d74bfb2d6 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -15,12 +15,15 @@ package analyzer import ( + "cmp" "fmt" "slices" "sort" "strings" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/spatial" @@ -826,6 +829,120 @@ type indexScanRangeBuilder struct { leftover []sql.Expression } +func keysToRangeColl[N cmp.Ordered](keys []N, typ sql.Type) sql.MySQLRangeCollection { + slices.Sort(keys) + keys = slices.Compact(keys) + // TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0], + // then we can just have one continuous range. unsure if this is worth + res := make(sql.MySQLRangeCollection, len(keys)) + for i, key := range keys { + res[i] = sql.MySQLRange{ + sql.ClosedRangeColumnExpr(key, key, typ), + } + } + if len(res) == 0 { + return nil + } + return res +} + +func setToIntRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) { + keys := make([]int64, 0, len(setVals)) + for _, val := range setVals { + switch v := val.(type) { + case int: + keys = append(keys, int64(v)) + case int8: + keys = append(keys, int64(v)) + case int16: + keys = append(keys, int64(v)) + case int32: + keys = append(keys, int64(v)) + case int64: + keys = append(keys, v) + case uint: + keys = append(keys, int64(v)) + case uint8: + keys = append(keys, int64(v)) + case uint16: + keys = append(keys, int64(v)) + case uint32: + keys = append(keys, int64(v)) + case uint64: + keys = append(keys, int64(v)) + // float32, float64, and decimal are ok as long as they don't round + case float32: + key := int64(v) + if float32(key) == v { + keys = append(keys, key) + } + case float64: + key := int64(v) + if float64(key) == v { + keys = append(keys, key) + } + case decimal.Decimal: + key := v.IntPart() + if v.Equal(decimal.NewFromInt(key)) { + keys = append(keys, key) + } + default: + // resort to default behavior for types that require more conversion + return nil, false + } + } + + return keysToRangeColl(keys, typ), true +} + +func setToUintRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) { + keys := make([]uint64, 0, len(setVals)) + for _, val := range setVals { + switch v := val.(type) { + case int: + keys = append(keys, uint64(v)) + case int8: + keys = append(keys, uint64(v)) + case int16: + keys = append(keys, uint64(v)) + case int32: + keys = append(keys, uint64(v)) + case int64: + keys = append(keys, uint64(v)) + case uint: + keys = append(keys, uint64(v)) + case uint8: + keys = append(keys, uint64(v)) + case uint16: + keys = append(keys, uint64(v)) + case uint32: + keys = append(keys, uint64(v)) + case uint64: + keys = append(keys, v) + // float32, float64, and decimal are ok as long as they don't round + case float32: + key := uint64(v) + if float32(key) == v { + keys = append(keys, key) + } + case float64: + key := uint64(v) + if float64(key) == v { + keys = append(keys, key) + } + case decimal.Decimal: + key := v.IntPart() + if v.Equal(decimal.NewFromInt(key)) { + keys = append(keys, uint64(key)) + } + default: + // resort to default behavior for types that require more conversion + return nil, false + } + } + return keysToRangeColl(keys, typ), true +} + // buildRangeCollection converts our representation of the best index scan // into the format that represents an index lookup, a list of sql.Range. func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRangeCollection, error) { @@ -839,6 +956,25 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa case *iScanOr: ranges, err = b.rangeBuildOr(f, inScan) case *iScanLeaf: + // When the filter is a simple IN, we can skip costly checks like building the RangeTree. + if f.Op() == sql.IndexScanOpInSet { + cets := b.idx.ColumnExpressionTypes() + if len(cets) == 1 { + typ := cets[0].Type + var ok bool + // TODO: it's possible to apply this optimization to other + // numeric types (float32, float64, decimal). + if types.IsSigned(typ) { + if ranges, ok = setToIntRangeColl(f.setValues, typ); ok { + return ranges, nil + } + } else if types.IsUnsigned(typ) { + if ranges, ok = setToUintRangeColl(f.setValues, typ); ok { + return ranges, nil + } + } + } + } ranges, err = b.rangeBuildLeaf(f, inScan) default: return nil, fmt.Errorf("unknown indexFilter type: %T", f) @@ -1429,14 +1565,14 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri if op == sql.IndexScanOpInSet || op == sql.IndexScanOpNotInSet { tup := right.(expression.Tuple) - var litSet []interface{} + litSet := make([]any, len(tup)) var litType sql.Type - for _, lit := range tup { + for i, lit := range tup { value, err := lit.Eval(ctx, nil) if err != nil { return nil, false } - litSet = append(litSet, value) + litSet[i] = value if litType == nil { litType = lit.Type() } diff --git a/sql/index_builder.go b/sql/index_builder.go index 5a0582c623..26e848660e 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -42,15 +42,17 @@ type MySQLIndexBuilder struct { // NewMySQLIndexBuilder returns a new MySQLIndexBuilder. Used internally to construct a range that will later be passed to // integrators through the Index function NewLookup. func NewMySQLIndexBuilder(idx Index) *MySQLIndexBuilder { - colExprTypes := make(map[string]Type) - ranges := make(map[string][]MySQLRangeColumnExpr) - for _, cet := range idx.ColumnExpressionTypes() { + cets := idx.ColumnExpressionTypes() + colExprTypes := make(map[string]Type, len(cets)) + ranges := make(map[string][]MySQLRangeColumnExpr, len(cets)) + for _, cet := range cets { typ := cet.Type if _, ok := typ.(StringType); ok { typ = typ.Promote() } - colExprTypes[strings.ToLower(cet.Expression)] = typ - ranges[strings.ToLower(cet.Expression)] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} + expr := strings.ToLower(cet.Expression) + colExprTypes[expr] = typ + ranges[expr] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} } return &MySQLIndexBuilder{ idx: idx, @@ -118,15 +120,19 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, k for i, k := range keys { // if converting from float to int results in rounding, then it's empty range if t, ok := colTyp.(NumberType); ok && !t.IsFloat() { - f, c := floor(k), ceil(k) - switch k.(type) { - case float32, float64: - if f != c { + switch k := k.(type) { + case float32: + if float32(int64(k)) != k { + potentialRanges[i] = EmptyRangeColumnExpr(colTyp) + continue + } + case float64: + if float64(int64(k)) != k { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } case decimal.Decimal: - if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) { + if !k.Equal(decimal.NewFromInt(k.IntPart())) { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } @@ -493,7 +499,6 @@ func (b *MySQLIndexBuilder) updateCol(ctx *Context, colExpr string, potentialRan var newRanges []MySQLRangeColumnExpr for _, currentRange := range currentRanges { for _, potentialRange := range potentialRanges { - newRange, ok, err := currentRange.TryIntersect(potentialRange) if err != nil { b.isInvalid = true