Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/sqlacodegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
parse_extension_row,
parse_function_row,
parse_policy_row,
parse_sequence_row,
parse_trigger_row,
)
from sqlacodegen.seed_export import export_pgdata_py, get_table_dependency_order
Expand Down Expand Up @@ -246,14 +245,6 @@ class ExportDict(TypedDict, total=False):
"parse_row_func": parse_extension_row,
"file": "pg_extensions.py",
},
{
"title": "Sequences",
"entities_varname": "all_sequences",
"template": "ALEMBIC_SEQUENCE_TEMPLATE",
"statement": "ALEMBIC_SEQUENCE_STATEMENT",
"parse_row_func": parse_sequence_row,
"file": "pg_sequences.py",
},
]

# ----------- Export-Loop ------------
Expand Down
40 changes: 30 additions & 10 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.schema import Sequence as SASequence
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import UserDefinedType
from sqlalchemy.types import TypeEngine
Expand Down Expand Up @@ -231,10 +232,10 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
elif isinstance(column.type, DOMAIN):
self.add_import(column.type.data_type.__class__)

if column.default:
if column.default is not None:
self.add_import(column.default)

if column.server_default:
if column.server_default is not None:
if isinstance(column.server_default, (Computed, Identity)):
self.add_import(column.server_default)
elif isinstance(column.server_default, DefaultClause):
Expand Down Expand Up @@ -449,7 +450,7 @@ def render_column(
for fk in dedicated_fks:
args.append(self.render_constraint(fk))

if column.default:
if column.default is not None:
args.append(repr(column.default))

if column.key != column.name:
Expand Down Expand Up @@ -483,7 +484,7 @@ def render_column(
)
elif isinstance(column.server_default, Identity):
args.append(repr(column.server_default))
elif column.server_default:
elif column.server_default is not None:
kwargs["server_default"] = repr(column.server_default)

comment = getattr(column, "comment", None)
Expand Down Expand Up @@ -624,6 +625,24 @@ def find_free_name(

return name

def get_pg_sequence_parameters(
self, bind: Engine, schema: str, name: str
) -> dict[str, Any]:
sql = """
SELECT start_value, increment_by, min_value, max_value
FROM pg_sequences
WHERE schemaname = :schema AND sequencename = :name
"""
with bind.connect() as conn:
row = (
conn.execute(
sqlalchemy.text(sql), {"schema": schema or "public", "name": name}
)
.mappings()
.first()
)
return dict(row) if row else {}

def fix_column_types(self, table: Table) -> None:
"""Adjust the reflected column types."""
# Detect check constraints for boolean and enum columns
Expand Down Expand Up @@ -667,19 +686,20 @@ def fix_column_types(self, table: Table) -> None:
pass

# PostgreSQL specific fix: detect sequences from server_default
if column.server_default and self.bind.dialect.name == "postgresql":
if (
column.server_default is not None
and self.bind.dialect.name == "postgresql"
):
if isinstance(column.server_default, DefaultClause) and isinstance(
column.server_default.arg, TextClause
):
schema, seqname = decode_postgresql_sequence(
column.server_default.arg
)
if seqname:
# Add an explicit sequence
if seqname != f"{column.table.name}_{column.name}_seq":
column.default = sqlalchemy.Sequence(seqname, schema=schema)

column.server_default = None
column.default = SASequence(
name=seqname, schema=schema, start=1, increment=1
)

def get_adapted_type(self, coltype: Any) -> Any:
compiled_type = coltype.compile(self.bind.engine.dialect)
Expand Down
144 changes: 133 additions & 11 deletions src/sqlacodegen/risclog_generators.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import re
from pprint import pformat
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, cast

from sqlalchemy import (
Column,
Computed,
DefaultClause,
ForeignKeyConstraint,
Identity,
Index,
MetaData,
PrimaryKeyConstraint,
Table,
UniqueConstraint,
inspect,
text,
types,
)
from sqlalchemy import types as satypes
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import next_value

from sqlacodegen.generators import (
Base,
Expand All @@ -34,7 +38,7 @@
from sqlacodegen.generators import TablesGenerator

EXCLUDED_TABLES = {"tmp_functest", "accesslogfailed"}
INCLUDED_POLICY_ROLES = {"brokeruser"}
INCLUDED_POLICY_ROLES = {"brokeruser", " clx_readonly", "clx"}
BASE_META_DATA = Base(
literal_imports=[
LiteralImport(
Expand Down Expand Up @@ -229,7 +233,15 @@ class {classname}(PortalObject): # type: ignore[misc]
JOIN pg_catalog.pg_sequences ps
ON s.sequence_schema = ps.schemaname
AND s.sequence_name = ps.sequencename
WHERE s.sequence_schema NOT IN ('pg_catalog', 'information_schema')
JOIN pg_class t
ON t.relkind = 'r' AND t.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = s.sequence_schema)
JOIN pg_attribute a
ON a.attrelid = t.oid
JOIN pg_attrdef d
ON d.adrelid = t.oid AND d.adnum = a.attnum
WHERE
s.sequence_schema NOT IN ('pg_catalog', 'information_schema')
AND pg_get_expr(d.adbin, d.adrelid) LIKE '%nextval(%' || s.sequence_name || '%'
ORDER BY s.sequence_schema, s.sequence_name;
"""

Expand Down Expand Up @@ -428,7 +440,7 @@ def parse_sequence_row(
) -> tuple[str, str]:
schema_val = row["schema"] or schema or "public"
signature = row["sequence_name"]
varname = f"{signature}_sequence".lower()
varname = signature.lower()

parts = [
f"AS {row['data_type']}",
Expand Down Expand Up @@ -644,9 +656,9 @@ def clx_render_index(self: "TablesGenerator", index: Index) -> str:
and hasattr(col, "type")
):
coltype = getattr(col.type, "python_type", None)
if isinstance(col.type, (types.String, types.Text, types.Unicode)) or (
coltype and coltype is str
):
if isinstance(
col.type, (satypes.String, satypes.Text, satypes.Unicode)
) or (coltype and coltype is str):
opclass_map[col.name] = "gin_trgm_ops"

elif getattr(index, "expressions", None):
Expand Down Expand Up @@ -727,6 +739,117 @@ def clx_render_table(self: "TablesGenerator", table: Table) -> str:
TablesGenerator.render_table = clx_render_table # type: ignore[method-assign]


def clx_render_column(
self: "TablesGenerator",
column: Column[Any],
show_name: bool,
is_table: bool = False,
) -> str:
args = []
kwargs: dict[str, Any] = {}
kwarg = []
is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
dedicated_fks = [
c
for c in column.foreign_keys
if c.constraint
and len(c.constraint.columns) == 1
and uses_default_name(c.constraint)
]
is_unique = any(
isinstance(c, UniqueConstraint)
and set(c.columns) == {column}
and uses_default_name(c)
for c in column.table.constraints
)
is_unique = is_unique or any(
i.unique and set(i.columns) == {column} and uses_default_name(i)
for i in column.table.indexes
)
is_primary = (
any(
isinstance(c, PrimaryKeyConstraint)
and column.name in c.columns
and uses_default_name(c)
for c in column.table.constraints
)
or column.primary_key
)
has_index = any(
set(i.columns) == {column} and uses_default_name(i)
for i in column.table.indexes
)

if show_name:
args.append(repr(column.name))

# Render the column type if there are no foreign keys on it or any of them
# points back to itself
if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
args.append(self.render_column_type(column.type))

for fk in dedicated_fks:
args.append(self.render_constraint(fk))

if column.default is not None:
args.append(repr(column.default))

if column.key != column.name:
kwargs["key"] = column.key
if is_primary:
kwargs["primary_key"] = True
if not column.nullable and not is_sole_pk and is_table:
kwargs["nullable"] = False

if is_unique:
column.unique = True
kwargs["unique"] = True
if has_index:
column.index = True
kwarg.append("index")
kwargs["index"] = True

# --- SERVER DEFAULT HANDLING ---
if isinstance(column.server_default, DefaultClause):
kwargs["server_default"] = render_callable(
"text", repr(cast(TextClause, column.server_default.arg).text)
)
elif isinstance(column.server_default, Computed):
expression = str(column.server_default.sqltext)

computed_kwargs = {}
if column.server_default.persisted is not None:
computed_kwargs["persisted"] = column.server_default.persisted

args.append(
render_callable("Computed", repr(expression), kwargs=computed_kwargs)
)
elif isinstance(column.server_default, Identity):
args.append(repr(column.server_default))
elif isinstance(column.server_default, next_value):
# --------- NEU: Sequence/next_value ----------
seq = column.server_default.sequence
if seq is not None:
default_schema = "public"
kwargs["server_default"] = (
f"Sequence({seq.name!r}{f', schema={seq.schema!r}' if seq.schema else f', schema={default_schema!r}'}).next_value()"
)
else:
kwargs["server_default"] = "None # Sequence not detected"
elif column.server_default is not None:
kwargs["server_default"] = repr(column.server_default)
# -----------------------------------------------

comment = getattr(column, "comment", None)
if comment:
kwargs["comment"] = repr(comment)

return self.render_column_callable(is_table, *args, **kwargs)


TablesGenerator.render_column = clx_render_column # type: ignore[method-assign]


def clx_generate(self: "TablesGenerator") -> tuple[str, list[str] | None]:
self.generate_base()

Expand Down Expand Up @@ -816,12 +939,10 @@ def generate_alembic_utils_sequences(
# Hole alle aus DB
result: list[dict[str, Any]] = fetch_all_mappings(conn, sql, {"schema": schema})
# Finde alle, die von Tables verwaltet werden
managed = get_table_managed_sequences(self.metadata)
entities = [
parsed
for row in result
if row["sequence_name"] not in managed
and (parsed := parse_row_func(row, template_def, schema)) is not None
if (parsed := parse_row_func(row, template_def, schema)) is not None
]

code = [code for code, _ in entities]
Expand Down Expand Up @@ -1127,6 +1248,7 @@ def get_extension_object_names(
self.add_literal_import(module, name)

self.add_literal_import("sqlalchemy", "text")
self.add_literal_import("sqlalchemy", "FetchedValue")

return "\n\n".join(rendered), finalize_alembic_utils(
pg_alembic_definition, entities, entities_name
Expand Down
24 changes: 1 addition & 23 deletions src/sqlacodegen/seed_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any

from sqlalchemy import MetaData, select, text
from sqlalchemy import MetaData, select
from sqlalchemy.engine import Engine


Expand Down Expand Up @@ -111,33 +111,11 @@ def export_pgdata_py(
result.append(d)
data[name] = result

sequence_rows = conn.execute(
text("""
SELECT sequence_schema, sequence_name
FROM information_schema.sequences
WHERE sequence_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY sequence_schema, sequence_name
""")
).fetchall()
raw_sql_stmts = []
for row in sequence_rows:
schema = row.sequence_schema
seq_name = row.sequence_name
lastval = conn.execute(
text(f"SELECT last_value FROM {schema}.{seq_name}")
).scalar()
raw_sql_stmts.append(
f" SELECT setval('{schema}.{seq_name}', {lastval}, false);"
)
raw_sql_str = "\n".join(raw_sql_stmts)

seed_block, imports = data_as_code(data)
lines: list[str] = []
for imp in sorted(imports):
lines.append(imp)
lines.append("\n\nall_seeds = {\n" + seed_block + "\n}")

lines.append('\nall_seeds[\'sql_next_values\'] = """\n' + raw_sql_str + '\n"""\n')

with open(out_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
5 changes: 4 additions & 1 deletion src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
Table,
)

_re_postgresql_nextval_sequence = re.compile(r"nextval\('(.+)'::regclass\)")
_re_postgresql_nextval_sequence = re.compile(
r"nextval\(\(?'(?P<seq>[\w\.]+)'(::text)?\)?::regclass\)"
)

_re_postgresql_sequence_delimiter = re.compile(r'(.*?)([."]|$)')


Expand Down
Loading