diff --git a/mindsdb_sql_parser/__about__.py b/mindsdb_sql_parser/__about__.py index 59f3ddd..58de06c 100644 --- a/mindsdb_sql_parser/__about__.py +++ b/mindsdb_sql_parser/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql_parser' __package_name__ = 'mindsdb_sql_parser' -__version__ = '0.13.6' +__version__ = '0.13.7' __description__ = "Mindsdb SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql_parser/ast/select/identifier.py b/mindsdb_sql_parser/ast/select/identifier.py index cfaed3b..b94474b 100644 --- a/mindsdb_sql_parser/ast/select/identifier.py +++ b/mindsdb_sql_parser/ast/select/identifier.py @@ -21,29 +21,20 @@ def path_str_to_parts(path_str: str): return parts, is_quoted -RESERVED_KEYWORDS = { - 'PERSIST', 'IF', 'EXISTS', 'NULLS', 'FIRST', 'LAST', - 'ORDER', 'BY', 'GROUP', 'PARTITION' +# Here is a hardcoded set of keywords that can be used as identifiers without escaping. +# For example, in a query like this: select {keyword} from tbl +# If there is a need to update this list, an example code to retrieve all keywords can be found here in v0.13.6 +keywords_to_escape = { + "VALUES", "DESCRIBE", "THEN", "WRITE", "WITH", "INSERT", "DROP", "CROSS", + "SET", "ASC", "IS", "IN", "NOT", "INTO", "WINDOW", "ALTER", "WHERE", + "DISTINCT", "USE", "INNER", "COLLATE", "FOR", "USING", "FULL", "LIKE", + "JOIN", "SELECT", "OVER", "CASE", "LIMIT", "END", "UNION", "DELETE", + "HAVING", "OUTER", "FROM", "AS", "CHARACTER", "INTERSECT", "CONVERT", + "WHEN", "OR", "AND", "UPDATE", "BETWEEN", "DESC", "EXPLAIN", "SHOW", + "EXCEPT", "LEFT", "ELSE", "READ", "RIGHT" } -_reserved_keywords: set[str] = None - - -def get_reserved_words() -> set[str]: - global _reserved_keywords - - if _reserved_keywords is None: - from mindsdb_sql_parser.lexer import MindsDBLexer - - _reserved_keywords = RESERVED_KEYWORDS - for word in MindsDBLexer.tokens: - if '_' not in word: - # exclude combinations - _reserved_keywords.add(word) - return _reserved_keywords - - class Identifier(ASTNode): def __init__( self, path_str=None, parts=None, is_outer=False, with_rollup=False, @@ -77,7 +68,6 @@ def append(self, other: "Identifier") -> None: self.is_quoted += other.is_quoted def iter_parts_str(self): - reserved_words = get_reserved_words() for part, is_quoted in zip(self.parts, self.is_quoted): if isinstance(part, Star): part = str(part) @@ -85,7 +75,7 @@ def iter_parts_str(self): if ( is_quoted or not no_wrap_identifier_regex.fullmatch(part) - or part.upper() in reserved_words + or part.upper() in keywords_to_escape ): part = f'`{part}`' yield part diff --git a/tests/test_base_sql/test_select_structure.py b/tests/test_base_sql/test_select_structure.py index 67d1614..707f233 100644 --- a/tests/test_base_sql/test_select_structure.py +++ b/tests/test_base_sql/test_select_structure.py @@ -738,7 +738,29 @@ def test_partial_backticks(self): sql = "SELECT `integration`.`some table`.column" ast = parse_sql(sql) - expected_ast = Select(targets=[Identifier(parts=['integration', 'some table', 'column']),],) + expected_ast = Select( + targets=[ + Identifier( + parts=['integration', 'some table', 'column'], + is_quoted=[True, True, False] + ), + ], + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_keyword_escaping(self): + sql = "select ID, `ID`, `VALUES`" + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Identifier(parts=['ID'], is_quoted=[False]), + Identifier(parts=['ID'], is_quoted=[True]), + Identifier(parts=['VALUES'], is_quoted=[True]), + ], + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast)