Skip to content

Commit 0535d80

Browse files
authored
Merge pull request #3169 from actiontech/sqlfash_api_update_ce
Sqlfash api update ce
2 parents 457f72b + a56ff97 commit 0535d80

File tree

11 files changed

+552
-196
lines changed

11 files changed

+552
-196
lines changed

sqle/driver/mysql/mysql.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/actiontech/sqle/sqle/driver/mysql/plocale"
1515
rulepkg "github.com/actiontech/sqle/sqle/driver/mysql/rule"
1616
_ "github.com/actiontech/sqle/sqle/driver/mysql/rule/ai"
17+
aiutil "github.com/actiontech/sqle/sqle/driver/mysql/rule/ai/util"
1718
"github.com/actiontech/sqle/sqle/driver/mysql/session"
1819
"github.com/actiontech/sqle/sqle/driver/mysql/util"
1920
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
@@ -639,6 +640,83 @@ func (p *PluginProcessor) GetDriverMetas() (*driverV2.DriverMetas, error) {
639640
return metas, nil
640641
}
641642

643+
func (i *MysqlDriverImpl) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
644+
node, err := util.ParseOneSql(sql)
645+
if err != nil {
646+
return nil, err
647+
}
648+
649+
if _, ok := node.(*ast.SelectStmt); !ok {
650+
log.NewEntry().Errorf("get selectivity of sql columns failed, sql is not a select statement, sql: %s", sql)
651+
return nil, nil
652+
}
653+
654+
selectVisitor := &util.SelectVisitor{}
655+
node.Accept(selectVisitor)
656+
657+
result := make(map[string]map[string]float32)
658+
659+
for _, selectNode := range selectVisitor.SelectList {
660+
if selectNode.From == nil || selectNode.From.TableRefs == nil {
661+
continue
662+
}
663+
664+
// 获取表别名映射关系
665+
aliasInfo := aiutil.GetTableAliasInfoFromJoin(selectNode.From.TableRefs)
666+
aliasMap := make(map[string]string)
667+
allTables := make([]string, 0, len(aliasInfo))
668+
669+
for _, alias := range aliasInfo {
670+
if alias.TableAliasName != "" {
671+
aliasMap[alias.TableAliasName] = alias.TableName
672+
}
673+
allTables = append(allTables, alias.TableName)
674+
}
675+
676+
// 提取列并按表分组
677+
tableColumns := util.ExtractColumnsFromSelectStmt(selectNode, aliasMap, allTables)
678+
679+
// 遍历每个表,获取其列的选择性
680+
for tableName, columnSet := range tableColumns {
681+
columns := make([]string, 0, len(columnSet))
682+
for colName := range columnSet {
683+
columns = append(columns, colName)
684+
}
685+
686+
if len(columns) == 0 {
687+
continue
688+
}
689+
690+
// 构造 TableName 对象
691+
var schemaName string
692+
for _, alias := range aliasInfo {
693+
if alias.TableName == tableName {
694+
schemaName = alias.SchemaName
695+
break
696+
}
697+
}
698+
tableNameObj := util.NewTableName(schemaName, tableName)
699+
700+
columnSelectivityMap, err := i.Ctx.GetSelectivityOfColumns(tableNameObj, columns)
701+
if err != nil {
702+
log.NewEntry().Errorf("get selectivity of columns failed, table: %s, columns: %v, error: %v", tableName, columns, err)
703+
continue
704+
}
705+
706+
if result[tableName] == nil {
707+
result[tableName] = make(map[string]float32)
708+
}
709+
for columnName, selectivity := range columnSelectivityMap {
710+
if selectivity > 0 {
711+
result[tableName][columnName] = float32(selectivity)
712+
}
713+
}
714+
}
715+
}
716+
717+
return result, nil
718+
}
719+
642720
func (p *PluginProcessor) Open(l *logrus.Entry, cfg *driverV2.Config) (driver.Plugin, error) {
643721
return NewInspect(l, cfg)
644722
}

sqle/driver/mysql/util/parser_helper.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,3 +931,106 @@ func ConvertAliasToTable(alias string, tables []*ast.TableSource) (*ast.TableNam
931931
}
932932
return nil, errors.New("can not find table")
933933
}
934+
935+
// TableColumnMap 表示按表分组的列名集合
936+
type TableColumnMap map[string]map[string]struct{}
937+
938+
// ExtractColumnsFromSelectStmt 从 SELECT 语句中提取列,并按表分组
939+
// 参数:
940+
// - selectStmt: SELECT 语句节点
941+
// - aliasMap: 表别名到实际表名的映射
942+
// - allTables: 所有涉及的表名列表(用于处理无表前缀的列)
943+
//
944+
// 返回:按表名分组的列名集合
945+
func ExtractColumnsFromSelectStmt(selectStmt *ast.SelectStmt, aliasMap map[string]string, allTables []string) TableColumnMap {
946+
tableColumns := make(TableColumnMap)
947+
948+
// 收集 SELECT 列表中的所有列别名
949+
selectAliases := make(map[string]struct{})
950+
if selectStmt.Fields != nil {
951+
for _, field := range selectStmt.Fields.Fields {
952+
if field.AsName.L != "" {
953+
selectAliases[field.AsName.L] = struct{}{}
954+
}
955+
}
956+
}
957+
958+
// 辅助函数:从表达式中提取列并按表分组
959+
extractColumnsFromExpr := func(expr ast.Node, skipAliases bool) {
960+
if expr == nil {
961+
return
962+
}
963+
columnVisitor := &ColumnNameVisitor{}
964+
expr.Accept(columnVisitor)
965+
966+
for _, colExpr := range columnVisitor.ColumnNameList {
967+
if colExpr.Name == nil {
968+
continue
969+
}
970+
971+
// 如果需要跳过别名且当前列名是一个别名,则跳过
972+
if skipAliases {
973+
if _, isAlias := selectAliases[colExpr.Name.Name.L]; isAlias && colExpr.Name.Table.L == "" {
974+
continue
975+
}
976+
}
977+
978+
var targetTableName string
979+
980+
// 如果列有表前缀(可能是别名或实际表名)
981+
if colExpr.Name.Table.L != "" {
982+
// 先尝试从别名映射中查找
983+
if actualTable, exists := aliasMap[colExpr.Name.Table.L]; exists {
984+
targetTableName = actualTable
985+
} else {
986+
// 如果不是别名,就当作实际表名
987+
targetTableName = colExpr.Name.Table.L
988+
}
989+
}
990+
991+
if targetTableName != "" {
992+
if tableColumns[targetTableName] == nil {
993+
tableColumns[targetTableName] = make(map[string]struct{})
994+
}
995+
tableColumns[targetTableName][colExpr.Name.Name.L] = struct{}{}
996+
} else {
997+
// 没有表前缀的列,可能属于任何表
998+
// 在多表查询中,尝试将该列添加到所有表
999+
for _, tableName := range allTables {
1000+
if tableColumns[tableName] == nil {
1001+
tableColumns[tableName] = make(map[string]struct{})
1002+
}
1003+
tableColumns[tableName][colExpr.Name.Name.L] = struct{}{}
1004+
}
1005+
}
1006+
}
1007+
}
1008+
1009+
// 从 SELECT Fields 提取列(包括聚合函数内的列)
1010+
if selectStmt.Fields != nil {
1011+
for _, field := range selectStmt.Fields.Fields {
1012+
extractColumnsFromExpr(field.Expr, false)
1013+
}
1014+
}
1015+
1016+
// 从 WHERE 条件提取列
1017+
if selectStmt.Where != nil {
1018+
extractColumnsFromExpr(selectStmt.Where, false)
1019+
}
1020+
1021+
// 从 GROUP BY 提取列(需要跳过别名引用)
1022+
if selectStmt.GroupBy != nil {
1023+
for _, item := range selectStmt.GroupBy.Items {
1024+
extractColumnsFromExpr(item.Expr, true)
1025+
}
1026+
}
1027+
1028+
// 从 HAVING 提取列
1029+
if selectStmt.Having != nil {
1030+
extractColumnsFromExpr(selectStmt.Having.Expr, false)
1031+
}
1032+
1033+
// 注意:不从 ORDER BY 提取,因为可能包含别名引用
1034+
1035+
return tableColumns
1036+
}

sqle/driver/plugin_adapter_v1.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,7 @@ func (s *PluginImplV1) GetDatabaseObjectDDL(ctx context.Context, objInfos []*dri
353353
func (s *PluginImplV1) GetDatabaseDiffModifySQL(ctx context.Context, calibratedDSN *driverV2.DSN, objInfos []*driverV2.DatabasCompareSchemaInfo) ([]*driverV2.DatabaseDiffModifySQLResult, error) {
354354
return nil, fmt.Errorf("unimplement this method")
355355
}
356+
357+
func (p *PluginImplV1) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
358+
return nil, fmt.Errorf("unimplement this method")
359+
}

sqle/driver/plugin_adapter_v2.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,25 @@ func (s *PluginImplV2) GetDatabaseDiffModifySQL(ctx context.Context, calibratedD
725725
}
726726
return dbDiffSQLs, nil
727727
}
728+
729+
func (s *PluginImplV2) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
730+
api := "GetSelectivityOfSQLColumns"
731+
s.preLog(api)
732+
resp, err := s.client.GetSelectivityOfSQLColumns(ctx, &protoV2.GetSelectivityOfSQLColumnsRequest{
733+
Session: s.Session,
734+
Sql: sql,
735+
})
736+
s.afterLog(api, err)
737+
if err != nil {
738+
return nil, err
739+
}
740+
result := make(map[string]map[string]float32, len(resp.Selectivity))
741+
for _, v := range resp.Selectivity {
742+
colMap := make(map[string]float32, len(v.SelectivityOfColumns))
743+
for k, sel := range v.SelectivityOfColumns {
744+
colMap[k] = sel
745+
}
746+
result[v.TableName] = colMap
747+
}
748+
return result, nil
749+
}

sqle/driver/plugin_interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type Plugin interface {
5151
Backup(ctx context.Context, backupStrategy string, sql string, backupMaxRows uint64) (backupSqls []string, executeResult string, err error)
5252

5353
RecommendBackupStrategy(ctx context.Context, sql string) (*RecommendBackupStrategyRes, error)
54+
GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string] /*table name*/ map[string] /*column name*/ float32, error)
5455
}
5556

5657
type RecommendBackupStrategyRes struct {

sqle/driver/v2/driver_grpc_server.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,27 @@ func (d *DriverGrpcServer) GetDatabaseDiffModifySQL(ctx context.Context, req *pr
674674
SchemaDiffModify: scheamDiff,
675675
}, nil
676676
}
677+
678+
func (d *DriverGrpcServer) GetSelectivityOfSQLColumns(ctx context.Context, req *protoV2.GetSelectivityOfSQLColumnsRequest) (*protoV2.GetSelectivityOfSQLColumnsResponse, error) {
679+
driver, err := d.getDriverBySession(req.Session)
680+
if err != nil {
681+
return &protoV2.GetSelectivityOfSQLColumnsResponse{}, err
682+
}
683+
selectivity, err := driver.GetSelectivityOfSQLColumns(ctx, req.Sql)
684+
if err != nil {
685+
return &protoV2.GetSelectivityOfSQLColumnsResponse{}, err
686+
}
687+
protoSelectivity := make([]*protoV2.SelectivityOfSQLColumns, 0, len(selectivity))
688+
for tableName, colMap := range selectivity {
689+
// 直接将 map[string]float32 赋值,无需合并操作
690+
merged := make(map[string]float32, len(colMap))
691+
for col, val := range colMap {
692+
merged[col] = val
693+
}
694+
protoSelectivity = append(protoSelectivity, &protoV2.SelectivityOfSQLColumns{
695+
TableName: tableName,
696+
SelectivityOfColumns: merged,
697+
})
698+
}
699+
return &protoV2.GetSelectivityOfSQLColumnsResponse{Selectivity: protoSelectivity}, nil
700+
}

sqle/driver/v2/driver_interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ type Driver interface {
9999

100100
Backup(ctx context.Context, req *BackupReq) (*BackupRes, error)
101101
RecommendBackupStrategy(ctx context.Context, req *RecommendBackupStrategyReq) (*RecommendBackupStrategyRes, error)
102+
GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string] /*table name*/ map[string] /*column name*/ float32, error)
102103
}
103104

104105
const (

0 commit comments

Comments
 (0)