Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 139 additions & 3 deletions sql/analyzer/costed_index_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package analyzer

import (
"cmp"
"fmt"
"slices"
"sort"
"strings"
"time"

"github.com/shopspring/decimal"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function/spatial"
Expand Down Expand Up @@ -826,6 +829,120 @@ type indexScanRangeBuilder struct {
leftover []sql.Expression
}

func keysToRangeColl[N cmp.Ordered](keys []N, typ sql.Type) sql.MySQLRangeCollection {
slices.Sort(keys)
keys = slices.Compact(keys)
// TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0],
// then we can just have one continuous range. unsure if this is worth
res := make(sql.MySQLRangeCollection, len(keys))
for i, key := range keys {
res[i] = sql.MySQLRange{
sql.ClosedRangeColumnExpr(key, key, typ),
}
}
if len(res) == 0 {
return nil
}
return res
}

func setToIntRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) {
keys := make([]int64, 0, len(setVals))
for _, val := range setVals {
switch v := val.(type) {
case int:
keys = append(keys, int64(v))
case int8:
keys = append(keys, int64(v))
case int16:
keys = append(keys, int64(v))
case int32:
keys = append(keys, int64(v))
case int64:
keys = append(keys, v)
case uint:
keys = append(keys, int64(v))
case uint8:
keys = append(keys, int64(v))
case uint16:
keys = append(keys, int64(v))
case uint32:
keys = append(keys, int64(v))
case uint64:
keys = append(keys, int64(v))
// float32, float64, and decimal are ok as long as they don't round
case float32:
key := int64(v)
if float32(key) == v {
keys = append(keys, key)
}
case float64:
key := int64(v)
if float64(key) == v {
keys = append(keys, key)
}
case decimal.Decimal:
key := v.IntPart()
if v.Equal(decimal.NewFromInt(key)) {
keys = append(keys, key)
}
default:
// resort to default behavior for types that require more conversion
return nil, false
}
}

return keysToRangeColl(keys, typ), true
}

func setToUintRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) {
keys := make([]uint64, 0, len(setVals))
for _, val := range setVals {
switch v := val.(type) {
case int:
keys = append(keys, uint64(v))
case int8:
keys = append(keys, uint64(v))
case int16:
keys = append(keys, uint64(v))
case int32:
keys = append(keys, uint64(v))
case int64:
keys = append(keys, uint64(v))
case uint:
keys = append(keys, uint64(v))
case uint8:
keys = append(keys, uint64(v))
case uint16:
keys = append(keys, uint64(v))
case uint32:
keys = append(keys, uint64(v))
case uint64:
keys = append(keys, v)
// float32, float64, and decimal are ok as long as they don't round
case float32:
key := uint64(v)
if float32(key) == v {
keys = append(keys, key)
}
case float64:
key := uint64(v)
if float64(key) == v {
keys = append(keys, key)
}
case decimal.Decimal:
key := v.IntPart()
if v.Equal(decimal.NewFromInt(key)) {
keys = append(keys, uint64(key))
}
default:
// resort to default behavior for types that require more conversion
return nil, false
}
}
return keysToRangeColl(keys, typ), true
}

// buildRangeCollection converts our representation of the best index scan
// into the format that represents an index lookup, a list of sql.Range.
func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRangeCollection, error) {
Expand All @@ -839,6 +956,25 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa
case *iScanOr:
ranges, err = b.rangeBuildOr(f, inScan)
case *iScanLeaf:
// When the filter is a simple IN, we can skip costly checks like building the RangeTree.
if f.Op() == sql.IndexScanOpInSet {
cets := b.idx.ColumnExpressionTypes()
if len(cets) == 1 {
typ := cets[0].Type
var ok bool
// TODO: it's possible to apply this optimization to other
// numeric types (float32, float64, decimal).
if types.IsSigned(typ) {
if ranges, ok = setToIntRangeColl(f.setValues, typ); ok {
return ranges, nil
}
} else if types.IsUnsigned(typ) {
if ranges, ok = setToUintRangeColl(f.setValues, typ); ok {
return ranges, nil
}
}
}
}
ranges, err = b.rangeBuildLeaf(f, inScan)
default:
return nil, fmt.Errorf("unknown indexFilter type: %T", f)
Expand Down Expand Up @@ -1429,14 +1565,14 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri

if op == sql.IndexScanOpInSet || op == sql.IndexScanOpNotInSet {
tup := right.(expression.Tuple)
var litSet []interface{}
litSet := make([]any, len(tup))
var litType sql.Type
for _, lit := range tup {
for i, lit := range tup {
value, err := lit.Eval(ctx, nil)
if err != nil {
return nil, false
}
litSet = append(litSet, value)
litSet[i] = value
if litType == nil {
litType = lit.Type()
}
Expand Down
27 changes: 16 additions & 11 deletions sql/index_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ type MySQLIndexBuilder struct {
// NewMySQLIndexBuilder returns a new MySQLIndexBuilder. Used internally to construct a range that will later be passed to
// integrators through the Index function NewLookup.
func NewMySQLIndexBuilder(idx Index) *MySQLIndexBuilder {
colExprTypes := make(map[string]Type)
ranges := make(map[string][]MySQLRangeColumnExpr)
for _, cet := range idx.ColumnExpressionTypes() {
cets := idx.ColumnExpressionTypes()
colExprTypes := make(map[string]Type, len(cets))
ranges := make(map[string][]MySQLRangeColumnExpr, len(cets))
for _, cet := range cets {
typ := cet.Type
if _, ok := typ.(StringType); ok {
typ = typ.Promote()
}
colExprTypes[strings.ToLower(cet.Expression)] = typ
ranges[strings.ToLower(cet.Expression)] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)}
expr := strings.ToLower(cet.Expression)
colExprTypes[expr] = typ
ranges[expr] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)}
}
return &MySQLIndexBuilder{
idx: idx,
Expand Down Expand Up @@ -118,15 +120,19 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, k
for i, k := range keys {
// if converting from float to int results in rounding, then it's empty range
if t, ok := colTyp.(NumberType); ok && !t.IsFloat() {
f, c := floor(k), ceil(k)
switch k.(type) {
case float32, float64:
if f != c {
switch k := k.(type) {
case float32:
if float32(int64(k)) != k {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
case float64:
if float64(int64(k)) != k {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
case decimal.Decimal:
if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) {
if !k.Equal(decimal.NewFromInt(k.IntPart())) {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
Expand Down Expand Up @@ -493,7 +499,6 @@ func (b *MySQLIndexBuilder) updateCol(ctx *Context, colExpr string, potentialRan
var newRanges []MySQLRangeColumnExpr
for _, currentRange := range currentRanges {
for _, potentialRange := range potentialRanges {

newRange, ok, err := currentRange.TryIntersect(potentialRange)
if err != nil {
b.isInvalid = true
Expand Down
Loading