From e40b1edb16c5370f684b7db334c6051d4ec20224 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 26 Nov 2025 15:19:38 -0800 Subject: [PATCH 1/6] unique IN values --- sql/analyzer/costed_index_scan.go | 22 +++++++++++++++++++--- sql/index_builder.go | 13 +++++++------ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index b5fa511169..91d89d1a05 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -839,6 +839,18 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa case *iScanOr: ranges, err = b.rangeBuildOr(f, inScan) case *iScanLeaf: + // TODO: special case for in set. can skip overlapping ranges since it's a series of equality checks + // TODO: sequential integers can be converted to a single partition, but i guess that's harder? + if f.Op() == sql.IndexScanOpInSet { + bb := sql.NewMySQLIndexBuilder(b.idx) + b.rangeBuildDefaultLeaf(bb, f, inScan) + if _, err := bb.Build(b.ctx); err != nil { + return nil, err + } + ranges = bb.Ranges(b.ctx) + return ranges, nil + } + ranges, err = b.rangeBuildLeaf(f, inScan) default: return nil, fmt.Errorf("unknown indexFilter type: %T", f) @@ -1429,14 +1441,18 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri if op == sql.IndexScanOpInSet || op == sql.IndexScanOpNotInSet { tup := right.(expression.Tuple) - var litSet []interface{} + litSet := make(map[any]struct{}, len(tup)) + setVals := make([]any, 0, len(tup)) var litType sql.Type for _, lit := range tup { value, err := lit.Eval(ctx, nil) if err != nil { return nil, false } - litSet = append(litSet, value) + if _, ok = litSet[value]; !ok { + litSet[value] = struct{}{} + setVals = append(setVals, value) + } if litType == nil { litType = lit.Type() } @@ -1445,7 +1461,7 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri id: id, gf: gf, op: op, - setValues: litSet, + setValues: setVals, litType: litType, underlying: underlying, }, true diff --git a/sql/index_builder.go b/sql/index_builder.go index 5a0582c623..08ba231ba6 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -42,15 +42,17 @@ type MySQLIndexBuilder struct { // NewMySQLIndexBuilder returns a new MySQLIndexBuilder. Used internally to construct a range that will later be passed to // integrators through the Index function NewLookup. func NewMySQLIndexBuilder(idx Index) *MySQLIndexBuilder { - colExprTypes := make(map[string]Type) - ranges := make(map[string][]MySQLRangeColumnExpr) - for _, cet := range idx.ColumnExpressionTypes() { + cets := idx.ColumnExpressionTypes() + colExprTypes := make(map[string]Type, len(cets)) + ranges := make(map[string][]MySQLRangeColumnExpr, len(cets)) + for _, cet := range cets { typ := cet.Type if _, ok := typ.(StringType); ok { typ = typ.Promote() } - colExprTypes[strings.ToLower(cet.Expression)] = typ - ranges[strings.ToLower(cet.Expression)] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} + expr := strings.ToLower(cet.Expression) + colExprTypes[expr] = typ + ranges[expr] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} } return &MySQLIndexBuilder{ idx: idx, @@ -493,7 +495,6 @@ func (b *MySQLIndexBuilder) updateCol(ctx *Context, colExpr string, potentialRan var newRanges []MySQLRangeColumnExpr for _, currentRange := range currentRanges { for _, potentialRange := range potentialRanges { - newRange, ok, err := currentRange.TryIntersect(potentialRange) if err != nil { b.isInvalid = true From 97b69d50d33546b7172d653abdeeb97b9a19b54c Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 2 Dec 2025 15:30:09 -0800 Subject: [PATCH 2/6] test --- enginetest/memory_engine_test.go | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index f1ec7b45d0..b447ecc83b 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,23 +200,18 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "AS OF propagates to nested CALLs", - SetUpScript: []string{}, + Name: "aaaaa", + SetUpScript: []string{ + "create table t (i int primary key, j int);", + "insert into t values (1, 1), (2, 2), (3, 3);", + }, Assertions: []queries.ScriptTestAssertion{ { - Query: "create procedure create_proc() create table t (i int primary key, j int);", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, - }, - { - Query: "call create_proc()", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, + Query: "select * from t where i in (1, 2, 3)", + Expected: []sql.Row{}, }, }, }, From d1037bb1d0f5c940e5c9b976530b34d4e5b6f722 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 4 Dec 2025 15:45:06 -0800 Subject: [PATCH 3/6] rangetree and overlap check for some in queries --- sql/analyzer/costed_index_scan.go | 74 ++++++++++++++++++++++++------- sql/index_builder.go | 14 +++--- 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index 91d89d1a05..89608edc2d 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/spatial" @@ -826,6 +828,53 @@ type indexScanRangeBuilder struct { leftover []sql.Expression } +func castToInt64(v any) (int64, bool) { + switch v := v.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case float32, float64, decimal.Decimal: + // TODO: return an empty range here + return 0, false + default: + return 0, false + } +} + +func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) (sql.MySQLRangeCollection, bool) { + if len(colExprTypes) != 1 { + return nil, false + } + typ := colExprTypes[0].Type + if !types.IsSigned(typ) { + return nil, false + } + var ok bool + keys := make([]int64, len(setVals)) + for i, val := range setVals { + keys[i], ok = castToInt64(val) + if !ok { + return nil, false + } + } + slices.Sort(keys) + slices.Compact(keys) + res := make(sql.MySQLRangeCollection, len(keys)) + for i, key := range keys { + res[i] = sql.MySQLRange{ + sql.ClosedRangeColumnExpr(key, key, typ), + } + } + return res, true +} + // buildRangeCollection converts our representation of the best index scan // into the format that represents an index lookup, a list of sql.Range. func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRangeCollection, error) { @@ -839,18 +888,13 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa case *iScanOr: ranges, err = b.rangeBuildOr(f, inScan) case *iScanLeaf: - // TODO: special case for in set. can skip overlapping ranges since it's a series of equality checks - // TODO: sequential integers can be converted to a single partition, but i guess that's harder? + // TODO: special case for in set. can skip building range tree and overlapping range check since it's a series of equality checks if f.Op() == sql.IndexScanOpInSet { - bb := sql.NewMySQLIndexBuilder(b.idx) - b.rangeBuildDefaultLeaf(bb, f, inScan) - if _, err := bb.Build(b.ctx); err != nil { - return nil, err + cets := b.idx.ColumnExpressionTypes() + if ranges, ok := setToSignedIntRange(f.setValues, cets); ok { + return ranges, nil } - ranges = bb.Ranges(b.ctx) - return ranges, nil } - ranges, err = b.rangeBuildLeaf(f, inScan) default: return nil, fmt.Errorf("unknown indexFilter type: %T", f) @@ -1441,18 +1485,14 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri if op == sql.IndexScanOpInSet || op == sql.IndexScanOpNotInSet { tup := right.(expression.Tuple) - litSet := make(map[any]struct{}, len(tup)) - setVals := make([]any, 0, len(tup)) + litSet := make([]any, len(tup)) var litType sql.Type - for _, lit := range tup { + for i, lit := range tup { value, err := lit.Eval(ctx, nil) if err != nil { return nil, false } - if _, ok = litSet[value]; !ok { - litSet[value] = struct{}{} - setVals = append(setVals, value) - } + litSet[i] = value if litType == nil { litType = lit.Type() } @@ -1461,7 +1501,7 @@ func newLeaf(ctx *sql.Context, id indexScanId, e sql.Expression, underlying stri id: id, gf: gf, op: op, - setValues: setVals, + setValues: litSet, litType: litType, underlying: underlying, }, true diff --git a/sql/index_builder.go b/sql/index_builder.go index 08ba231ba6..26e848660e 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -120,15 +120,19 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, k for i, k := range keys { // if converting from float to int results in rounding, then it's empty range if t, ok := colTyp.(NumberType); ok && !t.IsFloat() { - f, c := floor(k), ceil(k) - switch k.(type) { - case float32, float64: - if f != c { + switch k := k.(type) { + case float32: + if float32(int64(k)) != k { + potentialRanges[i] = EmptyRangeColumnExpr(colTyp) + continue + } + case float64: + if float64(int64(k)) != k { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } case decimal.Decimal: - if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) { + if !k.Equal(decimal.NewFromInt(k.IntPart())) { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } From 2446563527f862abaa6dc269af65f50f9a1fde8a Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 5 Dec 2025 11:03:29 -0800 Subject: [PATCH 4/6] more optimizing --- sql/analyzer/costed_index_scan.go | 74 ++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index 89608edc2d..e0b66e9066 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -828,26 +828,6 @@ type indexScanRangeBuilder struct { leftover []sql.Expression } -func castToInt64(v any) (int64, bool) { - switch v := v.(type) { - case int: - return int64(v), true - case int8: - return int64(v), true - case int16: - return int64(v), true - case int32: - return int64(v), true - case int64: - return v, true - case float32, float64, decimal.Decimal: - // TODO: return an empty range here - return 0, false - default: - return 0, false - } -} - func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) (sql.MySQLRangeCollection, bool) { if len(colExprTypes) != 1 { return nil, false @@ -856,22 +836,64 @@ func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) if !types.IsSigned(typ) { return nil, false } - var ok bool - keys := make([]int64, len(setVals)) - for i, val := range setVals { - keys[i], ok = castToInt64(val) - if !ok { + + keys := make([]int64, 0, len(setVals)) + for _, val := range setVals { + switch v := val.(type) { + case int: + keys = append(keys, int64(v)) + case int8: + keys = append(keys, int64(v)) + case int16: + keys = append(keys, int64(v)) + case int32: + keys = append(keys, int64(v)) + case int64: + keys = append(keys, v) + case uint: + keys = append(keys, int64(v)) + case uint8: + keys = append(keys, int64(v)) + case uint16: + keys = append(keys, int64(v)) + case uint32: + keys = append(keys, int64(v)) + case uint64: + keys = append(keys, int64(v)) + // float32, float64, and decimal are ok as long as they don't round + case float32: + key := int64(v) + if float32(key) == v { + keys = append(keys, key) + } + case float64: + key := int64(v) + if float64(key) == v { + keys = append(keys, key) + } + case decimal.Decimal: + key := v.IntPart() + if v.Equal(decimal.NewFromInt(key)) { + keys = append(keys, key) + } + default: + // resort to default behavior for types that require more conversion return nil, false } } + slices.Sort(keys) - slices.Compact(keys) + keys = slices.Compact(keys) res := make(sql.MySQLRangeCollection, len(keys)) for i, key := range keys { res[i] = sql.MySQLRange{ sql.ClosedRangeColumnExpr(key, key, typ), } } + + if len(res) == 0 { + return nil, true + } return res, true } From e4e22e20de9f9e94fe03e723631fd8a3b7516094 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 5 Dec 2025 12:09:20 -0800 Subject: [PATCH 5/6] use generics and apply to unsigned --- sql/analyzer/costed_index_scan.go | 98 ++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index e0b66e9066..2d74bfb2d6 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -15,6 +15,7 @@ package analyzer import ( + "cmp" "fmt" "slices" "sort" @@ -828,15 +829,24 @@ type indexScanRangeBuilder struct { leftover []sql.Expression } -func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) (sql.MySQLRangeCollection, bool) { - if len(colExprTypes) != 1 { - return nil, false +func keysToRangeColl[N cmp.Ordered](keys []N, typ sql.Type) sql.MySQLRangeCollection { + slices.Sort(keys) + keys = slices.Compact(keys) + // TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0], + // then we can just have one continuous range. unsure if this is worth + res := make(sql.MySQLRangeCollection, len(keys)) + for i, key := range keys { + res[i] = sql.MySQLRange{ + sql.ClosedRangeColumnExpr(key, key, typ), + } } - typ := colExprTypes[0].Type - if !types.IsSigned(typ) { - return nil, false + if len(res) == 0 { + return nil } + return res +} +func setToIntRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) { keys := make([]int64, 0, len(setVals)) for _, val := range setVals { switch v := val.(type) { @@ -882,19 +892,55 @@ func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) } } - slices.Sort(keys) - keys = slices.Compact(keys) - res := make(sql.MySQLRangeCollection, len(keys)) - for i, key := range keys { - res[i] = sql.MySQLRange{ - sql.ClosedRangeColumnExpr(key, key, typ), - } - } + return keysToRangeColl(keys, typ), true +} - if len(res) == 0 { - return nil, true +func setToUintRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) { + keys := make([]uint64, 0, len(setVals)) + for _, val := range setVals { + switch v := val.(type) { + case int: + keys = append(keys, uint64(v)) + case int8: + keys = append(keys, uint64(v)) + case int16: + keys = append(keys, uint64(v)) + case int32: + keys = append(keys, uint64(v)) + case int64: + keys = append(keys, uint64(v)) + case uint: + keys = append(keys, uint64(v)) + case uint8: + keys = append(keys, uint64(v)) + case uint16: + keys = append(keys, uint64(v)) + case uint32: + keys = append(keys, uint64(v)) + case uint64: + keys = append(keys, v) + // float32, float64, and decimal are ok as long as they don't round + case float32: + key := uint64(v) + if float32(key) == v { + keys = append(keys, key) + } + case float64: + key := uint64(v) + if float64(key) == v { + keys = append(keys, key) + } + case decimal.Decimal: + key := v.IntPart() + if v.Equal(decimal.NewFromInt(key)) { + keys = append(keys, uint64(key)) + } + default: + // resort to default behavior for types that require more conversion + return nil, false + } } - return res, true + return keysToRangeColl(keys, typ), true } // buildRangeCollection converts our representation of the best index scan @@ -910,11 +956,23 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa case *iScanOr: ranges, err = b.rangeBuildOr(f, inScan) case *iScanLeaf: - // TODO: special case for in set. can skip building range tree and overlapping range check since it's a series of equality checks + // When the filter is a simple IN, we can skip costly checks like building the RangeTree. if f.Op() == sql.IndexScanOpInSet { cets := b.idx.ColumnExpressionTypes() - if ranges, ok := setToSignedIntRange(f.setValues, cets); ok { - return ranges, nil + if len(cets) == 1 { + typ := cets[0].Type + var ok bool + // TODO: it's possible to apply this optimization to other + // numeric types (float32, float64, decimal). + if types.IsSigned(typ) { + if ranges, ok = setToIntRangeColl(f.setValues, typ); ok { + return ranges, nil + } + } else if types.IsUnsigned(typ) { + if ranges, ok = setToUintRangeColl(f.setValues, typ); ok { + return ranges, nil + } + } } } ranges, err = b.rangeBuildLeaf(f, inScan) From 9579daa26420cc5d4109f78516f5adc32e52fba1 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 5 Dec 2025 12:49:26 -0800 Subject: [PATCH 6/6] revert --- enginetest/memory_engine_test.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b447ecc83b..f1ec7b45d0 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,18 +200,23 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - //t.Skip() + t.Skip() var scripts = []queries.ScriptTest{ { - Name: "aaaaa", - SetUpScript: []string{ - "create table t (i int primary key, j int);", - "insert into t values (1, 1), (2, 2), (3, 3);", - }, + Name: "AS OF propagates to nested CALLs", + SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { - Query: "select * from t where i in (1, 2, 3)", - Expected: []sql.Row{}, + Query: "create procedure create_proc() create table t (i int primary key, j int);", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "call create_proc()", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, }, },