Skip to content

Commit 865ace5

Browse files
committed
sqlite: implement async Database.prototype.exec
1 parent fa34979 commit 865ace5

File tree

4 files changed

+183
-27
lines changed

4 files changed

+183
-27
lines changed

src/node_sqlite.cc

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,52 @@ class CustomAggregate {
424424
Global<Function> result_fn_;
425425
};
426426

427+
template <typename T>
428+
class SQLiteAsyncTask : public ThreadPoolWork {
429+
public:
430+
explicit SQLiteAsyncTask(
431+
Environment* env,
432+
Database* db,
433+
Local<Promise::Resolver> resolver,
434+
std::function<T()> work,
435+
std::function<void(T, Local<Promise::Resolver>)> after)
436+
: ThreadPoolWork(env, "node_sqlite_async_work"),
437+
env_(env),
438+
db_(db),
439+
work_(work),
440+
after_(after) {
441+
resolver_.Reset(env->isolate(), resolver);
442+
}
443+
444+
void DoThreadPoolWork() override {
445+
if (work_) {
446+
result_ = work_();
447+
}
448+
}
449+
450+
void AfterThreadPoolWork(int status) override {
451+
Isolate* isolate = env_->isolate();
452+
HandleScope handle_scope(isolate);
453+
Local<Promise::Resolver> resolver =
454+
Local<Promise::Resolver>::New(isolate, resolver_);
455+
456+
if (after_) {
457+
after_(result_, resolver);
458+
Finalize();
459+
}
460+
}
461+
462+
void Finalize() { db_->RemoveAsyncTask(this); }
463+
464+
private:
465+
Environment* env_;
466+
Database* db_;
467+
Global<Promise::Resolver> resolver_;
468+
std::function<T()> work_ = nullptr;
469+
std::function<void(T, Local<Promise::Resolver>)> after_ = nullptr;
470+
T result_;
471+
};
472+
427473
class BackupJob : public ThreadPoolWork {
428474
public:
429475
explicit BackupJob(Environment* env,
@@ -650,10 +696,10 @@ void UserDefinedFunction::xDestroy(void* self) {
650696
}
651697

652698
Database::Database(Environment* env,
653-
Local<Object> object,
654-
DatabaseOpenConfiguration&& open_config,
655-
bool open,
656-
bool allow_load_extension)
699+
Local<Object> object,
700+
DatabaseOpenConfiguration&& open_config,
701+
bool open,
702+
bool allow_load_extension)
657703
: BaseObject(env, object), open_config_(std::move(open_config)) {
658704
MakeWeak();
659705
connection_ = nullptr;
@@ -674,6 +720,14 @@ void Database::RemoveBackup(BackupJob* job) {
674720
backups_.erase(job);
675721
}
676722

723+
void Database::AddAsyncTask(ThreadPoolWork* async_task) {
724+
async_tasks_.insert(async_task);
725+
}
726+
727+
void Database::RemoveAsyncTask(ThreadPoolWork* async_task) {
728+
async_tasks_.erase(async_task);
729+
}
730+
677731
void Database::DeleteSessions() {
678732
// all attached sessions need to be deleted before the database is closed
679733
// https://www.sqlite.org/session/sqlite3session_create.html
@@ -685,6 +739,7 @@ void Database::DeleteSessions() {
685739

686740
Database::~Database() {
687741
FinalizeBackups();
742+
async_tasks_.clear();
688743

689744
if (IsOpen()) {
690745
FinalizeStatements();
@@ -841,7 +896,8 @@ std::optional<std::string> ValidateDatabasePath(Environment* env,
841896
return std::nullopt;
842897
}
843898

844-
inline void DatabaseNew(const FunctionCallbackInfo<Value>& args, bool async = true) {
899+
inline void DatabaseNew(const FunctionCallbackInfo<Value>& args,
900+
bool async = true) {
845901
Environment* env = Environment::GetCurrent(args);
846902
if (!args.IsConstructCall()) {
847903
THROW_ERR_CONSTRUCT_CALL_REQUIRED(env);
@@ -1053,8 +1109,7 @@ void Database::IsOpenGetter(const FunctionCallbackInfo<Value>& args) {
10531109
args.GetReturnValue().Set(db->IsOpen());
10541110
}
10551111

1056-
void Database::IsTransactionGetter(
1057-
const FunctionCallbackInfo<Value>& args) {
1112+
void Database::IsTransactionGetter(const FunctionCallbackInfo<Value>& args) {
10581113
Database* db;
10591114
ASSIGN_OR_RETURN_UNWRAP(&db, args.This());
10601115
Environment* env = Environment::GetCurrent(args);
@@ -1116,13 +1171,47 @@ void Database::Exec(const FunctionCallbackInfo<Value>& args) {
11161171
return;
11171172
}
11181173

1174+
Isolate* isolate = env->isolate();
1175+
auto sql = Utf8Value(isolate, args[0].As<String>()).ToString();
1176+
auto task = [sql, db]() -> int {
1177+
return sqlite3_exec(
1178+
db->connection_, sql.c_str(), nullptr, nullptr, nullptr);
1179+
};
1180+
11191181
if (db->open_config_.get_async()) {
1120-
// TODO(geeksilva97): Support async by returning a Promise
1121-
std::cout << "This is async" << std::endl;
1182+
Local<Promise::Resolver> resolver;
1183+
if (!Promise::Resolver::New(env->context()).ToLocal(&resolver)) {
1184+
return;
1185+
}
1186+
1187+
auto after = [db, env, isolate](int exec_result,
1188+
Local<Promise::Resolver> resolver) {
1189+
if (exec_result != SQLITE_OK) {
1190+
if (db->ShouldIgnoreSQLiteError()) {
1191+
db->SetIgnoreNextSQLiteError(false);
1192+
return;
1193+
}
1194+
1195+
Local<Object> e;
1196+
if (!CreateSQLiteError(isolate, db->Connection()).ToLocal(&e)) {
1197+
return;
1198+
}
1199+
1200+
resolver->Reject(env->context(), e).FromJust();
1201+
return;
1202+
}
1203+
1204+
resolver->Resolve(env->context(), Undefined(env->isolate())).FromJust();
1205+
};
1206+
1207+
auto* work = new SQLiteAsyncTask<int>(env, db, resolver, task, after);
1208+
work->ScheduleWork();
1209+
db->AddAsyncTask(work);
1210+
args.GetReturnValue().Set(resolver->GetPromise());
1211+
return;
11221212
}
11231213

1124-
Utf8Value sql(env->isolate(), args[0].As<String>());
1125-
int r = sqlite3_exec(db->connection_, *sql, nullptr, nullptr, nullptr);
1214+
int r = task();
11261215
CHECK_ERROR_OR_THROW(env->isolate(), db, r, SQLITE_OK, void());
11271216
}
11281217

@@ -1776,8 +1865,7 @@ void Database::ApplyChangeset(const FunctionCallbackInfo<Value>& args) {
17761865
THROW_ERR_SQLITE_ERROR(env->isolate(), r);
17771866
}
17781867

1779-
void Database::EnableLoadExtension(
1780-
const FunctionCallbackInfo<Value>& args) {
1868+
void Database::EnableLoadExtension(const FunctionCallbackInfo<Value>& args) {
17811869
Database* db;
17821870
ASSIGN_OR_RETURN_UNWRAP(&db, args.This());
17831871
Environment* env = Environment::GetCurrent(args);
@@ -2470,8 +2558,9 @@ Local<FunctionTemplate> StatementSync::GetConstructorTemplate(
24702558
return tmpl;
24712559
}
24722560

2473-
BaseObjectPtr<StatementSync> StatementSync::Create(
2474-
Environment* env, BaseObjectPtr<Database> db, sqlite3_stmt* stmt) {
2561+
BaseObjectPtr<StatementSync> StatementSync::Create(Environment* env,
2562+
BaseObjectPtr<Database> db,
2563+
sqlite3_stmt* stmt) {
24752564
Local<Object> obj;
24762565
if (!GetConstructorTemplate(env)
24772566
->InstanceTemplate()
@@ -2730,12 +2819,15 @@ void DefineConstants(Local<Object> target) {
27302819
NODE_DEFINE_CONSTANT(target, SQLITE_CHANGESET_FOREIGN_KEY);
27312820
}
27322821

2733-
void DefineAsyncInterface(Isolate* isolate, Local<Object> target, Local<Context> context) {
2822+
void DefineAsyncInterface(Isolate* isolate,
2823+
Local<Object> target,
2824+
Local<Context> context) {
27342825
Local<FunctionTemplate> db_async_tmpl =
27352826
NewFunctionTemplate(isolate, Database::NewAsync);
27362827
db_async_tmpl->InstanceTemplate()->SetInternalFieldCount(
27372828
Database::kInternalFieldCount);
27382829

2830+
SetProtoMethod(isolate, db_async_tmpl, "close", Database::Close);
27392831
SetProtoMethod(isolate, db_async_tmpl, "exec", Database::Exec);
27402832
SetConstructorFunction(context, target, "Database", db_async_tmpl);
27412833
}

src/node_sqlite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "base_object.h"
77
#include "node_mem.h"
88
#include "sqlite3.h"
9+
#include "threadpoolwork-inl.h"
910
#include "util.h"
1011

1112
#include <map>
@@ -114,6 +115,8 @@ class Database : public BaseObject {
114115
void FinalizeStatements();
115116
void RemoveBackup(BackupJob* backup);
116117
void AddBackup(BackupJob* backup);
118+
void AddAsyncTask(ThreadPoolWork* async_task);
119+
void RemoveAsyncTask(ThreadPoolWork* async_task);
117120
void FinalizeBackups();
118121
void UntrackStatement(StatementSync* statement);
119122
bool IsOpen();
@@ -149,6 +152,7 @@ class Database : public BaseObject {
149152
bool ignore_next_sqlite_error_;
150153

151154
std::set<BackupJob*> backups_;
155+
std::set<ThreadPoolWork*> async_tasks_;
152156
std::set<sqlite3_session*> sessions_;
153157
std::unordered_set<StatementSync*> statements_;
154158

test/parallel/test-sqlite-database-async.js

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { skipIfSQLiteMissing } from '../common/index.mjs';
2+
import tmpdir from '../common/tmpdir.js';
3+
import { suite, test } from 'node:test';
4+
import { join } from 'node:path';
5+
import { Database } from 'node:sqlite';
6+
skipIfSQLiteMissing();
7+
8+
tmpdir.refresh();
9+
10+
let cnt = 0;
11+
function nextDb() {
12+
return join(tmpdir.path, `database-${cnt++}.db`);
13+
}
14+
15+
suite('Database.prototype.exec()', () => {
16+
test('executes SQL', async (t) => {
17+
const db = new Database(nextDb());
18+
t.after(() => { db.close(); });
19+
const result = await db.exec(`
20+
CREATE TABLE data(
21+
key INTEGER PRIMARY KEY,
22+
val INTEGER
23+
) STRICT;
24+
INSERT INTO data (key, val) VALUES (1, 2);
25+
INSERT INTO data (key, val) VALUES (8, 9);
26+
`);
27+
t.assert.strictEqual(result, undefined);
28+
});
29+
30+
test('reports errors from SQLite', async (t) => {
31+
const db = new Database(nextDb());
32+
t.after(() => { db.close(); });
33+
34+
await t.assert.rejects(db.exec('CREATE TABLEEEE'), {
35+
code: 'ERR_SQLITE_ERROR',
36+
message: /syntax error/,
37+
});
38+
});
39+
40+
test('throws if the URL does not have the file: scheme', (t) => {
41+
t.assert.throws(() => {
42+
new Database(new URL('http://example.com'));
43+
}, {
44+
code: 'ERR_INVALID_URL_SCHEME',
45+
message: 'The URL must be of scheme file:',
46+
});
47+
});
48+
49+
test('throws if database is not open', (t) => {
50+
const db = new Database(nextDb(), { open: false });
51+
52+
t.assert.throws(() => {
53+
db.exec();
54+
}, {
55+
code: 'ERR_INVALID_STATE',
56+
message: /database is not open/,
57+
});
58+
});
59+
60+
test('throws if sql is not a string', (t) => {
61+
const db = new Database(nextDb());
62+
t.after(() => { db.close(); });
63+
64+
t.assert.throws(() => {
65+
db.exec();
66+
}, {
67+
code: 'ERR_INVALID_ARG_TYPE',
68+
message: /The "sql" argument must be a string/,
69+
});
70+
});
71+
});

0 commit comments

Comments
 (0)