@@ -5,7 +5,9 @@ use log::{debug, error};
55use once_cell:: sync:: OnceCell ;
66use regex:: { Regex , RegexSet } ;
77use sqlparser:: ast:: Statement :: { Query , StartTransaction } ;
8- use sqlparser:: ast:: { BinaryOperator , Expr , SetExpr , Value } ;
8+ use sqlparser:: ast:: {
9+ BinaryOperator , Expr , Ident , JoinConstraint , JoinOperator , SetExpr , TableFactor , Value ,
10+ } ;
911use sqlparser:: dialect:: PostgreSqlDialect ;
1012use sqlparser:: parser:: Parser ;
1113
@@ -403,20 +405,67 @@ impl QueryRouter {
403405
404406 /// A `selection` is the `WHERE` clause. This parses
405407 /// the clause and extracts the sharding key, if present.
406- fn selection_parser ( & self , expr : & Expr ) -> Vec < i64 > {
408+ fn selection_parser ( & self , expr : & Expr , table_names : & Vec < Vec < Ident > > ) -> Vec < i64 > {
407409 let mut result = Vec :: new ( ) ;
408410 let mut found = false ;
409411
412+ let sharding_key = self
413+ . pool_settings
414+ . automatic_sharding_key
415+ . as_ref ( )
416+ . unwrap ( )
417+ . split ( "." )
418+ . map ( |ident| Ident :: new ( ident) )
419+ . collect :: < Vec < Ident > > ( ) ;
420+
421+ // Sharding key must be always fully qualified
422+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
423+
410424 // This parses `sharding_key = 5`. But it's technically
411425 // legal to write `5 = sharding_key`. I don't judge the people
412426 // who do that, but I think ORMs will still use the first variant,
413427 // so we can leave the second as a TODO.
414428 if let Expr :: BinaryOp { left, op, right } = expr {
415429 match & * * left {
416- Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( left) ) ,
430+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( left, table_names ) ) ,
417431 Expr :: Identifier ( ident) => {
418- found =
419- ident. value == * self . pool_settings . automatic_sharding_key . as_ref ( ) . unwrap ( ) ;
432+ // Only if we're dealing with only one table
433+ // and there is no ambiguity
434+ if & ident. value == & sharding_key[ 1 ] . value {
435+ // Sharding key is unique enough, don't worry about
436+ // table names.
437+ if & sharding_key[ 0 ] . value == "*" {
438+ found = true ;
439+ } else if table_names. len ( ) == 1 {
440+ let table = & table_names[ 0 ] ;
441+
442+ if table. len ( ) == 1 {
443+ // Table is not fully qualified, e.g.
444+ // SELECT * FROM t WHERE sharding_key = 5
445+ // Make sure the table name from the sharding key matches
446+ // the table name from the query.
447+ found = & sharding_key[ 0 ] . value == & table[ 0 ] . value ;
448+ } else if table. len ( ) == 2 {
449+ // Table name is fully qualified with the schema: e.g.
450+ // SELECT * FROM public.t WHERE sharding_key = 5
451+ // Ignore the schema (TODO: at some point, we want schema support)
452+ // and use the table name only.
453+ found = & sharding_key[ 0 ] . value == & table[ 1 ] . value ;
454+ } else {
455+ debug ! ( "Got table name with more than two idents, which is not possible" ) ;
456+ }
457+ }
458+ }
459+ }
460+
461+ Expr :: CompoundIdentifier ( idents) => {
462+ // The key is fully qualified in the query,
463+ // it will exist or Postgres will throw an error.
464+ if idents. len ( ) == 2 {
465+ found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
466+ && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
467+ }
468+ // TODO: key can have schema as well, e.g. public.data.id (len == 3)
420469 }
421470 _ => ( ) ,
422471 } ;
@@ -433,7 +482,7 @@ impl QueryRouter {
433482 } ;
434483
435484 match & * * right {
436- Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( right) ) ,
485+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( right, table_names ) ) ,
437486 Expr :: Value ( Value :: Number ( value, ..) ) => {
438487 if found {
439488 match value. parse :: < i64 > ( ) {
@@ -456,6 +505,7 @@ impl QueryRouter {
456505 /// Try to figure out which shard the query should go to.
457506 fn infer_shard ( & self , query : & sqlparser:: ast:: Query ) -> Option < usize > {
458507 let mut shards = BTreeSet :: new ( ) ;
508+ let mut exprs = Vec :: new ( ) ;
459509
460510 match & * query. body {
461511 SetExpr :: Query ( query) => {
@@ -467,27 +517,75 @@ impl QueryRouter {
467517 } ;
468518 }
469519
520+ // SELECT * FROM ...
521+ // We understand that pretty well.
470522 SetExpr :: Select ( select) => {
471- match & select. selection {
472- Some ( selection) => {
473- let sharding_keys = self . selection_parser ( selection) ;
523+ // Collect all table names from the query.
524+ let mut table_names = Vec :: new ( ) ;
474525
475- // TODO: Add support for prepared statements here.
476- // This should just give us the position of the value in the `B` message.
526+ for table in select. from . iter ( ) {
527+ match & table. relation {
528+ TableFactor :: Table { name, .. } => {
529+ table_names. push ( name. 0 . clone ( ) ) ;
530+ }
477531
478- let sharder = Sharder :: new (
479- self . pool_settings . shards ,
480- self . pool_settings . sharding_function ,
481- ) ;
532+ _ => ( ) ,
533+ } ;
482534
483- for value in sharding_keys {
484- let shard = sharder. shard ( value) ;
485- shards. insert ( shard) ;
486- }
535+ // Get table names from all the joins.
536+ for join in table. joins . iter ( ) {
537+ match & join. relation {
538+ TableFactor :: Table { name, .. } => {
539+ table_names. push ( name. 0 . clone ( ) ) ;
540+ }
541+
542+ _ => ( ) ,
543+ } ;
544+
545+ // We can filter results based on join conditions, e.g.
546+ // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
547+ match & join. join_operator {
548+ JoinOperator :: Inner ( inner_join) => match & inner_join {
549+ JoinConstraint :: On ( expr) => {
550+ // Parse the selection criteria later.
551+ exprs. push ( expr. clone ( ) ) ;
552+ }
553+
554+ _ => ( ) ,
555+ } ,
556+
557+ _ => ( ) ,
558+ } ;
559+ }
560+ }
561+
562+ // Parse the actual "FROM ..."
563+ match & select. selection {
564+ Some ( selection) => {
565+ exprs. push ( selection. clone ( ) ) ;
487566 }
488567
489568 None => ( ) ,
490569 } ;
570+
571+ // Look for sharding keys in either the join condition
572+ // or the selection.
573+ for expr in exprs. iter ( ) {
574+ let sharding_keys = self . selection_parser ( expr, & table_names) ;
575+
576+ // TODO: Add support for prepared statements here.
577+ // This should just give us the position of the value in the `B` message.
578+
579+ let sharder = Sharder :: new (
580+ self . pool_settings . shards ,
581+ self . pool_settings . sharding_function ,
582+ ) ;
583+
584+ for value in sharding_keys {
585+ let shard = sharder. shard ( value) ;
586+ shards. insert ( shard) ;
587+ }
588+ }
491589 }
492590 _ => ( ) ,
493591 } ;
@@ -825,7 +923,7 @@ mod test {
825923 query_parser_enabled : true ,
826924 primary_reads_enabled : false ,
827925 sharding_function : ShardingFunction :: PgBigintHash ,
828- automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
926+ automatic_sharding_key : Some ( String :: from ( "test. id" ) ) ,
829927 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
830928 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
831929 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -854,11 +952,6 @@ mod test {
854952 let q2 = simple_query ( "SET SERVER ROLE TO 'default'" ) ;
855953 assert ! ( qr. try_execute_command( & q2) != None ) ;
856954 assert_eq ! ( qr. active_role. unwrap( ) , pool_settings. default_role) ;
857-
858- // Here we go :)
859- let q3 = simple_query ( "SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)" ) ;
860- assert ! ( qr. infer( & q3) ) ;
861- assert_eq ! ( qr. shard( ) , 1 ) ;
862955 }
863956
864957 #[ test]
@@ -891,7 +984,7 @@ mod test {
891984 query_parser_enabled : true ,
892985 primary_reads_enabled : false ,
893986 sharding_function : ShardingFunction :: PgBigintHash ,
894- automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
987+ automatic_sharding_key : None ,
895988 healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
896989 healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
897990 ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -920,4 +1013,56 @@ mod test {
9201013 assert ! ( qr. try_execute_command( & q2) == None ) ;
9211014 assert_eq ! ( qr. active_shard, Some ( 2 ) ) ;
9221015 }
1016+
1017+ #[ test]
1018+ fn test_automatic_sharding_key ( ) {
1019+ QueryRouter :: setup ( ) ;
1020+
1021+ let mut qr = QueryRouter :: new ( ) ;
1022+ qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1023+ qr. pool_settings . shards = 3 ;
1024+
1025+ assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
1026+ assert_eq ! ( qr. shard( ) , 2 ) ;
1027+
1028+ assert ! ( qr. infer( & simple_query(
1029+ "SELECT one, two, three FROM public.data WHERE id = 6"
1030+ ) ) ) ;
1031+ assert_eq ! ( qr. shard( ) , 0 ) ;
1032+
1033+ assert ! ( qr. infer( & simple_query(
1034+ "SELECT * FROM data
1035+ INNER JOIN t2 ON data.id = 5
1036+ AND t2.data_id = data.id
1037+ WHERE data.id = 5"
1038+ ) ) ) ;
1039+ assert_eq ! ( qr. shard( ) , 2 ) ;
1040+
1041+ // Shard did not move because we couldn't determine the sharding key since it could be ambiguous
1042+ // in the query.
1043+ assert ! ( qr. infer( & simple_query(
1044+ "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1045+ ) ) ) ;
1046+ assert_eq ! ( qr. shard( ) , 2 ) ;
1047+
1048+ assert ! ( qr. infer( & simple_query(
1049+ r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1050+ ) ) ) ;
1051+ assert_eq ! ( qr. shard( ) , 0 ) ;
1052+
1053+ assert ! ( qr. infer( & simple_query(
1054+ r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1055+ ) ) ) ;
1056+ assert_eq ! ( qr. shard( ) , 2 ) ;
1057+
1058+ // Super unique sharding key
1059+ qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1060+ assert ! ( qr. infer( & simple_query(
1061+ "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1062+ ) ) ) ;
1063+ assert_eq ! ( qr. shard( ) , 0 ) ;
1064+
1065+ assert ! ( qr. infer( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) ) ) ;
1066+ assert_eq ! ( qr. shard( ) , 0 ) ;
1067+ }
9231068}
0 commit comments