Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions sqlite_minutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def close(self):
def get_last_rowid(self):
res = next(self.execute('SELECT last_insert_rowid()'), None)
if res is None: return None
return int(res[0])
return int(res[0])

@contextlib.contextmanager
def ensure_autocommit_off(self):
Expand Down Expand Up @@ -845,14 +845,23 @@ def sort_key(p):
)

column_defs = []
# All minidata tables get a primary key: https://docs.fastht.ml/explains/minidataapi.html#creating-tables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minutils does not support the spec - that's fastlite. We shouldn't add a pk here automatically IMO, unless I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work to get tests to pass without it.

column_names = [x[0] for x in column_items]
if pk is None and 'id' not in column_names:
column_items.insert(0, ('id', int))
column_names.insert(0, 'id')
if pk is None:
pk = ['id']

# ensure pk is a tuple
single_pk = None
if isinstance(pk, list) and len(pk) == 1 and isinstance(pk[0], str):
pk = pk[0]
if isinstance(pk, str):
single_pk = pk
if pk not in [c[0] for c in column_items]:
if pk not in column_names:
column_items.insert(0, (pk, int))

for column_name, column_type in column_items:
column_extras = []
if column_name == single_pk:
Expand Down Expand Up @@ -945,6 +954,15 @@ def create_table(
if transform and self[name].exists():
table = cast(Table, self[name])
should_transform = False
# Has the primary key changed?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought changing the pk already worked? IIRC I used that feature during development of solveit...

current_pks = table.pks
desired_pk = None
if isinstance(pk, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know, with a ternary op, these 4 lines would be just one line... ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh n/m I see this was just a move of some lines Simon wrote.

desired_pk = [pk]
elif pk:
desired_pk = list(pk)
if desired_pk and current_pks != desired_pk:
should_transform = True
# First add missing columns and figure out columns to drop
existing_columns = table.columns_dict
missing_columns = dict(
Expand All @@ -955,6 +973,14 @@ def create_table(
columns_to_drop = [
column for column in existing_columns if column not in columns
]
# If no primary key was specified and id was added automatically, we prevent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can remove this then?

# it from being deleted here or Sqlite will complain that a primary key is
# being dropped. We delete the ID column once a new PK has been set.
delete_id_column = False
if table.pks == ['id'] and 'id' in columns_to_drop:
columns_to_drop.remove('id')
delete_id_column = True

if columns_to_drop:
for col_name in columns_to_drop: table.drop_column(col_name)
if missing_columns:
Expand All @@ -968,15 +994,6 @@ def create_table(
and list(existing_columns)[: len(column_order)] != column_order
):
should_transform = True
# Has the primary key changed?
current_pks = table.pks
desired_pk = None
if isinstance(pk, str):
desired_pk = [pk]
elif pk:
desired_pk = list(pk)
if desired_pk and current_pks != desired_pk:
should_transform = True
# Any not-null changes?
current_not_null = {c.name for c in table.columns if c.notnull}
desired_not_null = set(not_null) if not_null else set()
Expand All @@ -994,6 +1011,9 @@ def create_table(
defaults=defaults,
pk=pk,
)
# There has been set a primary key that isn't ['id'].It is now safe to drop the ID column.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and remove this.

if delete_id_column and table.pks != ['id']:
table.drop_column('id')
return table
sql = self.create_table_sql(
name=name,
Expand Down Expand Up @@ -1463,7 +1483,7 @@ def pks(self) -> List[str]:
"Primary key columns for this table."
names = [column.name for column in self.columns if int(column.is_pk)]
if not names:
names = ["rowid"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to force people to not use rowid do we?

names = ["id"]
return names

@property
Expand Down Expand Up @@ -2183,7 +2203,7 @@ def add_column(
fk_col = pks[0].name
fk_col_type = pks[0].type
else:
fk_col = "rowid"
fk_col = "id"
fk_col_type = "INTEGER"
if col_type is None:
col_type = str
Expand Down Expand Up @@ -2925,6 +2945,11 @@ def insert_chunk(
ignore,
)

if isinstance(pk, Union[str, None]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so this one could definitely be a ternary op!

pks = [pk]
else:
pks = pk

records = []
for query, params in queries_and_params:
try:
Expand Down Expand Up @@ -2989,17 +3014,32 @@ def insert_chunk(
# around for multiple queries/records are returned for what should
# be single SQL call operations.
self.last_pk = records[0][hash_id]
elif (rid := self.db.get_last_rowid()) is not None:
self.last_pk = self.last_rowid = rid
# self.last_rowid will be 0 if a "INSERT OR IGNORE" happened
elif len(records) > 0 and any(pks):
# Pks provided as an argument
last_record = records[-1]
self.last_pk = tuple(last_record[pk] for pk in pks)
if len(self.last_pk) == 1:
self.last_pk = self.last_pk[0]

if (hash_id or pk) and not upsert:
row = list(self.rows_where("rowid = ?", [rid]))[0]
if hash_id:
self.last_pk = row[hash_id]
self.last_pk = last_record[hash_id]
elif isinstance(pk, str):
self.last_pk = row[pk]
self.last_pk = last_record[pk]
else:
self.last_pk = tuple(row[p] for p in pk)
self.last_pk = tuple(last_record[p] for p in pk)
elif len(records):
# No pks provided, so we use the table's defaults
last_record = records[-1]
self.last_pk = tuple(last_record[pk] for pk in self.pks)
if len(self.last_pk) == 1:
self.last_pk = self.last_pk[0]

# Setting last_rowid to preserve API so we don't break backwards
# compatibility for users.
# TODO: Consider turning this into a property that warns upon usage
self.last_rowid = self.last_pk

return records

def insert(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_memory_name():
db1 = Database(memory_name="shared")
db2 = Database(memory_name="shared")
db1["dogs"].insert({"name": "Cleo"})
assert list(db2["dogs"].rows) == [{"name": "Cleo"}]
assert list(db2["dogs"].rows) == [{'id': 1, "name": "Cleo"}]


def test_sqlite_version():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_conversions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
def test_insert_conversion(fresh_db):
table = fresh_db["table"]
table.insert({"foo": "bar"}, conversions={"foo": "upper(?)"})
assert [{"foo": "BAR"}] == list(table.rows)
assert [{'id':1, "foo": "BAR"}] == list(table.rows)


def test_insert_all_conversion(fresh_db):
table = fresh_db["table"]
table.insert_all([{"foo": "bar"}], conversions={"foo": "upper(?)"})
assert [{"foo": "BAR"}] == list(table.rows)
assert [{'id':1, "foo": "BAR"}] == list(table.rows)


def test_upsert_conversion(fresh_db):
Expand Down Expand Up @@ -38,4 +38,4 @@ def test_update_conversion(fresh_db):
def test_table_constructor_conversion(fresh_db):
table = fresh_db.table("table", conversions={"bar": "upper(?)"})
table.insert({"bar": "baz"})
assert [{"bar": "BAZ"}] == list(table.rows)
assert [{'id':1,"bar": "BAZ"}] == list(table.rows)
Loading