Skip to content
42 changes: 42 additions & 0 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def __init__(
self._log_max_query_length = connection.log_max_query_length
self._inner_cursor: SnowflakeCursorBase | None = None
self._prefetch_hook = None
self._stats_data: dict[str, int] | None = (
None # Stores stats from response for DML operations
)

self._rownumber: int | None = None

self.reset()
Expand Down Expand Up @@ -454,6 +458,26 @@ def _description_internal(self) -> list[ResultMetadataV2]:
def rowcount(self) -> int | None:
return self._total_rowcount if self._total_rowcount >= 0 else None

@property
def rows_affected(self) -> RowsAffected | None:
"""Returns detailed rows affected statistics for DML operations.

Returns a NamedTuple with fields:
- num_rows_inserted: Number of rows inserted
- num_rows_deleted: Number of rows deleted
- num_rows_updated: Number of rows updated

Returns None on each position if no DML stats are available.
"""
if self._stats_data is None:
return RowsAffected(None, None, None, None)
return RowsAffected(
num_rows_inserted=self._stats_data.get("numRowsInserted", None),
num_rows_deleted=self._stats_data.get("numRowsDeleted", None),
num_rows_updated=self._stats_data.get("numRowsUpdated", None),
num_dml_duplicates=self._stats_data.get("numDmlDuplicates", None),
)

@property
def rownumber(self) -> int | None:
return self._rownumber if self._rownumber >= 0 else None
Expand Down Expand Up @@ -1201,6 +1225,10 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
self._rownumber = -1
self._result_state = ResultState.VALID

# Extract rows_affected from stats object if available (for DML operations like CTAS, INSERT, UPDATE, DELETE)
self._stats_data = data.get("stats", None)
logger.debug(f"Execution stats: {self.rows_affected}")

# don't update the row count when the result is returned from `describe` method
if is_dml and "rowset" in data and len(data["rowset"]) > 0:
updated_rows = 0
Expand Down Expand Up @@ -2007,3 +2035,17 @@ def __getattr__(name):
)
return None
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


class RowsAffected(NamedTuple):
"""
Statistics for rows affected by a DML operation.
None value expresses particular statistic being unknown - not returned by the backend service.

Added in the first place to expose DML data of CTAS statements - SNOW-295953
"""

num_rows_inserted: int | None = None
num_rows_deleted: int | None = None
num_rows_updated: int | None = None
num_dml_duplicates: int | None = None
87 changes: 86 additions & 1 deletion test/integ/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def test_client_fetch_threads_setting(conn_cnx):
@pytest.mark.skipolddriver
@pytest.mark.parametrize("disable_request_pooling", [True, False])
def test_ocsp_and_rest_pool_isolation(conn_cnx, disable_request_pooling):
"""Each connections SessionManager is isolated; OCSP picks the right one."""
"""Each connection's SessionManager is isolated; OCSP picks the right one."""
from snowflake.connector.ssl_wrap_socket import get_current_session_manager

#
Expand Down Expand Up @@ -1892,3 +1892,88 @@ def test_snowflake_version():
assert re.match(
version_pattern, conn.snowflake_version
), f"snowflake_version should match pattern 'x.y.z', but got '{conn.snowflake_version}'"


@pytest.mark.skipolddriver
def test_ctas_rows_affected_from_stats(conn_cnx):
"""Test that cursor.rowcount and cursor.rows_affected work for CTAS operations."""
with conn_cnx() as conn:
with conn.cursor() as cur:
cur.execute(
"create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)"
)
assert (
cur.rowcount == 1
), f"Expected rowcount 1 for CTAS, got {cur.rowcount}"
# rows_affected should contain the detailed stats as a NamedTuple
assert (
cur.rows_affected is not None
), "rows_affected should not be None for CTAS"
assert (
cur.rows_affected.num_rows_inserted == 3
), f"Expected num_rows_inserted=3, got {cur.rows_affected.num_rows_inserted}"
assert cur.rows_affected.num_rows_deleted == 0
assert cur.rows_affected.num_rows_updated == 0
assert cur.rows_affected.num_dml_duplicates == 0


@pytest.mark.skipolddriver
def test_create_view_rows_affected_from_stats(conn_cnx):
"""Test that cursor.rows_affected returns None fields for VIEW operations."""
with conn_cnx() as conn:
with conn.cursor() as cur:
cur.execute(
"create temp view test_view_stats as select col1 from values (1), (2), (3) as t(col1)"
)
assert (
cur.rowcount == 1
), f"Expected rowcount 1 for VIEW, got {cur.rowcount}"
# VIEW operations don't return DML stats, all fields should be None
assert cur.rows_affected is not None
assert cur.rows_affected.num_rows_inserted is None
assert cur.rows_affected.num_rows_deleted is None
assert cur.rows_affected.num_rows_updated is None
assert cur.rows_affected.num_dml_duplicates is None


@pytest.mark.skipolddriver
def test_cvas_separate_cursors_rows_affected_from_stats(conn_cnx):
"""Test cursor.rows_affected with CVAS in separate cursor from the one used for CTAS of the table."""
with conn_cnx() as conn:
with conn.cursor() as cur:
cur.execute(
"create temp table test_table (col1 int) as select col1 from values (1), (2), (3) as t(col1)"
)
with conn.cursor() as cur:
cur.execute("create temp view test_view as select col1 from test_table")
assert (
cur.rowcount == 1
), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend"
# VIEW operations don't return DML stats
assert cur.rows_affected is not None
assert cur.rows_affected.num_rows_inserted is None
assert cur.rows_affected.num_rows_deleted is None
assert cur.rows_affected.num_rows_updated is None
assert cur.rows_affected.num_dml_duplicates is None


@pytest.mark.skipolddriver
def test_cvas_one_cursor_rows_affected_from_stats(conn_cnx):
"""Test cursor.rows_affected with CVAS in the same cursor - make sure it's cleaned up after usage."""
with conn_cnx() as conn:
with conn.cursor() as cur:
cur.execute(
"create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)"
)
cur.execute(
"create temp view test_view as select col1 from test_ctas_stats"
)
assert (
cur.rowcount == 1
), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend"
# VIEW operations don't return DML stats
assert cur.rows_affected is not None
assert cur.rows_affected.num_rows_inserted is None
assert cur.rows_affected.num_rows_deleted is None
assert cur.rows_affected.num_rows_updated is None
assert cur.rows_affected.num_dml_duplicates is None
Loading