diff --git a/src/sqlacodegen/cli.py b/src/sqlacodegen/cli.py index 7de594a4..ce18d7ae 100644 --- a/src/sqlacodegen/cli.py +++ b/src/sqlacodegen/cli.py @@ -25,15 +25,15 @@ except ImportError: pgvector = None -from sqlacodegen.generate_factory_fixtures import export_factory_fixtures from sqlacodegen.risclog_generators import ( parse_aggregate_row, parse_extension_row, parse_function_row, parse_policy_row, + parse_publication_row, parse_trigger_row, ) -from sqlacodegen.seed_export import export_pgdata_py, get_table_dependency_order +from sqlacodegen.seed_export import export_pgdata_py if sys.version_info < (3, 10): from importlib_metadata import entry_points, version @@ -245,6 +245,14 @@ class ExportDict(TypedDict, total=False): "parse_row_func": parse_extension_row, "file": "pg_extensions.py", }, + { + "title": "Publications", + "entities_varname": "all_publications", + "template": "ALEMBIC_PUBLICATION_TEMPLATE", + "statement": "ALEMBIC_PUBLICATION_STATEMENT", + "parse_row_func": parse_publication_row, + "file": "pg_publications.py", + }, ] # ----------- Export-Loop ------------ @@ -335,11 +343,16 @@ class ExportDict(TypedDict, total=False): # ----------- PGData SEED Export separat ------------ if args.outfile_dir: + all_view_names = set() + for schema in schemas: + all_view_names |= set(inspector.get_view_names(schema=schema)) + dest_pg_path = Path(str(parent), "pg_seeds.py") export_pgdata_py( engine=engine, metadata=metadata_tables, out_path=dest_pg_path, + view_table_names=all_view_names, ) print(f"PGData Seed geschrieben nach: {dest_pg_path.as_posix()}") @@ -366,19 +379,3 @@ def make_dynamic_models(metadata: MetaData) -> dict[str, type[Any]]: model = type(class_name, (Base,), {"__table__": table}) models_by_table[table.name] = model return models_by_table - - Base = getattr(generator, "base", None) - if Base is not None: - models = get_all_models(Base) - models_by_table = {m.__tablename__: m for m in models} - else: - models_by_table = make_dynamic_models(metadata_tables) - - dependency_order = get_table_dependency_order(metadata_tables) - - export_factory_fixtures( - models_by_table=models_by_table, - factories_path=Path(parent) / "factories.py", - dependency_order=dependency_order, - ) - print(f"Factories & Fixtures geschrieben nach: {parent.as_posix()}") diff --git a/src/sqlacodegen/generate_factory_fixtures.py b/src/sqlacodegen/generate_factory_fixtures.py deleted file mode 100644 index 9eb5d3d1..00000000 --- a/src/sqlacodegen/generate_factory_fixtures.py +++ /dev/null @@ -1,54 +0,0 @@ -import re -from collections.abc import Sequence -from pathlib import Path -from typing import Any - -FACTORY_HEADER = """\ -# AUTO-GENERATED BY sqlacodegen -from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory -from polyfactory.pytest_plugin import register_fixture -{model_imports} -""" - -FACTORY_TEMPLATE = """\ -@register_fixture(name="{fixture_name}") -class {class_name}Factory(SQLAlchemyFactory[{class_name}]): - __model__ = {class_name}{set_relationships} -""" - - -def camel_to_snake(name: str) -> str: - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def render_factory(model: type[Any]) -> str: - has_fk = bool(getattr(model.__table__, "foreign_keys", [])) - set_relationships = "\n __set_relationships__ = True" if has_fk else "" - class_name = model.__name__ - fixture_name = f"{camel_to_snake(class_name)}_factory" - return FACTORY_TEMPLATE.format( - fixture_name=fixture_name, - class_name=class_name, - set_relationships=set_relationships, - ) - - -def export_factory_fixtures( - models_by_table: dict[str, type[Any]], - factories_path: Path, - dependency_order: Sequence[str], -) -> None: - model_names = {models_by_table[table].__name__ for table in dependency_order} - import_statement = ( - "from risclog.claimxdb.database import (\n " - + ",\n ".join(sorted(model_names)) - + "\n)" - ) - factories_lines = [FACTORY_HEADER.format(model_imports=import_statement)] - - for table in dependency_order: - model = models_by_table[table] - factories_lines.append(render_factory(model)) - - factories_path.write_text("\n".join(factories_lines), encoding="utf-8") diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 7d9bc542..030ba93f 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -1225,6 +1225,21 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: column = column_attr.column rendered_column = self.render_column(column, column_attr.name != column.name) + is_uuid_pk = ( + column.primary_key + and getattr(column.type, "python_type", None) + in (str, bytes) # meistens str + and "Uuid" in str(column.type) + and not getattr(column, "default", None) + and not getattr(column, "server_default", None) + ) + if is_uuid_pk and "default=uuid4" not in rendered_column: + # default=uuid4 in den Column-Aufruf einfügen (vor letztem ")") + if rendered_column.endswith(")"): + # Einfachstes Pattern: vor ")" einfügen + rendered_column = rendered_column[:-1] + ", default=uuid4)" + self.add_literal_import("uuid", "uuid4") + def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]: column_type = column.type pre: list[str] = [] diff --git a/src/sqlacodegen/risclog_generators.py b/src/sqlacodegen/risclog_generators.py index 98a8efb1..cd8e0637 100644 --- a/src/sqlacodegen/risclog_generators.py +++ b/src/sqlacodegen/risclog_generators.py @@ -104,7 +104,29 @@ class {classname}(PortalObject): # type: ignore[misc] definition=\"\"\"{definition}\"\"\", ) """ - +ALEMBIC_PUBLICATION_TEMPLATE = """{varname} = PGPublication( + name={name!r}, + tables={tables!r}, + publish={publish!r}, +) +""" +ALEMBIC_PUBLICATION_STATEMENT = """ +SELECT + p.pubname, + array_remove(array_agg(pt.relname), NULL) as tables, + ( + CASE WHEN p.pubinsert THEN 'insert' ELSE '' END || + CASE WHEN p.pubupdate THEN ', update' ELSE '' END || + CASE WHEN p.pubdelete THEN ', delete' ELSE '' END || + CASE WHEN p.pubtruncate THEN ', truncate' ELSE '' + END + ) as publish +FROM + pg_publication p + LEFT JOIN pg_publication_rel pr ON pr.prpubid = p.oid + LEFT JOIN pg_class pt ON pt.oid = pr.prrelid +GROUP BY p.pubname, p.pubinsert, p.pubupdate, p.pubdelete, p.pubtruncate +""" ALEMBIC_FUNCTION_STATEMENT = """SELECT pg_get_functiondef(p.oid) AS func FROM @@ -126,7 +148,6 @@ class {classname}(PortalObject): # type: ignore[misc] p.proname; """ - ALEMBIC_POLICIES_STATEMENT = """SELECT pol.polname AS policy_name, ns.nspname AS schema_name, @@ -246,6 +267,29 @@ class {classname}(PortalObject): # type: ignore[misc] """ +def parse_publication_row( + row: dict[str, Any], + template_def: str, + schema: str | None, +) -> tuple[str, str] | None: + name = row.get("pubname") + tables = row.get("tables") or [] + if isinstance(tables, str): + tables = [t.strip() for t in tables.split(",") if t.strip()] + publish = row.get("publish") or "" + owner = row.get("owner") or None + + varname = f"{name}".lower() + code = template_def.format( + varname=varname, + name=name, + tables=tables, + publish=publish, + owner=owner, + ) + return code, varname + + def finalize_alembic_utils( pg_alembic_definition: list[str], entities: list[str], @@ -259,6 +303,7 @@ def finalize_alembic_utils( "all_sequences": "from alembic_utils.pg_sequence import PGSequence", "all_extensions": "from alembic_utils.pg_extension import PGExtension", "all_aggregates": "from alembic_utils.pg_aggregate import PGAggregate", + "all_publications": "from risclog.claimxdb.alembic.object_ops import PGPublication", } import_stmt = imports.get( entities_name or "all_views", @@ -297,12 +342,15 @@ def parse_function_row( schema = schema or "public" name = name.lower() - return template_def.format( - varname=name, - schema=schema, - signature=signature, - definition=unescape_sql_string(squash_whitespace(definition)), - ), name + return ( + template_def.format( + varname=name, + schema=schema, + signature=signature, + definition=unescape_sql_string(squash_whitespace(definition)), + ), + name, + ) def parse_policy_row( @@ -394,7 +442,6 @@ def parse_aggregate_row( initcond = row.get("initcond") schema_val = schema or row.get("schema") or "public" - # Baue die Definition als lesbare String-Config: definition_parts = [] if sfunc: definition_parts.append(f"SFUNC = {sfunc}") @@ -449,9 +496,11 @@ def parse_sequence_row( f"MINVALUE {row['minimum_value']}", f"MAXVALUE {row['maximum_value']}", f"CACHE {row['cache_size']}", - "CYCLE" - if str(row.get("cycle", "")).lower() in ("yes", "true", "on", "1") - else "NO CYCLE", + ( + "CYCLE" + if str(row.get("cycle", "")).lower() in ("yes", "true", "on", "1") + else "NO CYCLE" + ), ] definition = "\n ".join(parts) @@ -642,48 +691,60 @@ def clx_generate_base(self: "TablesGenerator") -> None: TablesGenerator.generate_base = clx_generate_base # type: ignore[method-assign] +def unqualify(colname: str) -> str: + if isinstance(colname, str): + return colname.split(".")[-1] + return str(colname) + + def clx_render_index(self: "TablesGenerator", index: Index) -> str: - elements = [] + from sqlalchemy.sql.elements import TextClause + + args = [repr(index.name)] + kwargs: dict[str, Any] = {} opclass_map = {} - if index.columns: + # --- Columns --- + if getattr(index, "columns", None) and len(index.columns) > 0: for col in index.columns: - elements.append(repr(col.name)) - + args.append(repr(unqualify(col.name))) + # Operator-Class GIN/TRGM if ( "postgresql" in index.dialect_options and index.dialect_options["postgresql"].get("using") == "gin" - and hasattr(col, "type") ): coltype = getattr(col.type, "python_type", None) if isinstance( col.type, (satypes.String, satypes.Text, satypes.Unicode) ) or (coltype and coltype is str): - opclass_map[col.name] = "gin_trgm_ops" - - elif getattr(index, "expressions", None): + opclass_map[unqualify(col.name)] = "gin_trgm_ops" + # --- Expressions/TextClause --- + elif getattr(index, "expressions", None) and len(index.expressions) > 0: for expr in index.expressions: - expr_str = str(expr).strip() - elements.append(f"text({expr_str!r})") - - if ( - "postgresql" in index.dialect_options - and index.dialect_options["postgresql"].get("using") == "gin" - ): + if isinstance(expr, TextClause): + expr_str = str(expr) + # GIN/TRGM als Suffix if ( - "::tsvector" not in expr_str - and "array" not in expr_str.lower() - and "json" not in expr_str.lower() + "postgresql" in index.dialect_options + and index.dialect_options["postgresql"].get("using") == "gin" + and not expr_str.rstrip().endswith("gin_trgm_ops") ): - opclass_map[expr_str] = "gin_trgm_ops" - - if not elements: - print( - f"# WARNING: Skipped index {getattr(index, 'name', None)!r} on table {getattr(index.table, 'name', None)!r} (no columns or expressions)." - ) - return "" - - kwargs: dict[str, Any] = {} + expr_str = f"{expr_str} gin_trgm_ops" + args.append(f"text({expr_str!r})") + else: + expr_str = str(expr) + m = re.match(r"^upper\(\((\w+)\)::text\)$", expr_str) + if ( + m + and "postgresql" in index.dialect_options + and index.dialect_options["postgresql"].get("using") == "gin" + ): + args.append(f"text('upper(({m.group(1)})::text) gin_trgm_ops')") + else: + args.append(f"text({expr_str!r})") + else: + # Fallback + pass if index.unique: kwargs["unique"] = True @@ -695,11 +756,10 @@ def clx_render_index(self: "TablesGenerator", index: Index) -> str: kwargs["postgresql_using"] = ( f"'{using}'" if isinstance(using, str) else using ) - if opclass_map: kwargs["postgresql_ops"] = opclass_map - return render_callable("Index", repr(index.name), *elements, kwargs=kwargs) + return render_callable("Index", *args, kwargs=kwargs) TablesGenerator.render_index = clx_render_index # type: ignore[method-assign] @@ -708,9 +768,12 @@ def clx_render_index(self: "TablesGenerator", index: Index) -> str: def clx_render_table(self: "TablesGenerator", table: Table) -> str: args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"] kwargs: dict[str, object] = {} + + # Columns for column in table.columns: args.append(self.render_column(column, True, is_table=True)) + # Constraints for constraint in sorted(table.constraints, key=get_constraint_sort_key): if uses_default_name(constraint): if isinstance(constraint, PrimaryKeyConstraint): @@ -720,18 +783,25 @@ def clx_render_table(self: "TablesGenerator", table: Table) -> str: continue args.append(self.render_constraint(constraint)) + # Indices for index in sorted(table.indexes, key=lambda i: str(i.name or "")): - if len(index.columns) > 1 or not uses_default_name(index): - idx_code = self.render_index(index) - if idx_code.strip() and idx_code is not None: - args.append(idx_code) + orig_columns = getattr(index, "columns", []) + if orig_columns: + table.indexes.remove(index) + columns = [table.c[unqualify(col.name)] for col in orig_columns] + new_index = Index(index.name, *columns, **index.kwargs) + table.append_constraint(new_index) + idx_code = self.render_index(index) + if idx_code.strip() and idx_code is not None: + args.append(idx_code) if table.schema: - kwargs["schema"] = repr(table.schema) + kwargs["schema"] = table.schema + # Table comment table_comment = getattr(table, "comment", None) if table_comment: - kwargs["comment"] = repr(table.comment) + kwargs["comment"] = table_comment return render_callable("Table", *args, kwargs=kwargs, indentation=" ") @@ -911,7 +981,6 @@ def get_table_managed_sequences(metadata: MetaData) -> set[str]: for column in table.columns: default = getattr(column, "default", None) if default is not None: - # Sequence kann als Default o. direkt als ServerDefault hinterlegt sein if hasattr(default, "name"): seq_names.add(default.name) if hasattr(column, "sequence") and column.sequence is not None: @@ -935,10 +1004,8 @@ def generate_alembic_utils_sequences( sql = globals()[statement] template_def = globals()[template] - - # Hole alle aus DB result: list[dict[str, Any]] = fetch_all_mappings(conn, sql, {"schema": schema}) - # Finde alle, die von Tables verwaltet werden + entities = [ parsed for row in result @@ -1014,6 +1081,7 @@ def render_view_classes( has_id = False for col in table.columns: + sa_type = sa_type_from_column(col) if col.name == "id": has_id = True @@ -1071,11 +1139,11 @@ def render_table_args(self, table: Table) -> str: continue args.append(self.render_constraint(constraint)) + # NEU: ALLE Indexe (egal ob "special" oder nicht) for index in sorted(table.indexes, key=lambda i: str(i.name or "")): - if len(index.columns) > 1 or not uses_default_name(index): - idx_code = self.render_index(index) - if idx_code.strip() and idx_code is not None: - args.append(idx_code) + idx_code = self.render_index(index) + if idx_code.strip() and idx_code is not None: + args.append(idx_code) if table.schema: kwargs["schema"] = table.schema @@ -1137,14 +1205,14 @@ def render_models(self, models: list[Model]) -> tuple[str, list[str] | None]: # "ARRAY": ("sqlalchemy", "ARRAY"), } - # Ergänzung: Ermittlung aller Extension-Objekte (Tabellen, Views, etc.) def get_extension_object_names( conn: Connection, schemas: set[str | None] ) -> Any: extension_objs = set() for schema in schemas: result = conn.execute( - text(""" + text( + """ SELECT c.relname FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid @@ -1155,13 +1223,13 @@ def get_extension_object_names( JOIN pg_extension e ON d.refobjid = e.oid WHERE d.objid = c.oid AND d.deptype = 'e' ) - """), + """ + ), {"schema": schema}, ) extension_objs |= {row[0] for row in result} return extension_objs - # Hole nur einmal die Extension-Objekte aus der DB conn = self.bind.connect() if hasattr(self.bind, "connect") else self.bind EXTENSION_OBJECTS = get_extension_object_names(conn, schemas) @@ -1206,13 +1274,12 @@ def get_extension_object_names( schema = table.schema schema_views = views_by_schema.get(schema, set()) - # **Hier: Filter für System- und Extension-Objekte** if table.schema and table.schema.startswith("pg_"): - continue # Skip Postgres System-Views + continue if table.name.startswith("pg_"): - continue # Skip system views + continue if table.name in EXTENSION_OBJECTS: - continue # Skip Extension-Objekte + continue for col in table.columns: sa_type = sa_type_from_column(col) @@ -1250,6 +1317,8 @@ def get_extension_object_names( self.add_literal_import("sqlalchemy", "text") self.add_literal_import("sqlalchemy", "FetchedValue") - return "\n\n".join(rendered), finalize_alembic_utils( - pg_alembic_definition, entities, entities_name - ) if pg_alembic_definition else None + return "\n\n".join(rendered), ( + finalize_alembic_utils(pg_alembic_definition, entities, entities_name) + if pg_alembic_definition + else None + ) diff --git a/src/sqlacodegen/seed_export.py b/src/sqlacodegen/seed_export.py index 6aa85d93..61e15f40 100644 --- a/src/sqlacodegen/seed_export.py +++ b/src/sqlacodegen/seed_export.py @@ -67,6 +67,9 @@ def get_table_dependency_order(metadata: MetaData) -> list[str]: from collections import defaultdict graph: dict[str, set[str]] = defaultdict(set) + for table in metadata.tables.values(): + graph[table.name] + for table in metadata.tables.values(): name = table.name for fk in table.foreign_keys: @@ -74,25 +77,38 @@ def get_table_dependency_order(metadata: MetaData) -> list[str]: if parent != name: graph[name].add(parent) - visited: set[str] = set() - result: list[str] = [] + try: + from graphlib import TopologicalSorter - def visit(node: str) -> None: - if node in visited: - return - visited.add(node) - for dep in graph[node]: - visit(dep) - result.append(node) + ts = TopologicalSorter(graph) + order = list(ts.static_order()) + except ImportError: + visited: set[str] = set() + result: list[str] = [] - for table in metadata.tables.values(): - visit(table.name) - return result[::-1] + def visit(node: str) -> None: + if node in visited: + return + visited.add(node) + for dep in graph[node]: + visit(dep) + result.append(node) + + for node in graph: + visit(node) + order = result[::-1] + + return order def export_pgdata_py( - engine: Engine, metadata: MetaData, out_path: Path, max_rows: int | None = None + engine: Engine, + metadata: MetaData, + out_path: Path, + max_rows: int | None = None, + view_table_names: set[str] | None = None, ) -> None: + view_table_names = view_table_names or set() order = get_table_dependency_order(metadata) data: dict[str, list[dict[str, Any]]] = {} @@ -100,6 +116,8 @@ def export_pgdata_py( for name in order: if name not in metadata.tables: continue + if name in view_table_names: + continue table = metadata.tables[name] stmt = select(table) if max_rows is not None: