From 063283a8bf3d9258d8cc2360e15b64759ae04d23 Mon Sep 17 00:00:00 2001 From: Marcus Steinbach Date: Wed, 9 Jul 2025 07:32:32 +0200 Subject: [PATCH] fix: finalize fork --- src/sqlacodegen/cli.py | 9 -- src/sqlacodegen/generators.py | 40 +++++-- src/sqlacodegen/risclog_generators.py | 144 ++++++++++++++++++++++++-- src/sqlacodegen/seed_export.py | 24 +---- src/sqlacodegen/utils.py | 5 +- 5 files changed, 168 insertions(+), 54 deletions(-) diff --git a/src/sqlacodegen/cli.py b/src/sqlacodegen/cli.py index 87298d72..7de594a4 100644 --- a/src/sqlacodegen/cli.py +++ b/src/sqlacodegen/cli.py @@ -31,7 +31,6 @@ parse_extension_row, parse_function_row, parse_policy_row, - parse_sequence_row, parse_trigger_row, ) from sqlacodegen.seed_export import export_pgdata_py, get_table_dependency_order @@ -246,14 +245,6 @@ class ExportDict(TypedDict, total=False): "parse_row_func": parse_extension_row, "file": "pg_extensions.py", }, - { - "title": "Sequences", - "entities_varname": "all_sequences", - "template": "ALEMBIC_SEQUENCE_TEMPLATE", - "statement": "ALEMBIC_SEQUENCE_STATEMENT", - "parse_row_func": parse_sequence_row, - "file": "pg_sequences.py", - }, ] # ----------- Export-Loop ------------ diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index abdaa506..7d9bc542 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -41,6 +41,7 @@ from sqlalchemy.dialects.postgresql import DOMAIN, JSONB from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import CompileError +from sqlalchemy.schema import Sequence as SASequence from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.type_api import UserDefinedType from sqlalchemy.types import TypeEngine @@ -231,10 +232,10 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: elif isinstance(column.type, DOMAIN): self.add_import(column.type.data_type.__class__) - if column.default: + if column.default is not None: self.add_import(column.default) - if column.server_default: + if column.server_default is not None: if isinstance(column.server_default, (Computed, Identity)): self.add_import(column.server_default) elif isinstance(column.server_default, DefaultClause): @@ -449,7 +450,7 @@ def render_column( for fk in dedicated_fks: args.append(self.render_constraint(fk)) - if column.default: + if column.default is not None: args.append(repr(column.default)) if column.key != column.name: @@ -483,7 +484,7 @@ def render_column( ) elif isinstance(column.server_default, Identity): args.append(repr(column.server_default)) - elif column.server_default: + elif column.server_default is not None: kwargs["server_default"] = repr(column.server_default) comment = getattr(column, "comment", None) @@ -624,6 +625,24 @@ def find_free_name( return name + def get_pg_sequence_parameters( + self, bind: Engine, schema: str, name: str + ) -> dict[str, Any]: + sql = """ + SELECT start_value, increment_by, min_value, max_value + FROM pg_sequences + WHERE schemaname = :schema AND sequencename = :name + """ + with bind.connect() as conn: + row = ( + conn.execute( + sqlalchemy.text(sql), {"schema": schema or "public", "name": name} + ) + .mappings() + .first() + ) + return dict(row) if row else {} + def fix_column_types(self, table: Table) -> None: """Adjust the reflected column types.""" # Detect check constraints for boolean and enum columns @@ -667,7 +686,10 @@ def fix_column_types(self, table: Table) -> None: pass # PostgreSQL specific fix: detect sequences from server_default - if column.server_default and self.bind.dialect.name == "postgresql": + if ( + column.server_default is not None + and self.bind.dialect.name == "postgresql" + ): if isinstance(column.server_default, DefaultClause) and isinstance( column.server_default.arg, TextClause ): @@ -675,11 +697,9 @@ def fix_column_types(self, table: Table) -> None: column.server_default.arg ) if seqname: - # Add an explicit sequence - if seqname != f"{column.table.name}_{column.name}_seq": - column.default = sqlalchemy.Sequence(seqname, schema=schema) - - column.server_default = None + column.default = SASequence( + name=seqname, schema=schema, start=1, increment=1 + ) def get_adapted_type(self, coltype: Any) -> Any: compiled_type = coltype.compile(self.bind.engine.dialect) diff --git a/src/sqlacodegen/risclog_generators.py b/src/sqlacodegen/risclog_generators.py index 6232c26a..98a8efb1 100644 --- a/src/sqlacodegen/risclog_generators.py +++ b/src/sqlacodegen/risclog_generators.py @@ -1,10 +1,13 @@ import re from pprint import pformat -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast from sqlalchemy import ( Column, + Computed, + DefaultClause, ForeignKeyConstraint, + Identity, Index, MetaData, PrimaryKeyConstraint, @@ -12,10 +15,11 @@ UniqueConstraint, inspect, text, - types, ) from sqlalchemy import types as satypes from sqlalchemy.engine import Connection, Engine +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.functions import next_value from sqlacodegen.generators import ( Base, @@ -34,7 +38,7 @@ from sqlacodegen.generators import TablesGenerator EXCLUDED_TABLES = {"tmp_functest", "accesslogfailed"} -INCLUDED_POLICY_ROLES = {"brokeruser"} +INCLUDED_POLICY_ROLES = {"brokeruser", " clx_readonly", "clx"} BASE_META_DATA = Base( literal_imports=[ LiteralImport( @@ -229,7 +233,15 @@ class {classname}(PortalObject): # type: ignore[misc] JOIN pg_catalog.pg_sequences ps ON s.sequence_schema = ps.schemaname AND s.sequence_name = ps.sequencename -WHERE s.sequence_schema NOT IN ('pg_catalog', 'information_schema') +JOIN pg_class t + ON t.relkind = 'r' AND t.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = s.sequence_schema) +JOIN pg_attribute a + ON a.attrelid = t.oid +JOIN pg_attrdef d + ON d.adrelid = t.oid AND d.adnum = a.attnum +WHERE + s.sequence_schema NOT IN ('pg_catalog', 'information_schema') + AND pg_get_expr(d.adbin, d.adrelid) LIKE '%nextval(%' || s.sequence_name || '%' ORDER BY s.sequence_schema, s.sequence_name; """ @@ -428,7 +440,7 @@ def parse_sequence_row( ) -> tuple[str, str]: schema_val = row["schema"] or schema or "public" signature = row["sequence_name"] - varname = f"{signature}_sequence".lower() + varname = signature.lower() parts = [ f"AS {row['data_type']}", @@ -644,9 +656,9 @@ def clx_render_index(self: "TablesGenerator", index: Index) -> str: and hasattr(col, "type") ): coltype = getattr(col.type, "python_type", None) - if isinstance(col.type, (types.String, types.Text, types.Unicode)) or ( - coltype and coltype is str - ): + if isinstance( + col.type, (satypes.String, satypes.Text, satypes.Unicode) + ) or (coltype and coltype is str): opclass_map[col.name] = "gin_trgm_ops" elif getattr(index, "expressions", None): @@ -727,6 +739,117 @@ def clx_render_table(self: "TablesGenerator", table: Table) -> str: TablesGenerator.render_table = clx_render_table # type: ignore[method-assign] +def clx_render_column( + self: "TablesGenerator", + column: Column[Any], + show_name: bool, + is_table: bool = False, +) -> str: + args = [] + kwargs: dict[str, Any] = {} + kwarg = [] + is_sole_pk = column.primary_key and len(column.table.primary_key) == 1 + dedicated_fks = [ + c + for c in column.foreign_keys + if c.constraint + and len(c.constraint.columns) == 1 + and uses_default_name(c.constraint) + ] + is_unique = any( + isinstance(c, UniqueConstraint) + and set(c.columns) == {column} + and uses_default_name(c) + for c in column.table.constraints + ) + is_unique = is_unique or any( + i.unique and set(i.columns) == {column} and uses_default_name(i) + for i in column.table.indexes + ) + is_primary = ( + any( + isinstance(c, PrimaryKeyConstraint) + and column.name in c.columns + and uses_default_name(c) + for c in column.table.constraints + ) + or column.primary_key + ) + has_index = any( + set(i.columns) == {column} and uses_default_name(i) + for i in column.table.indexes + ) + + if show_name: + args.append(repr(column.name)) + + # Render the column type if there are no foreign keys on it or any of them + # points back to itself + if not dedicated_fks or any(fk.column is column for fk in dedicated_fks): + args.append(self.render_column_type(column.type)) + + for fk in dedicated_fks: + args.append(self.render_constraint(fk)) + + if column.default is not None: + args.append(repr(column.default)) + + if column.key != column.name: + kwargs["key"] = column.key + if is_primary: + kwargs["primary_key"] = True + if not column.nullable and not is_sole_pk and is_table: + kwargs["nullable"] = False + + if is_unique: + column.unique = True + kwargs["unique"] = True + if has_index: + column.index = True + kwarg.append("index") + kwargs["index"] = True + + # --- SERVER DEFAULT HANDLING --- + if isinstance(column.server_default, DefaultClause): + kwargs["server_default"] = render_callable( + "text", repr(cast(TextClause, column.server_default.arg).text) + ) + elif isinstance(column.server_default, Computed): + expression = str(column.server_default.sqltext) + + computed_kwargs = {} + if column.server_default.persisted is not None: + computed_kwargs["persisted"] = column.server_default.persisted + + args.append( + render_callable("Computed", repr(expression), kwargs=computed_kwargs) + ) + elif isinstance(column.server_default, Identity): + args.append(repr(column.server_default)) + elif isinstance(column.server_default, next_value): + # --------- NEU: Sequence/next_value ---------- + seq = column.server_default.sequence + if seq is not None: + default_schema = "public" + kwargs["server_default"] = ( + f"Sequence({seq.name!r}{f', schema={seq.schema!r}' if seq.schema else f', schema={default_schema!r}'}).next_value()" + ) + else: + kwargs["server_default"] = "None # Sequence not detected" + elif column.server_default is not None: + kwargs["server_default"] = repr(column.server_default) + # ----------------------------------------------- + + comment = getattr(column, "comment", None) + if comment: + kwargs["comment"] = repr(comment) + + return self.render_column_callable(is_table, *args, **kwargs) + + +TablesGenerator.render_column = clx_render_column # type: ignore[method-assign] + + def clx_generate(self: "TablesGenerator") -> tuple[str, list[str] | None]: self.generate_base() @@ -816,12 +939,10 @@ def generate_alembic_utils_sequences( # Hole alle aus DB result: list[dict[str, Any]] = fetch_all_mappings(conn, sql, {"schema": schema}) # Finde alle, die von Tables verwaltet werden - managed = get_table_managed_sequences(self.metadata) entities = [ parsed for row in result - if row["sequence_name"] not in managed - and (parsed := parse_row_func(row, template_def, schema)) is not None + if (parsed := parse_row_func(row, template_def, schema)) is not None ] code = [code for code, _ in entities] @@ -1127,6 +1248,7 @@ def get_extension_object_names( self.add_literal_import(module, name) self.add_literal_import("sqlalchemy", "text") + self.add_literal_import("sqlalchemy", "FetchedValue") return "\n\n".join(rendered), finalize_alembic_utils( pg_alembic_definition, entities, entities_name diff --git a/src/sqlacodegen/seed_export.py b/src/sqlacodegen/seed_export.py index 65f72287..6aa85d93 100644 --- a/src/sqlacodegen/seed_export.py +++ b/src/sqlacodegen/seed_export.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any -from sqlalchemy import MetaData, select, text +from sqlalchemy import MetaData, select from sqlalchemy.engine import Engine @@ -111,33 +111,11 @@ def export_pgdata_py( result.append(d) data[name] = result - sequence_rows = conn.execute( - text(""" - SELECT sequence_schema, sequence_name - FROM information_schema.sequences - WHERE sequence_schema NOT IN ('pg_catalog', 'information_schema') - ORDER BY sequence_schema, sequence_name - """) - ).fetchall() - raw_sql_stmts = [] - for row in sequence_rows: - schema = row.sequence_schema - seq_name = row.sequence_name - lastval = conn.execute( - text(f"SELECT last_value FROM {schema}.{seq_name}") - ).scalar() - raw_sql_stmts.append( - f" SELECT setval('{schema}.{seq_name}', {lastval}, false);" - ) - raw_sql_str = "\n".join(raw_sql_stmts) - seed_block, imports = data_as_code(data) lines: list[str] = [] for imp in sorted(imports): lines.append(imp) lines.append("\n\nall_seeds = {\n" + seed_block + "\n}") - lines.append('\nall_seeds[\'sql_next_values\'] = """\n' + raw_sql_str + '\n"""\n') - with open(out_path, "w", encoding="utf-8") as f: f.write("\n".join(lines)) diff --git a/src/sqlacodegen/utils.py b/src/sqlacodegen/utils.py index beea46b9..c21de0c9 100644 --- a/src/sqlacodegen/utils.py +++ b/src/sqlacodegen/utils.py @@ -17,7 +17,10 @@ Table, ) -_re_postgresql_nextval_sequence = re.compile(r"nextval\('(.+)'::regclass\)") +_re_postgresql_nextval_sequence = re.compile( + r"nextval\(\(?'(?P[\w\.]+)'(::text)?\)?::regclass\)" +) + _re_postgresql_sequence_delimiter = re.compile(r'(.*?)([."]|$)')