diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index c13ab242c..b388de2c2 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -418,6 +418,10 @@ def __init__( self._log_max_query_length = connection.log_max_query_length self._inner_cursor: SnowflakeCursorBase | None = None self._prefetch_hook = None + self._stats_data: dict[str, int] | None = ( + None # Stores stats from response for DML operations + ) + self._rownumber: int | None = None self.reset() @@ -454,6 +458,26 @@ def _description_internal(self) -> list[ResultMetadataV2]: def rowcount(self) -> int | None: return self._total_rowcount if self._total_rowcount >= 0 else None + @property + def stats(self) -> QueryResultStats | None: + """Returns detailed rows affected statistics for DML operations. + + Returns a NamedTuple with fields: + - num_rows_inserted: Number of rows inserted + - num_rows_deleted: Number of rows deleted + - num_rows_updated: Number of rows updated + + Returns None on each position if no DML stats are available. + """ + if self._stats_data is None: + return QueryResultStats(None, None, None, None) + return QueryResultStats( + num_rows_inserted=self._stats_data.get("numRowsInserted", None), + num_rows_deleted=self._stats_data.get("numRowsDeleted", None), + num_rows_updated=self._stats_data.get("numRowsUpdated", None), + num_dml_duplicates=self._stats_data.get("numDmlDuplicates", None), + ) + @property def rownumber(self) -> int | None: return self._rownumber if self._rownumber >= 0 else None @@ -1201,6 +1225,10 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None: self._rownumber = -1 self._result_state = ResultState.VALID + # Extract stats object if available (for DML operations like CTAS, INSERT, UPDATE, DELETE) + self._stats_data = data.get("stats", None) + logger.debug(f"Execution DML stats: {self.stats}") + # don't update the row count when the result is returned from `describe` method if is_dml and "rowset" in data and len(data["rowset"]) > 0: updated_rows = 0 @@ -2007,3 +2035,17 @@ def __getattr__(name): ) return None raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class QueryResultStats(NamedTuple): + """ + Statistics for rows affected by a DML operation. + None value expresses particular statistic being unknown - not returned by the backend service. + + Added in the first place to expose DML data of CTAS statements - SNOW-295953 + """ + + num_rows_inserted: int | None = None + num_rows_deleted: int | None = None + num_rows_updated: int | None = None + num_dml_duplicates: int | None = None diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index b5d490d34..0fed8aa9e 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -1067,7 +1067,7 @@ def test_client_fetch_threads_setting(conn_cnx): @pytest.mark.skipolddriver @pytest.mark.parametrize("disable_request_pooling", [True, False]) def test_ocsp_and_rest_pool_isolation(conn_cnx, disable_request_pooling): - """Each connection’s SessionManager is isolated; OCSP picks the right one.""" + """Each connection's SessionManager is isolated; OCSP picks the right one.""" from snowflake.connector.ssl_wrap_socket import get_current_session_manager # @@ -1896,3 +1896,86 @@ def test_snowflake_version(): assert re.match( version_pattern, conn.snowflake_version ), f"snowflake_version should match pattern 'x.y.z', but got '{conn.snowflake_version}'" + + +@pytest.mark.skipolddriver +def test_ctas_stats(conn_cnx): + """Test that cursor.rowcount and cursor.stats work for CTAS operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for CTAS, got {cur.rowcount}" + # stats should contain the details as a NamedTuple + assert cur.stats is not None, "stats should not be None for CTAS" + assert ( + cur.stats.num_rows_inserted == 3 + ), f"Expected num_rows_inserted=3, got {cur.stats.num_rows_inserted}" + assert cur.stats.num_rows_deleted == 0 + assert cur.stats.num_rows_updated == 0 + assert cur.stats.num_dml_duplicates == 0 + + +@pytest.mark.skipolddriver +def test_create_view_stats(conn_cnx): + """Test that cursor.stats returns None fields for VIEW operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp view test_view_stats as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for VIEW, got {cur.rowcount}" + # VIEW operations don't return DML stats, all fields should be None + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +def test_cvas_separate_cursors_stats(conn_cnx): + """Test cursor.stats with CVAS in separate cursor from the one used for CTAS of the table.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_table (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + with conn.cursor() as cur: + cur.execute("create temp view test_view as select col1 from test_table") + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +def test_cvas_one_cursor_stats(conn_cnx): + """Test cursor.stats with CVAS in the same cursor - make sure it's cleaned up after usage.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + cur.execute( + "create temp view test_view as select col1 from test_ctas_stats" + ) + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None