diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 236f146..fb2fe0a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,27 @@ jobs: done echo "sqld is ready!" + - name: Install sqlean + run: | + SQLEAN_VERSION=0.27.2 + case "${{ runner.os }}" in + Linux) OS_SLUG=linux-x86 ;; + macOS) OS_SLUG=macos-x86 ;; + esac + curl -sL -o sqlean-${OS_SLUG}.zip \ + https://github.com/nalgeon/sqlean/releases/download/${SQLEAN_VERSION}/sqlean-${OS_SLUG}.zip + unzip -q sqlean-${OS_SLUG}.zip -d . + echo "$PWD" >> $GITHUB_PATH + + - name: Set extension env + run: | + case "${{ runner.os }}" in + Linux) ext=sqlean.so ;; + macOS) ext=sqlean.dylib ;; + esac + echo "LIBSQL_TEST_EXTENSION=$PWD/$ext" >> $GITHUB_ENV + echo "LIBSQL_TEST_EXTENSION_ENTRY=sqlite3_sqlean_init" >> $GITHUB_ENV + - name: Build run: go build -v ./... diff --git a/libsql.go b/libsql.go index 055c088..252e8c8 100644 --- a/libsql.go +++ b/libsql.go @@ -424,6 +424,20 @@ type conn struct { nativePtr C.libsql_connection_t } +func (c *conn) LoadExtension(lib string, entry string) error { + libCString := C.CString(lib) + defer C.free(unsafe.Pointer(libCString)) + entryCString := C.CString(entry) + defer C.free(unsafe.Pointer(entryCString)) + + var errMsg *C.char + statusCode := C.libsql_load_extension(c.nativePtr, libCString, entryCString, &errMsg) + if statusCode != 0 { + return libsqlError(fmt.Sprintf("failed to load extension %s with entry point %s", lib, entry), statusCode, errMsg) + } + return nil +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { return c.PrepareContext(context.Background(), query) } diff --git a/libsql_test.go b/libsql_test.go index 9b0d139..716ff8f 100644 --- a/libsql_test.go +++ b/libsql_test.go @@ -1359,3 +1359,71 @@ func TestErrorRowsNext(t *testing.T) { } }) } + +// To run this, set LIBSQL_TEST_EXTENSION to the full path of a valid SQLite extension +// and (optionally) LIBSQL_TEST_EXTENSION_ENTRY to its init symbol (defaults to "sqlite3_extension_init"). +func TestLoadExtension_Existing(t *testing.T) { + extPath := os.Getenv("LIBSQL_TEST_EXTENSION") + if extPath == "" { + t.Skip("LIBSQL_TEST_EXTENSION not set; skipping existing‐extension load test") + } + entryPoint := os.Getenv("LIBSQL_TEST_EXTENSION_ENTRY") + if entryPoint == "" { + entryPoint = "sqlite3_extension_init" + } + + db, err := sql.Open("libsql", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx := context.Background() + sqlConn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer sqlConn.Close() + + err = sqlConn.Raw(func(driverConn any) error { + cImpl, ok := driverConn.(*conn) + if !ok { + return fmt.Errorf("unexpected driverConn type %T", driverConn) + } + return cImpl.LoadExtension(extPath, entryPoint) + }) + + if err != nil { + t.Fatalf("failed to load existing extension %q (entry %q): %v", extPath, entryPoint, err) + } +} + +func TestLoadExtension_Nonexistent(t *testing.T) { + db, err := sql.Open("libsql", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx := context.Background() + sqlConn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer sqlConn.Close() + + err = sqlConn.Raw(func(driverConn any) error { + cImpl, ok := driverConn.(*conn) + if !ok { + return fmt.Errorf("unexpected driverConn type %T", driverConn) + } + return cImpl.LoadExtension("nonexistent_extension.so", "entry_point") + }) + + if err == nil { + t.Fatal("expected error loading nonexistent extension, got nil") + } + if !strings.Contains(err.Error(), "failed to load extension") { + t.Fatalf("unexpected error loading extension: %v", err) + } +}