Skip to content

Commit

Permalink
fix: fix java client memory leak and remove prepared statement (#147)
Browse files Browse the repository at this point in the history
* fix: fix java client memory leak and remove prepared statement
  • Loading branch information
Ma1oneZhang authored Sep 24, 2024
1 parent c1455d0 commit 8e3b35e
Showing 1 changed file with 74 additions and 207 deletions.
281 changes: 74 additions & 207 deletions zh_CN/development-guide/arrow-flight-sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,48 +693,33 @@ import org.slf4j.LoggerFactory;

import java.util.*;


public class SqlRunner {

private static final Logger log = LoggerFactory.getLogger(SqlRunner.class);

public static void main(String[] args) {
BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final Location clientLocation = Location.forGrpcInsecure("127.0.0.1", 8360);

FlightClient client = FlightClient.builder(allocator, clientLocation).build();
FlightSqlClient sqlClient = new FlightSqlClient(client);

Optional<CredentialCallOption> credentialCallOption = client.authenticateBasicToken("admin", "public");
final CallHeaders headers = new FlightCallHeaders();
Set<CallOption> options = new HashSet<>();
headers.insert("database", "test");

credentialCallOption.ifPresent(options::add);
options.add(new HeaderCallOption(headers));
CallOption[] callOptions = options.toArray(new CallOption[0]);

try {
final FlightInfo info = sqlClient.execute("create database test", callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("create database result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
static void run_flight_sql() throws Exception {
try (BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
final Location clientLocation = Location.forGrpcInsecure("127.0.0.1", 8360);
try (FlightClient client = FlightClient.builder(allocator, clientLocation).build();
FlightSqlClient sqlClient = new FlightSqlClient(client)) {

Optional<CredentialCallOption> credentialCallOption = client.authenticateBasicToken("admin", "public");
CallHeaders headers = new FlightCallHeaders();
headers.insert("database", "test");

Set<CallOption> options = new HashSet<>();
credentialCallOption.ifPresent(options::add);
options.add(new HeaderCallOption(headers));
try {
String query = "create database test";
executeQuery(sqlClient, query, options);
} catch (Exception e){
e.printStackTrace();
throw e;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} catch (Exception e){
throw new RuntimeException(e);
}

try {
String query = "CREATE TABLE test.sx1 (" +
try {
String query = "CREATE TABLE test.sx1 (" +
"ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP," +
"sid INT32," +
"value REAL," +
Expand All @@ -743,190 +728,69 @@ public class SqlRunner {
")" +
"PARTITION BY HASH(sid) PARTITIONS 32" +
"ENGINE=TimeSeries";
final FlightInfo info = sqlClient.execute(query, callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("create table result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
executeQuery(sqlClient, query, options);
} catch (Exception e){
e.printStackTrace();
throw e;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} catch (Exception e){
throw new RuntimeException(e);
}

try {
String query = "INSERT INTO test.sx1 (sid, value, flag) VALUES (1, 1.1, 1);";
final FlightInfo info = sqlClient.execute(query, callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("insert result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} catch (Exception e){
throw new RuntimeException(e);
}

try {
String query = "SELECT count(*) from test.sx1;";
final FlightInfo info = sqlClient.execute(query, callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("select result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
try {
String query = "INSERT INTO test.sx1 (sid, value, flag) VALUES (1, 1.1, 1);";
executeQuery(sqlClient, query, options);
} catch (Exception e){
e.printStackTrace();
throw e;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} catch (Exception e){
throw new RuntimeException(e);
}

// prepared statement insert in java
// `insert into table sx1 (sid, value, flag) values(?, ?, ?);`
// The ? is a placeholder for the actual value that will be inserted.
// The actual value is set using the setParameters method.
// please make sure each col has same length

// need callOptions argument here~
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare("insert into table sx1 (sid, value, flag) values(?, ?, ?);", callOptions)) {
IntVector sids = new IntVector("sid",allocator);
sids.allocateNew();
Float4Vector values = new Float4Vector("value",allocator);
values.allocateNew();
TinyIntVector flags = new TinyIntVector("flag",allocator);

sids.setSafe(0,1);
values.setSafe(0, 1.0F);
flags.setSafe(0, (byte)1);

List<Field> fields = Arrays.asList(sids.getField(), values.getField(), flags.getField());
List<FieldVector> fieldVectors = Arrays.asList(sids, values, flags);
VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fields, fieldVectors);
vectorSchemaRoot.setRowCount(1);
preparedStatement.setParameters(vectorSchemaRoot);

// need callOptions argument here~
final FlightInfo info = preparedStatement.execute(callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();

// need callOptions argument here~
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("prepared statement get result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
// insert with multiple values
try {
String query = "INSERT INTO test.sx1 (sid, value, flag) VALUES (1, 1.1, 1), (1, 1.1, 1), (1, 1.1, 1), (1, 1.1, 1);";
executeQuery(sqlClient, query, options);
} catch (Exception e){
e.printStackTrace();
throw e;
}
} catch (Exception e) {
throw new RuntimeException(e);
}

// need callOptions argument here~
preparedStatement.close(callOptions);
}
// batch prepared statement insertion
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare("insert into table sx1 (sid, value, flag) values(?, ?, ?);", callOptions)) {
IntVector sids = new IntVector("sid",allocator);
sids.allocateNew();
Float4Vector values = new Float4Vector("value",allocator);
values.allocateNew();
TinyIntVector flags = new TinyIntVector("flag",allocator);

for (int i = 0;i < 100;i ++){
sids.setSafe(i, i);
values.setSafe(i, (float) i);
flags.setSafe(i, (byte)i);
}

List<Field> fields = Arrays.asList(sids.getField(), values.getField(), flags.getField());
List<FieldVector> fieldVectors = Arrays.asList(sids, values, flags);
VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fields, fieldVectors);
// remember set right row count
vectorSchemaRoot.setRowCount(100);
preparedStatement.setParameters(vectorSchemaRoot);
final FlightInfo info = preparedStatement.execute(callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("prepared batch insert statement get result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
try {
String query = "SELECT count(*) from test.sx1;";
executeQuery(sqlClient, query, options);
} catch (Exception e){
e.printStackTrace();
throw e;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
preparedStatement.close(callOptions);
}
}

// need callOptions argument here~
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare("select count(*) from test.sx1 where sid = ?;", callOptions)) {
IntVector sids = new IntVector("sid",allocator);
sids.allocateNew();

sids.setSafe(0,1);
List<Field> fields = Arrays.asList(sids.getField());
List<FieldVector> fieldVectors = Arrays.asList(sids);
VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(fields, fieldVectors);
vectorSchemaRoot.setRowCount(1);
preparedStatement.setParameters(vectorSchemaRoot);

// need callOptions argument here~
final FlightInfo info = preparedStatement.execute(callOptions);
final Ticket ticket = info.getEndpoints().get(0).getTicket();

// need callOptions argument here~
try (FlightStream stream = sqlClient.getStream(ticket, callOptions)) {
int n = 0;
while (stream.next()) {
System.out.println("prepared statement get result:");
List<FieldVector> vectors = stream.getRoot().getFieldVectors();
for (int i = 0; i < vectors.size(); i++) {
System.out.printf("%d %d %s\n", n, i , vectors.get(i));
}
n++;
private static void executeQuery(FlightSqlClient sqlClient, String query, Set<CallOption> options) throws Exception {
final FlightInfo info = sqlClient.execute(query, options.toArray(new CallOption[0]));
final Ticket ticket = info.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket, options.toArray(new CallOption[0]))) {
while (stream.next()) {
try (VectorSchemaRoot schemaRoot = stream.getRoot()) {
// // How to get single element
// // You can cast the FieldVector class to some class Like TinyIntVector and so on.
// // You can get the type mapping from arrow official website
// List<FieldVector> vectors = schemaRoot.getFieldVectors();
// for (int i = 0; i < vectors.size(); i++) {
// System.out.printf("Col :%d %s\n", i, vectors.get(i));
// }
log.info(schemaRoot.contentToTSVString());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
preparedStatement.close(callOptions);
}
}

public static void main(String[] args) throws Exception {
run_flight_sql();
}
}

// deps
/*
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<build>
<plugins>
Expand Down Expand Up @@ -966,12 +830,6 @@ public class SqlRunner {
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.32</version>
<scope>provided</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.arrow/arrow-flight -->
<dependency>
Expand Down Expand Up @@ -1008,6 +866,15 @@ public class SqlRunner {
<artifactId>flight-core</artifactId>
<version>${arrow.version}</version>
</dependency>
<!-- Add it for example logging, you can remove it when you wanna use your own logger -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.9</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</project>
*/
Expand Down

0 comments on commit 8e3b35e

Please sign in to comment.