Skip to content

Commit

Permalink
remove pruning on data nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Handalian <marc.handalian@gmail.com>
  • Loading branch information
mch2 committed Jan 30, 2025
1 parent 396c871 commit a3b8941
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 74 deletions.
6 changes: 4 additions & 2 deletions libs/datafusion/jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use datafusion::execution::SendableRecordBatchStream;
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
use futures::stream::TryStreamExt;
use jni::objects::{JByteArray, JClass, JObject, JString};
use jni::sys::jlong;
use jni::sys::{jint, jlong};
use jni::JNIEnv;
use std::io::BufWriter;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -44,14 +44,16 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFusion_agg(
runtime: jlong,
ctx: jlong,
ticket: JByteArray,
size: jint,
callback: JObject,
) {
let input = env.convert_byte_array(&ticket).unwrap();
env.lon
let context = unsafe { &mut *(ctx as *mut SessionContext) };
let runtime = unsafe { &mut *(runtime as *mut Runtime) };

runtime.block_on(async {
let result = provider::read_aggs(context.clone(), Bytes::from(input)).await;
let result = provider::read_aggs(context.clone(), Bytes::from(input), size).await;
let addr = result.map(|df| Box::into_raw(Box::new(df)));
set_object_result(&mut env, callback, addr);
});
Expand Down
25 changes: 14 additions & 11 deletions libs/datafusion/jni/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::Context;
use arrow_flight::{flight_service_client::FlightServiceClient, FlightDescriptor};
use bytes::Bytes;
use datafusion::common::JoinType;
Expand Down Expand Up @@ -32,12 +33,15 @@ pub async fn query(
pub async fn read_aggs(
ctx: SessionContext,
ticket: Bytes,
size: usize
) -> datafusion::common::Result<DataFrame> {
let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket, "http://localhost:9400".to_owned()).await?;
let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket, "http://localhost:9450".to_owned()).await?;
// Ok(df)
// df.clone().explain(true, true)?.collect().await?;
df.filter(col("ord").is_not_null())?
.aggregate(vec![col("ord")], vec![sum(col("count")).alias("count")])?
.sort(vec![col("count").sort(false, true)])? // Sort by count descending
.limit(0, Some(500)) // Get top 500 results
.limit(0, Some(500)) // Get top 500 results
}

// inner join two tables together, returning a single DataFrame that can be consumed
Expand Down Expand Up @@ -105,7 +109,6 @@ async fn get_dataframe_for_tickets(
ticket: Bytes,
entry_point: String
) -> Result<DataFrame> {
println!("Register");
register_table(ctx, name.clone(), ticket, entry_point.clone())
.and_then(|_| ctx.table(&name))
.await
Expand All @@ -119,7 +122,7 @@ async fn register_table(ctx: &SessionContext, name: String, ticket: Bytes, entry
.open_table(entry_point, HashMap::new())
.await
.map_err(|e| DataFusionError::Execution(format!("Error creating table: {}", e)))?;
println!("Registering table {:?}", table);
// println!("Registering table {:?}", table);
ctx.register_table(name, Arc::new(table))
.map_err(|e| DataFusionError::Execution(format!("Error registering table: {}", e)))?;
// let df = ctx.sql("SHOW TABLES").await?;
Expand All @@ -140,27 +143,27 @@ impl FlightDriver for TicketedFlightDriver {
channel: Channel,
_options: &HashMap<String, String>,
) -> arrow_flight::error::Result<FlightMetadata> {
println!("DRIVER: options: {:?}", _options);
// println!("DRIVER: options: {:?}", _options);

let mut client: FlightServiceClient<Channel> = FlightServiceClient::new(channel.clone());

println!("DRIVER: Using ticket: {:?}", self.ticket);
// println!("DRIVER: Using ticket: {:?}", self.ticket);

let descriptor = FlightDescriptor::new_cmd(self.ticket.clone());
println!("DRIVER: Created descriptor: {:?}", descriptor);
// println!("DRIVER: Created descriptor: {:?}", descriptor);

let request = tonic::Request::new(descriptor);
println!("DRIVER: Sending get_flight_info request");
// println!("DRIVER: Sending get_flight_info request");

match client.get_flight_info(request).await {
Ok(info) => {
println!("DRIVER: Received flight info response");
// println!("DRIVER: Received flight info response");
let info = info.into_inner();
println!("DRIVER: Flight info: {:?}", info);
// println!("DRIVER: Flight info: {:?}", info);
FlightMetadata::try_new(info, FlightProperties::default())
}
Err(status) => {
println!("DRIVER: Error getting flight info: {:?}", status);
// println!("DRIVER: Error getting flight info: {:?}", status);
Err(status.into())
}
}
Expand Down
87 changes: 61 additions & 26 deletions libs/datafusion/jni/src/provider/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use std::sync::atomic::{AtomicUsize, Ordering};

use arrow::array::{StringArray, StructArray};
use arrow::datatypes::SchemaRef;
use arrow::ipc;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
Expand Down Expand Up @@ -36,13 +38,14 @@ const BEARER_TOKEN: &str = "Bearer flight-sql-token";

struct TestFlightService {
flight_info: FlightInfo,
partition_data: RecordBatch,
schema: SchemaRef,
batch_counter: Arc<AtomicUsize>,
shutdown_sender: Option<Sender<()>>,
}

impl TestFlightService {
async fn run_in_background(self, rx: Receiver<()>) -> SocketAddr {
let addr = SocketAddr::from(([127, 0, 0, 1], 8815));
let addr = SocketAddr::from(([127, 0, 0, 1], 9450));
let listener = TcpListener::bind(addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let service = FlightServiceServer::new(self);
Expand All @@ -60,6 +63,27 @@ impl TestFlightService {
tokio::time::sleep(Duration::from_millis(25)).await;
addr
}

fn generate_batch(schema: SchemaRef, batch_num: usize) -> RecordBatch {
println!("TEST SERVER: Generating batch {}", batch_num);
match batch_num {
0 => RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["A", "B"])),
Arc::new(Int64Array::from(vec![10, 20])),
],
).unwrap(),
1 => RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["A", "C"])),
Arc::new(Int64Array::from(vec![30, 40])),
],
).unwrap(),
_ => panic!("No more batches to generate"),
}
}
}

impl Drop for TestFlightService {
Expand Down Expand Up @@ -98,12 +122,34 @@ impl FlightService for TestFlightService {
) -> Result<Response<Self::DoGetStream>, Status> {
println!("TEST SERVER: do_get called");
println!("TEST SERVER: ticket: {:?}", request.get_ref());

let data = self.partition_data.clone();
let rb = async move { Ok(data) };

let counter = self.batch_counter.clone();
let schema = self.schema.clone();

// Create an async stream that generates batches on demand
let stream = futures::stream::unfold(0usize, move |current_batch| {
let counter = counter.clone();
let schema = schema.clone();

async move {
if current_batch >= 2 { // Total number of batches we want to generate
return None;
}

// Simulate some async work before generating the batch
tokio::time::sleep(Duration::from_millis(100)).await;

counter.fetch_add(1, Ordering::SeqCst);
println!("TEST SERVER: Preparing batch {}", current_batch);

let batch = Self::generate_batch(schema, current_batch);
Some((Ok(batch), current_batch + 1))
}
});
// Build the Flight data encoder with the schema from the first batch
let stream = FlightDataEncoderBuilder::default()
.with_schema(self.partition_data.schema())
.build(stream::once(rb))
.with_schema(self.schema.clone())
.build(stream)
.map_err(|e| Status::internal(e.to_string()));

Ok(Response::new(Box::pin(stream)))
Expand Down Expand Up @@ -177,29 +223,24 @@ impl FlightService for TestFlightService {

#[tokio::test]
async fn test_read_aggs() -> datafusion::common::Result<()> {
// Create test data
let partition_data: RecordBatch = RecordBatch::try_new(
Arc::new(Schema::new([
Arc::new(Field::new("ord", DataType::Utf8, false)),
Arc::new(Field::new("count", DataType::Int64, false)),
])),
vec![
Arc::new(StringArray::from(vec!["A", "B", "A", "C"])),
Arc::new(Int64Array::from(vec![10, 20, 30, 40])),
],
)?;
// Create test data schema
let schema = Arc::new(Schema::new([
Arc::new(Field::new("ord", DataType::Utf8, false)),
Arc::new(Field::new("count", DataType::Int64, false)),
]));

// Set up flight endpoint
let endpoint = FlightEndpoint::default().with_ticket(Ticket::new("bytes".as_bytes()));
let flight_info = FlightInfo::default()
.try_with_schema(partition_data.schema().as_ref())?
.try_with_schema(schema.as_ref())?
.with_endpoint(endpoint);

// Set up test service
let (tx, rx) = channel();
let service = TestFlightService {
flight_info,
partition_data,
schema: schema.clone(),
batch_counter: Arc::new(AtomicUsize::new(0)),
shutdown_sender: Some(tx),
};

Expand All @@ -214,12 +255,6 @@ impl FlightService for TestFlightService {
// Set up session context
let config = SessionConfig::new().with_batch_size(1);
let ctx = SessionContext::new_with_config(config);
// let props_template = FlightProperties::new().with_reusable_flight_info(true);
// let driver = FlightSqlDriver::new().with_properties_template(props_template);
// ctx.state_ref().write().table_factories_mut().insert(
// "FLIGHT_SQL".into(),
// Arc::new(FlightTableFactory::new(Arc::new(driver))),
// );

// Create test bytes for ticket
let test_bytes = Bytes::from("test_ticket");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
df = frameSupplier.apply(rootTicket).join();
recordBatchStream = df.getStream(allocator, root).get();
while (recordBatchStream.loadNextBatch().join()) {
// logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount());
// wait for a signal to load the next batch
flushSignal.awaitConsumption(1000);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;

Expand All @@ -27,7 +26,7 @@ public class DataFusion {
// create a DataFrame from a list of tickets.
static native void query(long runtime, long ctx, byte[] ticket, ObjectResultCallback callback);

static native void agg(long runtime, long ctx, byte[] ticket, ObjectResultCallback callback);
static native void agg(long runtime, long ctx, byte[] ticket, int size, ObjectResultCallback callback);

// collect the DataFrame
static native void collect(long runtime, long df, BiConsumer<String, byte[]> callback);
Expand All @@ -48,10 +47,10 @@ public static CompletableFuture<DataFrame> query(byte[] ticket) {
return future;
}

public static CompletableFuture<DataFrame> agg(byte[] ticket) {
public static CompletableFuture<DataFrame> agg(byte[] ticket, int size) {
SessionContext ctx = new SessionContext();
CompletableFuture<DataFrame> future = new CompletableFuture<>();
DataFusion.agg(ctx.getRuntime(), ctx.getPointer(), ticket, (err, ptr) -> {
DataFusion.agg(ctx.getRuntime(), ctx.getPointer(), ticket, size, (err, ptr) -> {
if (err != null) {
future.completeExceptionally(new RuntimeException(err));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.stream.StreamSearchResult;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -88,6 +91,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas

private final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure;
private final Collection<AggregationBuilder> aggregationBuilders;

/**
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
Expand Down Expand Up @@ -118,6 +122,7 @@ public QueryPhaseResultConsumer(
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
this.aggregationBuilders = source.aggregations().getAggregatorFactories();
}

@Override
Expand Down Expand Up @@ -146,8 +151,12 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {

SearchPhaseController.ReducedQueryPhase reducePhase = null;
if (results.get(0) instanceof StreamSearchResult) {
int size = 0;
for (AggregationBuilder aggregatorFactory : aggregationBuilders) {
size = ((TermsAggregationBuilder) aggregatorFactory).size();
}
reducePhase = controller.reducedAggsFromStream(results.asList()
.stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList()));
.stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList()), size);
} else {
final SearchPhaseController.TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats();
final List<TopDocs> topDocsList = pendingMerges.consumeTopDocs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,15 +726,16 @@ static int getTopDocsSize(SearchRequest request) {

public static Logger logger = LogManager.getLogger(SearchPhaseController.class);

public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) throws Exception {
public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list, int size) throws Exception {

List<byte[]> tickets = list.stream().flatMap(r -> r.getFlightTickets().stream())
.map(OSTicket::getBytes)
.collect(Collectors.toList());

Set<StreamTicket> streamTickets = tickets.stream().map(t -> streamManager.getStreamTicketFactory().fromBytes(t)).collect(Collectors.toSet());

DataFrameStreamProducer producer = AccessController.doPrivileged((PrivilegedAction<DataFrameStreamProducer>) () ->
new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes())));
new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes(), size)));

logger.info("Register stream at coordinator");
return AccessController.doPrivileged((PrivilegedAction<ReducedQueryPhase>) () -> {
Expand All @@ -748,7 +749,6 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
List<StringTerms.Bucket> buckets = new ArrayList<>();
logger.info("Starting iteration at coordinator");
while (streamIterator.next()) {
logger.info(root.getRowCount());
int rowCount = root.getRowCount();
totalRows+= rowCount;

Expand All @@ -763,24 +763,6 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
}
}

// recordBatchStream.close();
// dataFrame.close();
// while (streamIterator.next()) {
// int rowCount = root.getRowCount();
// totalRows+= rowCount;
// logger.info("AT COORD Record Batch with " + rowCount + " rows:");
//
// // Iterate through rows
// for (int row = 0; row < rowCount; row++) {
// FieldVector ordKey = root.getVector("ord");
// String ordName = (String) getValue(ordKey, row);
// Float8Vector count = (Float8Vector) root.getVector("count");
//
// Double bucketCount = (Double) getValue(count, row);
// logger.info("Ord: " + ordName + " Bucket Count" + bucketCount + "NodeID: ");
// buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW));
// }
// }
aggs.add(new StringTerms(
"category",
InternalOrder.key(true),
Expand Down
Loading

0 comments on commit a3b8941

Please sign in to comment.