Skip to content

Commit 2582834

Browse files
[WIP] Add rollback for postgres.
1 parent d792f27 commit 2582834

File tree

2 files changed

+19
-29
lines changed

2 files changed

+19
-29
lines changed

pandas/io/sql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,8 @@ def _exists_temporary(self):
983983
_ = self.pd_sql.read_query(query)
984984
return True
985985
except ProgrammingError:
986+
# Some DBMS (e.g. postgres) require a rollback after a caught exception
987+
self.pd_sql.execute("rollback")
986988
return False
987989

988990
def exists(self):

pandas/tests/io/test_sql.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,6 @@ def test_read_procedure(conn, request):
12471247
# GH 7324
12481248
# Although it is more an api test, it is added to the
12491249
# mysql tests as sqlite does not have stored procedures
1250-
from sqlalchemy import text
12511250
from sqlalchemy.engine import Engine
12521251

12531252
df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]})
@@ -2387,7 +2386,6 @@ def test_read_sql_delegate(conn, request):
23872386

23882387
def test_not_reflect_all_tables(sqlite_conn):
23892388
conn = sqlite_conn
2390-
from sqlalchemy import text
23912389
from sqlalchemy.engine import Engine
23922390

23932391
# create invalid table
@@ -2532,7 +2530,6 @@ def test_query_by_text_obj(conn, request):
25322530
# WIP : GH10846
25332531
conn_name = conn
25342532
conn = request.getfixturevalue(conn)
2535-
from sqlalchemy import text
25362533

25372534
if "postgres" in conn_name:
25382535
name_text = text('select * from iris where "Name"=:name')
@@ -3199,7 +3196,6 @@ def test_get_schema_create_table(conn, request, test_frame3):
31993196

32003197
conn = request.getfixturevalue(conn)
32013198

3202-
from sqlalchemy import text
32033199
from sqlalchemy.engine import Engine
32043200

32053201
tbl = "test_get_schema_create_table"
@@ -4357,7 +4353,7 @@ def test_xsqlite_if_exists(sqlite_buildin):
43574353
drop_table(table_name, sqlite_buildin)
43584354

43594355

4360-
@pytest.mark.parametrize("conn", mysql_connectable)
4356+
@pytest.mark.parametrize("conn", mysql_connectable + postgresql_connectable)
43614357
def test_exists_temporary_table(conn, test_frame1, request):
43624358
conn = request.getfixturevalue(conn)
43634359

@@ -4376,26 +4372,22 @@ def test_exists_temporary_table(conn, test_frame1, request):
43764372
assert True if table.exists() else False
43774373

43784374

4379-
@pytest.mark.parametrize("conn", mysql_connectable)
4375+
@pytest.mark.parametrize("conn", mysql_connectable + postgresql_connectable)
43804376
def test_to_sql_temporary_table_replace(conn, test_frame1, request):
43814377
conn = request.getfixturevalue(conn)
43824378

4383-
query = """
4384-
CREATE TEMPORARY TABLE test_frame1 (
4385-
`INDEX` TEXT,
4386-
A FLOAT(53),
4387-
B FLOAT(53),
4388-
C FLOAT(53),
4389-
D FLOAT(53)
4390-
)
4391-
"""
4392-
43934379
if isinstance(conn, Connection):
43944380
con = conn
43954381
else:
43964382
con = conn.connect()
43974383

4398-
con.execute(text(query))
4384+
test_frame1.to_sql(
4385+
name="test_frame1",
4386+
con=con,
4387+
if_exists="fail",
4388+
index=False,
4389+
prefixes=["TEMPORARY"],
4390+
)
43994391

44004392
test_frame1.to_sql(
44014393
name="test_frame1",
@@ -4410,26 +4402,22 @@ def test_to_sql_temporary_table_replace(conn, test_frame1, request):
44104402
assert_frame_equal(test_frame1, df_test)
44114403

44124404

4413-
@pytest.mark.parametrize("conn", mysql_connectable)
4405+
@pytest.mark.parametrize("conn", mysql_connectable + postgresql_connectable)
44144406
def test_to_sql_temporary_table_fail(conn, test_frame1, request):
44154407
conn = request.getfixturevalue(conn)
44164408

4417-
query = """
4418-
CREATE TEMPORARY TABLE test_frame1 (
4419-
`INDEX` TEXT,
4420-
A FLOAT(53),
4421-
B FLOAT(53),
4422-
C FLOAT(53),
4423-
D FLOAT(53)
4424-
)
4425-
"""
4426-
44274409
if isinstance(conn, Connection):
44284410
con = conn
44294411
else:
44304412
con = conn.connect()
44314413

4432-
con.execute(text(query))
4414+
test_frame1.to_sql(
4415+
name="test_frame1",
4416+
con=con,
4417+
if_exists="fail",
4418+
index=False,
4419+
prefixes=["TEMPORARY"],
4420+
)
44334421

44344422
with pytest.raises(ValueError, match=r"Table 'test_frame1' already exists."):
44354423
test_frame1.to_sql(

0 commit comments

Comments
 (0)