diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementIT.java similarity index 99% rename from integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java rename to integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementIT.java index f06d46201aff2..c7c5902e8e2dc 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementIT.java @@ -45,7 +45,7 @@ @RunWith(IoTDBTestRunner.class) @Category({TableLocalStandaloneIT.class, TableClusterIT.class}) -public class IoTDBPreparedStatementIT { +public class IoTDBTablePreparedStatementIT { private static final String DATABASE_NAME = "test"; private static final String[] sqls = new String[] { diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementJDBCIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementJDBCIT.java new file mode 100644 index 0000000000000..7f5da65fa0c0a --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBTablePreparedStatementJDBCIT.java @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.db.it; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableClusterIT; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.itbase.runtime.ClusterTestConnection; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class, TableClusterIT.class}) +public class IoTDBTablePreparedStatementJDBCIT { + + private static final String DATABASE_NAME = "test"; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().initClusterEnvironment(); + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE " + DATABASE_NAME); + statement.execute("USE " + DATABASE_NAME); + statement.execute( + "CREATE TABLE test_table(id STRING TAG, name STRING FIELD, value DOUBLE FIELD, " + + "int_value INT32 FIELD, long_value INT64 FIELD)"); + statement.execute( + "INSERT INTO test_table VALUES (2025-01-01T00:00:00, '1', 'Alice', 100.5, 10, 1000)"); + statement.execute( + "INSERT INTO test_table VALUES (2025-01-01T00:01:00, '2', 'Bob', 200.3, 20, 2000)"); + statement.execute( + "INSERT INTO test_table VALUES (2025-01-01T00:02:00, '3', 'Charlie', 300.7, 30, 3000)"); + } + } + + @AfterClass + public static void tearDown() { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private Connection getConnection() throws SQLException { + Connection connection = EnvFactory.getEnv().getTableConnection(); + if (connection instanceof ClusterTestConnection) { + // Get the underlying real JDBC connection that supports prepareStatement + return ((ClusterTestConnection) connection).writeConnection.getUnderlyingConnection(); + } + return connection; + } + + @Test + public void testPreparedStatementWithIntParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE int_value = ?")) { + ps.setInt(1, 20); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals(20, rs.getInt("int_value")); + assertEquals("Bob", rs.getString("name")); + assertEquals(200.3, rs.getDouble("value"), 0.001); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithStringParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE name = ?")) { + ps.setString(1, "Charlie"); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("3", rs.getString("id")); + assertEquals("Charlie", rs.getString("name")); + assertEquals(300.7, rs.getDouble("value"), 0.001); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithMultipleParameters() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE id = ? AND value < ?")) { + ps.setString(1, "2"); + ps.setDouble(2, 300.0); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("2", rs.getString("id")); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementExecuteMultipleTimes() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE id = ?")) { + // First execution + ps.setString(1, "1"); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("Alice", rs.getString("name")); + assertFalse(rs.next()); + } + + // Second execution with different parameter + ps.setString(1, "3"); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("Charlie", rs.getString("name")); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithDoubleParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE value > ?")) { + ps.setDouble(1, 250.0); + try (ResultSet rs = ps.executeQuery()) { + int count = 0; + while (rs.next()) { + assertTrue(rs.getDouble("value") > 250.0); + count++; + } + // Only Charlie (300.7 > 250) satisfies the condition + assertEquals(1, count); + } + } + } + } + + @Test + public void testPreparedStatementWithLongParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE long_value = ?")) { + ps.setLong(1, 1000L); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1000L, rs.getLong("long_value")); + assertEquals("Alice", rs.getString("name")); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithFloatParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE value < ?")) { + ps.setFloat(1, 150.0f); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("Alice", rs.getString("name")); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithBooleanParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + // Create table with boolean column + stmt.execute("CREATE TABLE bool_table(flag BOOLEAN FIELD)"); + stmt.execute("INSERT INTO bool_table VALUES (2025-01-01T00:00:00, true)"); + stmt.execute("INSERT INTO bool_table VALUES (2025-01-01T00:01:00, false)"); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM bool_table WHERE flag = ?")) { + ps.setBoolean(1, true); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getBoolean("flag")); + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithNullParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE name = ?")) { + ps.setNull(1, java.sql.Types.VARCHAR); + try (ResultSet rs = ps.executeQuery()) { + // No rows should match null + assertFalse(rs.next()); + } + } + } + } + + @Test + public void testPreparedStatementWithBinaryParameter() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + stmt.execute("CREATE TABLE blob_table(data BLOB FIELD)"); + + byte[] testData = new byte[] {0x01, 0x02, 0x03}; + stmt.execute("INSERT INTO blob_table VALUES (2025-01-01T00:00:00, X'010203')"); + + try (PreparedStatement queryPs = + connection.prepareStatement("SELECT data FROM blob_table WHERE data = ?")) { + queryPs.setBytes(1, testData); + + try (ResultSet rs = queryPs.executeQuery()) { + assertTrue(rs.next()); + assertArrayEquals(testData, rs.getBytes("data")); + } + } + } + } + + @Test + public void testPreparedStatementResultSetMetaData() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement( + "SELECT id, name, value, int_value, long_value FROM test_table WHERE id = ?")) { + ps.setString(1, "1"); + try (ResultSet rs = ps.executeQuery()) { + ResultSetMetaData metaData = rs.getMetaData(); + assertEquals(5, metaData.getColumnCount()); + assertEquals("id", metaData.getColumnLabel(1).toLowerCase()); + assertEquals("name", metaData.getColumnLabel(2).toLowerCase()); + assertEquals("value", metaData.getColumnLabel(3).toLowerCase()); + assertEquals("int_value", metaData.getColumnLabel(4).toLowerCase()); + assertEquals("long_value", metaData.getColumnLabel(5).toLowerCase()); + } + } + } + } + + @Test + public void testPreparedStatementParameterMetaData() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement( + "SELECT id, name, value FROM test_table WHERE id = ? AND value > ?")) { + ParameterMetaData metaData = ps.getParameterMetaData(); + assertEquals(2, metaData.getParameterCount()); + } + } + } + + @Test + public void testPreparedStatementInsert() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + stmt.execute("CREATE TABLE insert_test(id INT32 FIELD, name STRING FIELD)"); + + try (PreparedStatement ps = + connection.prepareStatement("INSERT INTO insert_test VALUES (?, ?, ?)")) { + ps.setLong(1, System.currentTimeMillis()); + ps.setInt(2, 100); + ps.setString(3, "TestName"); + int affected = ps.executeUpdate(); + assertTrue(affected >= 0); + } + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM insert_test WHERE id = 100")) { + assertTrue(rs.next()); + assertEquals("TestName", rs.getString("name")); + } + } + } + + @Test + public void testPreparedStatementAggregation() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT COUNT(*) as cnt FROM test_table WHERE value > ?")) { + ps.setDouble(1, 150.0); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals(2, rs.getLong("cnt")); // Bob and Charlie + } + } + } + } + + @Test + public void testPreparedStatementClearParameters() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps = + connection.prepareStatement("SELECT * FROM test_table WHERE id = ?")) { + ps.setString(1, "1"); + ps.clearParameters(); + // After clear, should be able to set new parameters + ps.setString(1, "2"); + try (ResultSet rs = ps.executeQuery()) { + assertTrue(rs.next()); + assertEquals("Bob", rs.getString("name")); + } + } + } + } + + @Test + public void testMultiplePreparedStatements() throws SQLException { + try (Connection connection = getConnection(); + Statement stmt = connection.createStatement()) { + stmt.execute("USE " + DATABASE_NAME); + + try (PreparedStatement ps1 = + connection.prepareStatement("SELECT * FROM test_table WHERE id = ?"); + PreparedStatement ps2 = + connection.prepareStatement( + "SELECT COUNT(*) as cnt FROM test_table WHERE value > ?")) { + // Execute first prepared statement + ps1.setString(1, "1"); + try (ResultSet rs = ps1.executeQuery()) { + assertTrue(rs.next()); + assertEquals("Alice", rs.getString("name")); + } + + // Execute second prepared statement + ps2.setDouble(1, 200.0); + try (ResultSet rs = ps2.executeQuery()) { + assertTrue(rs.next()); + assertEquals(2, rs.getLong("cnt")); + } + } + } + } +} diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java index 51ebf7d727a03..54b148f25e4c4 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java @@ -433,7 +433,11 @@ public CallableStatement prepareCall(String arg0, int arg1, int arg2, int arg3) @Override public PreparedStatement prepareStatement(String sql) throws SQLException { - return new IoTDBPreparedStatement(this, getClient(), sessionId, sql, zoneId, charset); + if (getSqlDialect().equals(Constant.TABLE_DIALECT)) { + return new IoTDBTablePreparedStatement(this, getClient(), sessionId, sql, zoneId, charset); + } else { + return new IoTDBPreparedStatement(this, getClient(), sessionId, sql, zoneId, charset); + } } @Override diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java index 93a922070db2b..8cb0a32417f27 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java @@ -55,7 +55,7 @@ public class IoTDBStatement implements Statement { private final IoTDBConnection connection; - private ResultSet resultSet = null; + protected ResultSet resultSet = null; private int fetchSize; private int maxRows = 0; @@ -66,7 +66,7 @@ public class IoTDBStatement implements Statement { * Timeout of query can be set by users. Unit: s. A negative number means using the default * configuration of server. And value 0 will disable the function of query timeout. */ - private int queryTimeout = -1; + protected int queryTimeout = -1; protected IClientRPCService.Iface client; private List batchSQLList; @@ -82,7 +82,7 @@ public class IoTDBStatement implements Statement { /** Add SQLWarnings to the warningChain if needed. */ private SQLWarning warningChain = null; - private long sessionId; + protected long sessionId; private long stmtId = -1; private long queryId = -1; diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatement.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatement.java new file mode 100644 index 0000000000000..c67fd554b2404 --- /dev/null +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatement.java @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.jdbc; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.rpc.RpcUtils; +import org.apache.iotdb.rpc.StatementExecutionException; +import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.rpc.stmt.PreparedParameterSerializer; +import org.apache.iotdb.service.rpc.thrift.IClientRPCService.Iface; +import org.apache.iotdb.service.rpc.thrift.TSDeallocatePreparedReq; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; +import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; + +import org.apache.thrift.TException; +import org.apache.tsfile.common.conf.TSFileConfig; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.nio.charset.Charset; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class IoTDBTablePreparedStatement extends IoTDBStatement implements PreparedStatement { + + private static final Logger logger = LoggerFactory.getLogger(IoTDBTablePreparedStatement.class); + private static final String METHOD_NOT_SUPPORTED_STRING = "Method not supported"; + + private final String sql; + private final String preparedStatementName; + private final int parameterCount; + private final boolean serverSidePrepared; + + private final Object[] parameterValues; + private final int[] parameterTypes; + + // retain parameters for backward compatibility + private final Map parameters = new HashMap<>(); + + IoTDBTablePreparedStatement( + IoTDBConnection connection, + Iface client, + Long sessionId, + String sql, + ZoneId zoneId, + Charset charset) + throws SQLException { + super(connection, client, sessionId, zoneId, charset); + this.sql = sql; + this.preparedStatementName = generateStatementName(); + + if (isQueryStatement(sql)) { + // Send PREPARE request to server only for query statements + this.serverSidePrepared = true; + TSPrepareReq prepareReq = new TSPrepareReq(); + prepareReq.setSessionId(sessionId); + prepareReq.setSql(sql); + prepareReq.setStatementName(preparedStatementName); + + try { + TSPrepareResp resp = client.prepareStatement(prepareReq); + RpcUtils.verifySuccess(resp.getStatus()); + + this.parameterCount = resp.isSetParameterCount() ? resp.getParameterCount() : 0; + this.parameterValues = new Object[parameterCount]; + this.parameterTypes = new int[parameterCount]; + + for (int i = 0; i < parameterCount; i++) { + parameterTypes[i] = Types.NULL; + } + } catch (TException | StatementExecutionException e) { + throw new SQLException("Failed to prepare statement: " + e.getMessage(), e); + } + } else { + // For non-query statements, only keep text parameters for client-side substitution. + this.serverSidePrepared = false; + this.parameterCount = 0; + this.parameterValues = null; + this.parameterTypes = null; + } + } + + // Only for tests + IoTDBTablePreparedStatement( + IoTDBConnection connection, Iface client, Long sessionId, String sql, ZoneId zoneId) + throws SQLException { + this(connection, client, sessionId, sql, zoneId, TSFileConfig.STRING_CHARSET); + } + + private String generateStatementName() { + // StatementId is unique across all sessions in one IoTDB instance + return "jdbc_ps_" + getStmtId(); + } + + @Override + public void addBatch() throws SQLException { + super.addBatch(createCompleteSql(sql, parameters)); + } + + @Override + public void clearParameters() { + this.parameters.clear(); + if (serverSidePrepared) { + for (int i = 0; i < parameterCount; i++) { + parameterValues[i] = null; + parameterTypes[i] = Types.NULL; + } + } + } + + @Override + public boolean execute() throws SQLException { + if (isQueryStatement(sql)) { + TSExecuteStatementResp resp = executeInternal(); + return resp.isSetQueryDataSet() || resp.isSetQueryResult(); + } else { + return super.execute(createCompleteSql(sql, parameters)); + } + } + + private boolean isQueryStatement(String sql) { + if (sql == null) { + return false; + } + String trimmedSql = sql.trim().toUpperCase(); + return trimmedSql.startsWith("SELECT"); + } + + @Override + public ResultSet executeQuery() throws SQLException { + TSExecuteStatementResp resp = executeInternal(); + return processQueryResult(resp); + } + + @Override + public int executeUpdate() throws SQLException { + return super.executeUpdate(createCompleteSql(sql, parameters)); + } + + private TSExecuteStatementResp executeInternal() throws SQLException { + // Validate all parameters are set + for (int i = 0; i < parameterCount; i++) { + if (parameterTypes[i] == Types.NULL + && parameterValues[i] == null + && !parameters.containsKey(i + 1)) { + throw new SQLException("Parameter #" + (i + 1) + " is unset"); + } + } + + TSExecutePreparedReq req = new TSExecutePreparedReq(); + req.setSessionId(sessionId); + req.setStatementName(preparedStatementName); + req.setParameters( + PreparedParameterSerializer.serialize(parameterValues, parameterTypes, parameterCount)); + req.setStatementId(getStmtId()); + if (queryTimeout > 0) { + req.setTimeout(queryTimeout * 1000L); + } + + try { + TSExecuteStatementResp resp = client.executePreparedStatement(req); + RpcUtils.verifySuccess(resp.getStatus()); + return resp; + } catch (TException | StatementExecutionException e) { + throw new SQLException("Failed to execute prepared statement: " + e.getMessage(), e); + } + } + + private ResultSet processQueryResult(TSExecuteStatementResp resp) throws SQLException { + if (resp.isSetQueryDataSet() || resp.isSetQueryResult()) { + this.resultSet = + new IoTDBJDBCResultSet( + this, + resp.getColumns(), + resp.getDataTypeList(), + resp.columnNameIndexMap, + resp.ignoreTimeStamp, + client, + sql, + resp.queryId, + sessionId, + resp.queryResult, + resp.tracingInfo, + (long) queryTimeout * 1000, + resp.isSetMoreData() && resp.isMoreData(), + zoneId); + return resultSet; + } + return null; + } + + @Override + public void close() throws SQLException { + if (!isClosed() && serverSidePrepared) { + // Deallocate prepared statement on server only if it was prepared server-side + TSDeallocatePreparedReq req = new TSDeallocatePreparedReq(); + req.setSessionId(sessionId); + req.setStatementName(preparedStatementName); + + try { + TSStatus status = client.deallocatePreparedStatement(req); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + logger.warn("Failed to deallocate prepared statement: {}", status.getMessage()); + } + } catch (TException e) { + logger.warn("Error deallocating prepared statement", e); + } + } + super.close(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + if (resultSet != null) { + return resultSet.getMetaData(); + } + return null; + } + + @Override + public ParameterMetaData getParameterMetaData() { + return new ParameterMetaData() { + @Override + public int getParameterCount() { + return parameterCount; + } + + @Override + public int isNullable(int param) { + return ParameterMetaData.parameterNullableUnknown; + } + + @Override + public boolean isSigned(int param) { + if (!serverSidePrepared) { + return false; + } + int type = parameterTypes[param - 1]; + return type == Types.INTEGER + || type == Types.BIGINT + || type == Types.FLOAT + || type == Types.DOUBLE; + } + + @Override + public int getPrecision(int param) { + return 0; + } + + @Override + public int getScale(int param) { + return 0; + } + + @Override + public int getParameterType(int param) { + if (!serverSidePrepared) { + return Types.NULL; + } + return parameterTypes[param - 1]; + } + + @Override + public String getParameterTypeName(int param) { + return null; + } + + @Override + public String getParameterClassName(int param) { + return null; + } + + @Override + public int getParameterMode(int param) { + return ParameterMetaData.parameterModeIn; + } + + @Override + public T unwrap(Class iface) { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) { + return false; + } + }; + } + + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, null, Types.NULL); + this.parameters.put(parameterIndex, "NULL"); + } + + @Override + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + setNull(parameterIndex, sqlType); + } + + @Override + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.BOOLEAN); + this.parameters.put(parameterIndex, Boolean.toString(x)); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.INTEGER); + this.parameters.put(parameterIndex, Integer.toString(x)); + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.BIGINT); + this.parameters.put(parameterIndex, Long.toString(x)); + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.FLOAT); + this.parameters.put(parameterIndex, Float.toString(x)); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.DOUBLE); + this.parameters.put(parameterIndex, Double.toString(x)); + } + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.VARCHAR); + if (x == null) { + this.parameters.put(parameterIndex, null); + } else { + this.parameters.put(parameterIndex, "'" + escapeSingleQuotes(x) + "'"); + } + } + + @Override + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + checkParameterIndex(parameterIndex); + setPreparedParameterValue(parameterIndex, x, Types.BINARY); + Binary binary = new Binary(x); + this.parameters.put(parameterIndex, binary.getStringValue(TSFileConfig.STRING_CHARSET)); + } + + @Override + public void setDate(int parameterIndex, Date x) throws SQLException { + checkParameterIndex(parameterIndex); + DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + String dateStr = dateFormat.format(x); + setPreparedParameterValue(parameterIndex, dateStr, Types.VARCHAR); + this.parameters.put(parameterIndex, "'" + dateStr + "'"); + } + + @Override + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + setDate(parameterIndex, x); + } + + @Override + public void setTime(int parameterIndex, Time x) throws SQLException { + checkParameterIndex(parameterIndex); + try { + long time = x.getTime(); + String timeprecision = client.getProperties().getTimestampPrecision(); + switch (timeprecision.toLowerCase()) { + case "ms": + break; + case "us": + time = time * 1000; + break; + case "ns": + time = time * 1000000; + break; + default: + break; + } + setPreparedParameterValue(parameterIndex, time, Types.BIGINT); + this.parameters.put(parameterIndex, Long.toString(time)); + } catch (TException e) { + throw new SQLException("Failed to get time precision: " + e.getMessage(), e); + } + } + + @Override + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + setTime(parameterIndex, x); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + checkParameterIndex(parameterIndex); + ZonedDateTime zonedDateTime = + ZonedDateTime.ofInstant(Instant.ofEpochMilli(x.getTime()), super.zoneId); + String tsStr = zonedDateTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME); + setPreparedParameterValue(parameterIndex, tsStr, Types.VARCHAR); + this.parameters.put(parameterIndex, tsStr); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + setTimestamp(parameterIndex, x); + } + + @Override + public void setObject(int parameterIndex, Object x) throws SQLException { + if (x == null) { + setNull(parameterIndex, Types.NULL); + } else if (x instanceof String) { + setString(parameterIndex, (String) x); + } else if (x instanceof Integer) { + setInt(parameterIndex, (Integer) x); + } else if (x instanceof Long) { + setLong(parameterIndex, (Long) x); + } else if (x instanceof Float) { + setFloat(parameterIndex, (Float) x); + } else if (x instanceof Double) { + setDouble(parameterIndex, (Double) x); + } else if (x instanceof Boolean) { + setBoolean(parameterIndex, (Boolean) x); + } else if (x instanceof Timestamp) { + setTimestamp(parameterIndex, (Timestamp) x); + } else if (x instanceof Date) { + setDate(parameterIndex, (Date) x); + } else if (x instanceof Time) { + setTime(parameterIndex, (Time) x); + } else if (x instanceof byte[]) { + setBytes(parameterIndex, (byte[]) x); + } else { + throw new SQLException( + String.format( + "Can't infer the SQL type for an instance of %s. Use setObject() with explicit type.", + x.getClass().getName())); + } + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + setObject(parameterIndex, x); + } + + @Override + public void setObject(int parameterIndex, Object parameterObj, int targetSqlType, int scale) + throws SQLException { + setObject(parameterIndex, parameterObj); + } + + private void checkParameterIndex(int index) throws SQLException { + if (!serverSidePrepared) { + return; + } + if (index < 1 || index > parameterCount) { + throw new SQLException( + "Parameter index out of range: " + index + " (expected 1-" + parameterCount + ")"); + } + } + + private void setPreparedParameterValue(int parameterIndex, Object value, int sqlType) { + if (!serverSidePrepared) { + return; + } + parameterValues[parameterIndex - 1] = value; + parameterTypes[parameterIndex - 1] = sqlType; + } + + private String escapeSingleQuotes(String value) { + return value.replace("'", "''"); + } + + @Override + public void setArray(int parameterIndex, Array x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + try { + byte[] bytes = ReadWriteIOUtils.readBytes(x, length); + setBytes(parameterIndex, bytes); + } catch (IOException e) { + throw new SQLException("Failed to read binary stream: " + e.getMessage(), e); + } + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, Blob x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream, long length) + throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, int length) + throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, long length) + throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Clob x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value, long length) + throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, NClob value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setNString(int parameterIndex, String value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setRef(int parameterIndex, Ref x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setRowId(int parameterIndex, RowId x) throws SQLException { + throw new SQLException(METHOD_NOT_SUPPORTED_STRING); + } + + @Override + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + throw new SQLException(METHOD_NOT_SUPPORTED_STRING); + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + setInt(parameterIndex, x); + } + + @Override + public void setURL(int parameterIndex, URL x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + @Override + public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); + } + + // ================== Helper Methods for Backward Compatibility ================== + + private String createCompleteSql(final String sql, Map parameters) + throws SQLException { + List parts = splitSqlStatement(sql); + + StringBuilder newSql = new StringBuilder(parts.get(0)); + for (int i = 1; i < parts.size(); i++) { + if (!parameters.containsKey(i)) { + throw new SQLException("Parameter #" + i + " is unset"); + } + newSql.append(parameters.get(i)); + newSql.append(parts.get(i)); + } + return newSql.toString(); + } + + private List splitSqlStatement(final String sql) { + List parts = new ArrayList<>(); + int apCount = 0; + int off = 0; + boolean skip = false; + + for (int i = 0; i < sql.length(); i++) { + char c = sql.charAt(i); + if (skip) { + skip = false; + continue; + } + switch (c) { + case '\'': + apCount++; + break; + case '\\': + skip = true; + break; + case '?': + if ((apCount & 1) == 0) { + parts.add(sql.substring(off, i)); + off = i + 1; + } + break; + default: + break; + } + } + parts.add(sql.substring(off)); + return parts; + } +} diff --git a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java index 2ae65dfed2aea..f80b8a83936cd 100644 --- a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java +++ b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java @@ -400,146 +400,4 @@ public void testInsertStatement4() throws Exception { "INSERT INTO root.ln.wf01.wt02(time,a,b,c,d,e,f) VALUES(2020-01-01T10:10:10,false,123,123234345,123.423,-1323.0,'abc')", argument.getValue().getStatement()); } - - // ========== Table Model SQL Injection Prevention Tests ========== - - @SuppressWarnings("resource") - @Test - public void testTableModelLoginInjectionWithComment() throws Exception { - // Login interface SQL injection attack 1: Using -- comments to bypass password checks - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE username = ? AND password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "admin' --"); - ps.setString(2, "password"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE username = 'admin'' --' AND password = 'password'", - argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelLoginInjectionWithORCondition() throws Exception { - // Login interface SQL injection attack 2: Bypassing authentication by using 'OR '1'='1 - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE username = ? AND password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "admin"); - ps.setString(2, "' OR '1'='1"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE username = 'admin' AND password = ''' OR ''1''=''1'", - argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelQueryWithMultipleInjectionVectors() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE email = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "'; DROP TABLE users;"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE email = '''; DROP TABLE users;'", - argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelString1() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "a'b"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE password = 'a''b'", argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelString2() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "a\'b"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE password = 'a''b'", argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelString3() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "a\\'b"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE password = 'a\\''b'", argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelString4() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE password = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "a\\\'b"); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE password = 'a\\''b'", argument.getValue().getStatement()); - } - - @SuppressWarnings("resource") - @Test - public void testTableModelStringWithNull() throws Exception { - when(connection.getSqlDialect()).thenReturn("table"); - String sql = "SELECT * FROM users WHERE email = ?"; - IoTDBPreparedStatement ps = - new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, null); - ps.execute(); - - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE email = null", argument.getValue().getStatement()); - } } diff --git a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatementTest.java b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatementTest.java new file mode 100644 index 0000000000000..dd1c7ecbcb209 --- /dev/null +++ b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBTablePreparedStatementTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.jdbc; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.service.rpc.thrift.IClientRPCService.Iface; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; +import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.time.ZoneId; + +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class IoTDBTablePreparedStatementTest { + + @Mock TSExecuteStatementResp execStatementResp; + @Mock TSStatus getOperationStatusResp; + private ZoneId zoneId = ZoneId.systemDefault(); + @Mock private IoTDBConnection connection; + @Mock private Iface client; + @Mock private TSStatus successStatus = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + private TSStatus Status_SUCCESS = new TSStatus(successStatus); + private long queryId; + private long sessionId; + + @Before + public void before() throws Exception { + MockitoAnnotations.initMocks(this); + when(connection.getSqlDialect()).thenReturn("table"); + when(execStatementResp.getStatus()).thenReturn(Status_SUCCESS); + when(execStatementResp.getQueryId()).thenReturn(queryId); + + // Mock for prepareStatement - dynamically calculate parameter count from SQL + when(client.prepareStatement(any(TSPrepareReq.class))) + .thenAnswer( + new Answer() { + @Override + public TSPrepareResp answer(InvocationOnMock invocation) throws Throwable { + TSPrepareReq req = invocation.getArgument(0); + String sql = req.getSql(); + int paramCount = countQuestionMarks(sql); + + TSPrepareResp resp = new TSPrepareResp(); + resp.setStatus(Status_SUCCESS); + resp.setParameterCount(paramCount); + return resp; + } + }); + + // Mock for executePreparedStatement + when(client.executePreparedStatement(any(TSExecutePreparedReq.class))) + .thenReturn(execStatementResp); + } + + /** Count the number of '?' placeholders in a SQL string, ignoring those inside quotes */ + private int countQuestionMarks(String sql) { + int count = 0; + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + + for (int i = 0; i < sql.length(); i++) { + char c = sql.charAt(i); + + if (c == '\'' && !inDoubleQuote) { + // Check for escaped quote + if (i + 1 < sql.length() && sql.charAt(i + 1) == '\'') { + i++; // Skip escaped quote + } else { + inSingleQuote = !inSingleQuote; + } + } else if (c == '"' && !inSingleQuote) { + inDoubleQuote = !inDoubleQuote; + } else if (c == '?' && !inSingleQuote && !inDoubleQuote) { + count++; + } + } + + return count; + } + + // ========== Table Model SQL Injection Prevention Tests ========== + + @SuppressWarnings("resource") + @Test + public void testTableModelLoginInjectionWithComment() throws Exception { + // Login interface SQL injection attack 1: Using -- comments to bypass password checks + String sql = "SELECT * FROM users WHERE username = ? AND password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "admin' --"); + ps.setString(2, "password"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelLoginInjectionWithORCondition() throws Exception { + // Login interface SQL injection attack 2: Bypassing authentication by using 'OR '1'='1 + String sql = "SELECT * FROM users WHERE username = ? AND password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "admin"); + ps.setString(2, "' OR '1'='1"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelQueryWithMultipleInjectionVectors() throws Exception { + String sql = "SELECT * FROM users WHERE email = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "'; DROP TABLE users;"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelString1() throws Exception { + String sql = "SELECT * FROM users WHERE password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "a'b"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelString2() throws Exception { + String sql = "SELECT * FROM users WHERE password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "a\'b"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelString3() throws Exception { + String sql = "SELECT * FROM users WHERE password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "a\\'b"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelString4() throws Exception { + String sql = "SELECT * FROM users WHERE password = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, "a\\\'b"); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); + } + + @SuppressWarnings("resource") + @Test + public void testTableModelStringWithNull() throws Exception { + String sql = "SELECT * FROM users WHERE email = ?"; + IoTDBTablePreparedStatement ps = + new IoTDBTablePreparedStatement(connection, client, sessionId, sql, zoneId); + ps.setString(1, null); + ps.execute(); + + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); + } +} diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializer.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializer.java new file mode 100644 index 0000000000000..a43df4f9a71c9 --- /dev/null +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializer.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc.stmt; + +import org.apache.tsfile.enums.TSDataType; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Types; +import java.util.ArrayList; +import java.util.List; + +/** Serializer for PreparedStatement parameters. */ +public class PreparedParameterSerializer { + + public static class DeserializedParam { + public final TSDataType type; + public final Object value; + + DeserializedParam(TSDataType type, Object value) { + this.type = type; + this.value = value; + } + + public boolean isNull() { + return type == TSDataType.UNKNOWN || value == null; + } + } + + private PreparedParameterSerializer() {} + + /** Serialize parameters to binary format. */ + public static ByteBuffer serialize(Object[] values, int[] jdbcTypes, int count) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + dos.writeInt(count); + for (int i = 0; i < count; i++) { + serializeParameter(dos, values[i], jdbcTypes[i]); + } + + dos.flush(); + return ByteBuffer.wrap(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize parameters", e); + } + } + + private static void serializeParameter(DataOutputStream dos, Object value, int jdbcType) + throws IOException { + if (value == null || jdbcType == Types.NULL) { + dos.writeByte(TSDataType.UNKNOWN.serialize()); + return; + } + + switch (jdbcType) { + case Types.BOOLEAN: + dos.writeByte(TSDataType.BOOLEAN.serialize()); + dos.writeByte((Boolean) value ? 1 : 0); + break; + + case Types.INTEGER: + dos.writeByte(TSDataType.INT32.serialize()); + dos.writeInt(((Number) value).intValue()); + break; + + case Types.BIGINT: + dos.writeByte(TSDataType.INT64.serialize()); + dos.writeLong(((Number) value).longValue()); + break; + + case Types.FLOAT: + dos.writeByte(TSDataType.FLOAT.serialize()); + dos.writeFloat(((Number) value).floatValue()); + break; + + case Types.DOUBLE: + dos.writeByte(TSDataType.DOUBLE.serialize()); + dos.writeDouble(((Number) value).doubleValue()); + break; + + case Types.VARCHAR: + case Types.CHAR: + byte[] strBytes = ((String) value).getBytes(StandardCharsets.UTF_8); + dos.writeByte(TSDataType.STRING.serialize()); + dos.writeInt(strBytes.length); + dos.write(strBytes); + break; + + case Types.BINARY: + case Types.VARBINARY: + byte[] binBytes = (byte[]) value; + dos.writeByte(TSDataType.BLOB.serialize()); + dos.writeInt(binBytes.length); + dos.write(binBytes); + break; + + default: + byte[] defaultBytes = String.valueOf(value).getBytes(StandardCharsets.UTF_8); + dos.writeByte(TSDataType.STRING.serialize()); + dos.writeInt(defaultBytes.length); + dos.write(defaultBytes); + break; + } + } + + /** Deserialize parameters from binary format. */ + public static List deserialize(ByteBuffer buffer) { + if (buffer == null || buffer.remaining() == 0) { + return new ArrayList<>(); + } + + buffer.rewind(); + int count = buffer.getInt(); + if (count < 0 || count > buffer.remaining()) { + throw new IllegalArgumentException("Invalid parameter count: " + count); + } + + List result = new ArrayList<>(count); + + for (int i = 0; i < count; i++) { + byte typeCode = buffer.get(); + TSDataType type = TSDataType.deserialize(typeCode); + Object value = deserializeValue(buffer, type); + result.add(new DeserializedParam(type, value)); + } + + return result; + } + + private static Object deserializeValue(ByteBuffer buffer, TSDataType type) { + switch (type) { + case UNKNOWN: + return null; + case BOOLEAN: + return buffer.get() != 0; + case INT32: + return buffer.getInt(); + case INT64: + return buffer.getLong(); + case FLOAT: + return buffer.getFloat(); + case DOUBLE: + return buffer.getDouble(); + case TEXT: + case STRING: + int strLen = buffer.getInt(); + byte[] strBytes = new byte[strLen]; + buffer.get(strBytes); + return new String(strBytes, StandardCharsets.UTF_8); + case BLOB: + int binLen = buffer.getInt(); + byte[] binBytes = new byte[binLen]; + buffer.get(binBytes); + return binBytes; + default: + throw new IllegalArgumentException("Unsupported type: " + type); + } + } +} diff --git a/iotdb-client/service-rpc/src/test/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializerTest.java b/iotdb-client/service-rpc/src/test/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializerTest.java new file mode 100644 index 0000000000000..10e3cd4b67881 --- /dev/null +++ b/iotdb-client/service-rpc/src/test/java/org/apache/iotdb/rpc/stmt/PreparedParameterSerializerTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc.stmt; + +import org.apache.iotdb.rpc.stmt.PreparedParameterSerializer.DeserializedParam; + +import org.apache.tsfile.enums.TSDataType; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.sql.Types; +import java.util.List; + +import static org.apache.iotdb.rpc.stmt.PreparedParameterSerializer.deserialize; +import static org.apache.iotdb.rpc.stmt.PreparedParameterSerializer.serialize; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** Unit tests for {@link PreparedParameterSerializer}. */ +public class PreparedParameterSerializerTest { + + @Test + public void testEmptyParameterList() { + ByteBuffer buffer = serialize(new Object[0], new int[0], 0); + List result = deserialize(buffer); + + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + @Test + public void testNullAndEmptyBuffer() { + assertTrue(deserialize(null).isEmpty()); + assertTrue(deserialize(ByteBuffer.allocate(0)).isEmpty()); + } + + @Test + public void testNullValue() { + ByteBuffer buffer = serialize(new Object[] {null}, new int[] {Types.VARCHAR}, 1); + List result = deserialize(buffer); + + assertEquals(1, result.size()); + assertTrue(result.get(0).isNull()); + } + + @Test + public void testAllDataTypes() { + Object[] values = {true, 42, 123456789L, 3.14f, 2.71828, "hello", new byte[] {1, 2, 3}}; + int[] types = { + Types.BOOLEAN, + Types.INTEGER, + Types.BIGINT, + Types.FLOAT, + Types.DOUBLE, + Types.VARCHAR, + Types.BINARY + }; + + ByteBuffer buffer = serialize(values, types, 7); + List result = deserialize(buffer); + + assertEquals(7, result.size()); + assertEquals(TSDataType.BOOLEAN, result.get(0).type); + assertEquals(true, result.get(0).value); + assertEquals(TSDataType.INT32, result.get(1).type); + assertEquals(42, result.get(1).value); + assertEquals(TSDataType.INT64, result.get(2).type); + assertEquals(123456789L, result.get(2).value); + assertEquals(TSDataType.FLOAT, result.get(3).type); + assertEquals(3.14f, (Float) result.get(3).value, 0.0001f); + assertEquals(TSDataType.DOUBLE, result.get(4).type); + assertEquals(2.71828, (Double) result.get(4).value, 0.00001); + assertEquals(TSDataType.STRING, result.get(5).type); + assertEquals("hello", result.get(5).value); + assertEquals(TSDataType.BLOB, result.get(6).type); + assertArrayEquals(new byte[] {1, 2, 3}, (byte[]) result.get(6).value); + } + + @Test + public void testUnicodeString() { + ByteBuffer buffer = serialize(new Object[] {"你好🌍"}, new int[] {Types.VARCHAR}, 1); + List result = deserialize(buffer); + + assertEquals("你好🌍", result.get(0).value); + } + + @Test + public void testMixedNullAndValues() { + Object[] values = {"hello", null, 42}; + int[] types = {Types.VARCHAR, Types.INTEGER, Types.INTEGER}; + + ByteBuffer buffer = serialize(values, types, 3); + List result = deserialize(buffer); + + assertEquals(3, result.size()); + assertEquals("hello", result.get(0).value); + assertTrue(result.get(1).isNull()); + assertEquals(42, result.get(2).value); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidParameterCount() { + ByteBuffer buffer = ByteBuffer.allocate(4); + buffer.putInt(-1); + buffer.flip(); + deserialize(buffer); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/OperationType.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/OperationType.java index e461f0cc45d16..9c44de9f5fdca 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/OperationType.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/OperationType.java @@ -52,7 +52,10 @@ public enum OperationType { EXECUTE_NON_QUERY_PLAN("executeNonQueryPlan"), SELECT_INTO("selectInto"), QUERY_LATENCY("queryLatency"), - WRITE_AUDIT_LOG("writeAuditLog"); + WRITE_AUDIT_LOG("writeAuditLog"), + PREPARE_STATEMENT("prepareStatement"), + EXECUTE_PREPARED_STATEMENT("executePreparedStatement"), + DEALLOCATE_PREPARED_STATEMENT("deallocatePreparedStatement"); private final String name; OperationType(String name) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index 167a1fa914fd2..e46429523d81a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -75,6 +75,7 @@ import org.apache.iotdb.db.queryengine.plan.analyze.schema.ISchemaFetcher; import org.apache.iotdb.db.queryengine.plan.execution.ExecutionResult; import org.apache.iotdb.db.queryengine.plan.execution.IQueryExecution; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PreparedStatementHelper; import org.apache.iotdb.db.queryengine.plan.parser.ASTVisitor; import org.apache.iotdb.db.queryengine.plan.parser.StatementGenerator; import org.apache.iotdb.db.queryengine.plan.planner.LocalExecutionPlanner; @@ -89,7 +90,17 @@ import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.cache.TableId; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.cache.TreeDeviceSchemaCacheManager; import org.apache.iotdb.db.queryengine.plan.relational.security.TreeAccessCheckContext; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ParameterExtractor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BinaryLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DoubleLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Execute; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Identifier; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SetSqlDialect; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Use; import org.apache.iotdb.db.queryengine.plan.relational.sql.parser.ParsingException; import org.apache.iotdb.db.queryengine.plan.relational.sql.parser.SqlParser; @@ -129,6 +140,8 @@ import org.apache.iotdb.db.utils.SetThreadName; import org.apache.iotdb.rpc.RpcUtils; import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.rpc.stmt.PreparedParameterSerializer; +import org.apache.iotdb.rpc.stmt.PreparedParameterSerializer.DeserializedParam; import org.apache.iotdb.service.rpc.thrift.ServerProperties; import org.apache.iotdb.service.rpc.thrift.TCreateTimeseriesUsingSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TPipeSubscribeReq; @@ -146,9 +159,11 @@ import org.apache.iotdb.service.rpc.thrift.TSCreateMultiTimeseriesReq; import org.apache.iotdb.service.rpc.thrift.TSCreateSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSCreateTimeseriesReq; +import org.apache.iotdb.service.rpc.thrift.TSDeallocatePreparedReq; import org.apache.iotdb.service.rpc.thrift.TSDeleteDataReq; import org.apache.iotdb.service.rpc.thrift.TSDropSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteBatchStatementReq; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; import org.apache.iotdb.service.rpc.thrift.TSFastLastDataQueryForOneDeviceReq; @@ -169,6 +184,8 @@ import org.apache.iotdb.service.rpc.thrift.TSLastDataQueryReq; import org.apache.iotdb.service.rpc.thrift.TSOpenSessionReq; import org.apache.iotdb.service.rpc.thrift.TSOpenSessionResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; import org.apache.iotdb.service.rpc.thrift.TSProtocolVersion; import org.apache.iotdb.service.rpc.thrift.TSPruneSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSQueryDataSet; @@ -1488,6 +1505,174 @@ public TSStatus closeOperation(TSCloseOperationReq req) { COORDINATOR::cleanupQueryExecution); } + // ========================= PreparedStatement RPC Methods ========================= + + @Override + public TSPrepareResp prepareStatement(TSPrepareReq req) { + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return new TSPrepareResp(getNotLoggedInStatus()); + } + + try { + String sql = req.getSql(); + String statementName = req.getStatementName(); + + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement = + relationSqlParser.createStatement(sql, clientSession.getZoneId(), clientSession); + + if (statement == null) { + return new TSPrepareResp( + RpcUtils.getStatus(TSStatusCode.SQL_PARSE_ERROR, "Failed to parse SQL: " + sql)); + } + + int parameterCount = ParameterExtractor.getParameterCount(statement); + + PreparedStatementHelper.register(clientSession, statementName, statement); + + TSPrepareResp resp = new TSPrepareResp(RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS)); + resp.setParameterCount(parameterCount); + return resp; + } catch (Exception e) { + return new TSPrepareResp( + onQueryException( + e, OperationType.PREPARE_STATEMENT.getName(), TSStatusCode.INTERNAL_SERVER_ERROR)); + } + } + + @Override + public TSExecuteStatementResp executePreparedStatement(TSExecutePreparedReq req) { + boolean finished = false; + long queryId = Long.MIN_VALUE; + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return RpcUtils.getTSExecuteStatementResp(getNotLoggedInStatus()); + } + + long startTime = System.nanoTime(); + Throwable t = null; + try { + String statementName = req.getStatementName(); + + List rawParams = + PreparedParameterSerializer.deserialize(ByteBuffer.wrap(req.getParameters())); + List parameters = new ArrayList<>(rawParams.size()); + for (DeserializedParam param : rawParams) { + parameters.add(convertToLiteral(param)); + } + + Execute executeStatement = new Execute(new Identifier(statementName), parameters); + + queryId = SESSION_MANAGER.requestQueryId(clientSession, req.getStatementId()); + + long timeout = req.isSetTimeout() ? req.getTimeout() : config.getQueryTimeoutThreshold(); + ExecutionResult result = + COORDINATOR.executeForTableModel( + executeStatement, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + "EXECUTE " + statementName, + metadata, + timeout, + true); + + if (result.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode() + && result.status.code != TSStatusCode.REDIRECTION_RECOMMEND.getStatusCode()) { + finished = true; + return RpcUtils.getTSExecuteStatementResp(result.status); + } + + IQueryExecution queryExecution = COORDINATOR.getQueryExecution(queryId); + + try (SetThreadName threadName = new SetThreadName(result.queryId.getId())) { + TSExecuteStatementResp resp; + if (queryExecution != null && queryExecution.isQuery()) { + resp = createResponse(queryExecution.getDatasetHeader(), queryId); + resp.setStatus(result.status); + int fetchSize = + req.isSetFetchSize() ? req.getFetchSize() : config.getThriftMaxFrameSize(); + finished = setResultForPrepared.apply(resp, queryExecution, fetchSize); + resp.setMoreData(!finished); + } else { + finished = true; + resp = RpcUtils.getTSExecuteStatementResp(result.status); + } + return resp; + } + } catch (Exception e) { + finished = true; + t = e; + return RpcUtils.getTSExecuteStatementResp( + onQueryException( + e, + OperationType.EXECUTE_PREPARED_STATEMENT.getName(), + TSStatusCode.INTERNAL_SERVER_ERROR)); + } finally { + long currentOperationCost = System.nanoTime() - startTime; + if (finished) { + COORDINATOR.cleanupQueryExecution(queryId, null, t); + } + COORDINATOR.recordExecutionTime(queryId, currentOperationCost); + } + } + + @Override + public TSStatus deallocatePreparedStatement(TSDeallocatePreparedReq req) { + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return getNotLoggedInStatus(); + } + + try { + PreparedStatementHelper.unregister(clientSession, req.getStatementName()); + return RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS); + } catch (Exception e) { + return onQueryException( + e, + OperationType.DEALLOCATE_PREPARED_STATEMENT.getName(), + TSStatusCode.INTERNAL_SERVER_ERROR); + } + } + + private Literal convertToLiteral(DeserializedParam param) { + if (param.isNull()) { + return new NullLiteral(); + } + + switch (param.type) { + case BOOLEAN: + return new BooleanLiteral((Boolean) param.value ? "true" : "false"); + case INT32: + case INT64: + return new LongLiteral(String.valueOf(param.value)); + case FLOAT: + return new DoubleLiteral((Float) param.value); + case DOUBLE: + return new DoubleLiteral((Double) param.value); + case TEXT: + case STRING: + return new StringLiteral((String) param.value); + case BLOB: + return new BinaryLiteral((byte[]) param.value); + default: + throw new IllegalArgumentException("Unknown parameter type: " + param.type); + } + } + + private final SelectResult setResultForPrepared = + (resp, queryExecution, fetchSize) -> { + // Use V2 format (queryResult) to match IoTDBTablePreparedStatement client + Pair, Boolean> pair = + QueryDataSetUtils.convertQueryResultByFetchSize(queryExecution, fetchSize); + resp.setQueryResult(pair.left); + return pair.right; + }; + + // ========================= End PreparedStatement RPC Methods ========================= + @Override public TSGetTimeZoneResp getTimeZone(long sessionId) { try { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java index 6f5f3f4846154..973f7cfa1f2cc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.session; -import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.protocol.session.IClientSession; -import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; @@ -54,19 +52,12 @@ public ListenableFuture execute(IConfigTaskExecutor configTask return future; } - // Remove the prepared statement - PreparedStatementInfo removedInfo = session.removePreparedStatement(statementName); - if (removedInfo == null) { - future.setException( - new SemanticException( - String.format("Prepared statement '%s' does not exist", statementName))); - return future; + try { + PreparedStatementHelper.unregister(session, statementName); + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + } catch (Exception e) { + future.setException(e); } - - // Release the memory allocated for this PreparedStatement from the shared MemoryBlock - PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); - - future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); return future; } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java index 62c59d5bf785e..c83808f4cecaf 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.session; -import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.protocol.session.IClientSession; -import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; @@ -58,27 +56,12 @@ public ListenableFuture execute(IConfigTaskExecutor configTask return future; } - // Check if prepared statement with the same name already exists - PreparedStatementInfo existingInfo = session.getPreparedStatement(statementName); - if (existingInfo != null) { - future.setException( - new SemanticException( - String.format("Prepared statement '%s' already exists.", statementName))); - return future; + try { + PreparedStatementHelper.register(session, statementName, sql); + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + } catch (Exception e) { + future.setException(e); } - - // Estimate memory size of the AST - long memorySizeInBytes = sql == null ? 0L : sql.ramBytesUsed(); - - // Allocate memory from CoordinatorMemoryManager - // This memory is shared across all sessions using a single MemoryBlock - PreparedStatementMemoryManager.getInstance().allocate(statementName, memorySizeInBytes); - - // Create and store the prepared statement info (AST is cached) - PreparedStatementInfo info = new PreparedStatementInfo(statementName, sql, memorySizeInBytes); - session.addPreparedStatement(statementName, info); - - future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); return future; } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementHelper.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementHelper.java new file mode 100644 index 0000000000000..f5ef405332d6a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementHelper.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +/** Helper for prepared statement registration/unregistration. */ +public class PreparedStatementHelper { + + private PreparedStatementHelper() {} + + /** Registers a prepared statement in the session. */ + public static PreparedStatementInfo register( + IClientSession session, String statementName, Statement sql) { + if (session.getPreparedStatement(statementName) != null) { + throw new SemanticException( + String.format("Prepared statement '%s' already exists", statementName)); + } + + long memorySizeInBytes = sql == null ? 0L : sql.ramBytesUsed(); + + PreparedStatementMemoryManager.getInstance().allocate(statementName, memorySizeInBytes); + + PreparedStatementInfo info = new PreparedStatementInfo(statementName, sql, memorySizeInBytes); + session.addPreparedStatement(statementName, info); + + return info; + } + + /** Unregisters a prepared statement from the session. */ + public static PreparedStatementInfo unregister(IClientSession session, String statementName) { + PreparedStatementInfo removedInfo = session.removePreparedStatement(statementName); + if (removedInfo == null) { + throw new SemanticException( + String.format("Prepared statement '%s' does not exist", statementName)); + } + + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + + return removedInfo; + } +} diff --git a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift index 48afb89d33661..7ff12ceb02154 100644 --- a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift +++ b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift @@ -167,6 +167,34 @@ struct TSCloseOperationReq { 4: optional string preparedStatementName } +// PREPARE +struct TSPrepareReq { + 1: required i64 sessionId + 2: required string sql + 3: required string statementName +} + +struct TSPrepareResp { + 1: required common.TSStatus status + 2: optional i32 parameterCount +} + +// EXECUTE +struct TSExecutePreparedReq { + 1: required i64 sessionId + 2: required string statementName + 3: required binary parameters + 4: optional i32 fetchSize + 5: optional i64 timeout + 6: required i64 statementId +} + +// DEALLOCATE +struct TSDeallocatePreparedReq { + 1: required i64 sessionId + 2: required string statementName +} + struct TSFetchResultsReq{ 1: required i64 sessionId 2: required string statement @@ -576,6 +604,13 @@ service IClientRPCService { common.TSStatus closeOperation(1:TSCloseOperationReq req); + // PreparedStatement operations + TSPrepareResp prepareStatement(1:TSPrepareReq req); + + TSExecuteStatementResp executePreparedStatement(1:TSExecutePreparedReq req); + + common.TSStatus deallocatePreparedStatement(1:TSDeallocatePreparedReq req); + TSGetTimeZoneResp getTimeZone(1:i64 sessionId); common.TSStatus setTimeZone(1:TSSetTimeZoneReq req);