From a3b8941a463785981f5c522153c19dd27caedf22 Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Wed, 29 Jan 2025 17:23:40 -0800 Subject: [PATCH] remove pruning on data nodes Signed-off-by: Marc Handalian --- libs/datafusion/jni/src/lib.rs | 6 +- libs/datafusion/jni/src/provider.rs | 25 +++--- libs/datafusion/jni/src/provider/test/mod.rs | 87 +++++++++++++------ .../DataFrameStreamProducer.java | 1 - .../org.opensearch.datafusion/DataFusion.java | 7 +- .../search/QueryPhaseResultConsumer.java | 11 ++- .../action/search/SearchPhaseController.java | 24 +---- .../GlobalOrdinalsStringTermsAggregator.java | 13 ++- .../support/StreamingAggregator.java | 15 +++- 9 files changed, 115 insertions(+), 74 deletions(-) diff --git a/libs/datafusion/jni/src/lib.rs b/libs/datafusion/jni/src/lib.rs index d8965029a9475..89361cfe008f4 100644 --- a/libs/datafusion/jni/src/lib.rs +++ b/libs/datafusion/jni/src/lib.rs @@ -11,7 +11,7 @@ use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use futures::stream::TryStreamExt; use jni::objects::{JByteArray, JClass, JObject, JString}; -use jni::sys::jlong; +use jni::sys::{jint, jlong}; use jni::JNIEnv; use std::io::BufWriter; use tokio::runtime::Runtime; @@ -44,14 +44,16 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFusion_agg( runtime: jlong, ctx: jlong, ticket: JByteArray, + size: jint, callback: JObject, ) { let input = env.convert_byte_array(&ticket).unwrap(); + env.lon let context = unsafe { &mut *(ctx as *mut SessionContext) }; let runtime = unsafe { &mut *(runtime as *mut Runtime) }; runtime.block_on(async { - let result = provider::read_aggs(context.clone(), Bytes::from(input)).await; + let result = provider::read_aggs(context.clone(), Bytes::from(input), size).await; let addr = result.map(|df| Box::into_raw(Box::new(df))); set_object_result(&mut env, callback, addr); }); diff --git a/libs/datafusion/jni/src/provider.rs b/libs/datafusion/jni/src/provider.rs index f0028812ca78d..19f7dcbc83e56 100644 --- a/libs/datafusion/jni/src/provider.rs +++ b/libs/datafusion/jni/src/provider.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use arrow_flight::{flight_service_client::FlightServiceClient, FlightDescriptor}; use bytes::Bytes; use datafusion::common::JoinType; @@ -32,12 +33,15 @@ pub async fn query( pub async fn read_aggs( ctx: SessionContext, ticket: Bytes, + size: usize ) -> datafusion::common::Result { - let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket, "http://localhost:9400".to_owned()).await?; + let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket, "http://localhost:9450".to_owned()).await?; + // Ok(df) + // df.clone().explain(true, true)?.collect().await?; df.filter(col("ord").is_not_null())? .aggregate(vec![col("ord")], vec![sum(col("count")).alias("count")])? .sort(vec![col("count").sort(false, true)])? // Sort by count descending - .limit(0, Some(500)) // Get top 500 results + .limit(0, Some(500)) // Get top 500 results } // inner join two tables together, returning a single DataFrame that can be consumed @@ -105,7 +109,6 @@ async fn get_dataframe_for_tickets( ticket: Bytes, entry_point: String ) -> Result { - println!("Register"); register_table(ctx, name.clone(), ticket, entry_point.clone()) .and_then(|_| ctx.table(&name)) .await @@ -119,7 +122,7 @@ async fn register_table(ctx: &SessionContext, name: String, ticket: Bytes, entry .open_table(entry_point, HashMap::new()) .await .map_err(|e| DataFusionError::Execution(format!("Error creating table: {}", e)))?; - println!("Registering table {:?}", table); + // println!("Registering table {:?}", table); ctx.register_table(name, Arc::new(table)) .map_err(|e| DataFusionError::Execution(format!("Error registering table: {}", e)))?; // let df = ctx.sql("SHOW TABLES").await?; @@ -140,27 +143,27 @@ impl FlightDriver for TicketedFlightDriver { channel: Channel, _options: &HashMap, ) -> arrow_flight::error::Result { - println!("DRIVER: options: {:?}", _options); + // println!("DRIVER: options: {:?}", _options); let mut client: FlightServiceClient = FlightServiceClient::new(channel.clone()); - println!("DRIVER: Using ticket: {:?}", self.ticket); + // println!("DRIVER: Using ticket: {:?}", self.ticket); let descriptor = FlightDescriptor::new_cmd(self.ticket.clone()); - println!("DRIVER: Created descriptor: {:?}", descriptor); + // println!("DRIVER: Created descriptor: {:?}", descriptor); let request = tonic::Request::new(descriptor); - println!("DRIVER: Sending get_flight_info request"); + // println!("DRIVER: Sending get_flight_info request"); match client.get_flight_info(request).await { Ok(info) => { - println!("DRIVER: Received flight info response"); + // println!("DRIVER: Received flight info response"); let info = info.into_inner(); - println!("DRIVER: Flight info: {:?}", info); + // println!("DRIVER: Flight info: {:?}", info); FlightMetadata::try_new(info, FlightProperties::default()) } Err(status) => { - println!("DRIVER: Error getting flight info: {:?}", status); + // println!("DRIVER: Error getting flight info: {:?}", status); Err(status.into()) } } diff --git a/libs/datafusion/jni/src/provider/test/mod.rs b/libs/datafusion/jni/src/provider/test/mod.rs index 182ff94e3ba5a..fe74341b83f89 100644 --- a/libs/datafusion/jni/src/provider/test/mod.rs +++ b/libs/datafusion/jni/src/provider/test/mod.rs @@ -3,8 +3,10 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use std::sync::atomic::{AtomicUsize, Ordering}; use arrow::array::{StringArray, StructArray}; +use arrow::datatypes::SchemaRef; use arrow::ipc; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; @@ -36,13 +38,14 @@ const BEARER_TOKEN: &str = "Bearer flight-sql-token"; struct TestFlightService { flight_info: FlightInfo, - partition_data: RecordBatch, + schema: SchemaRef, + batch_counter: Arc, shutdown_sender: Option>, } impl TestFlightService { async fn run_in_background(self, rx: Receiver<()>) -> SocketAddr { - let addr = SocketAddr::from(([127, 0, 0, 1], 8815)); + let addr = SocketAddr::from(([127, 0, 0, 1], 9450)); let listener = TcpListener::bind(addr).await.unwrap(); let addr = listener.local_addr().unwrap(); let service = FlightServiceServer::new(self); @@ -60,6 +63,27 @@ impl TestFlightService { tokio::time::sleep(Duration::from_millis(25)).await; addr } + + fn generate_batch(schema: SchemaRef, batch_num: usize) -> RecordBatch { + println!("TEST SERVER: Generating batch {}", batch_num); + match batch_num { + 0 => RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A", "B"])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ).unwrap(), + 1 => RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A", "C"])), + Arc::new(Int64Array::from(vec![30, 40])), + ], + ).unwrap(), + _ => panic!("No more batches to generate"), + } + } } impl Drop for TestFlightService { @@ -98,12 +122,34 @@ impl FlightService for TestFlightService { ) -> Result, Status> { println!("TEST SERVER: do_get called"); println!("TEST SERVER: ticket: {:?}", request.get_ref()); - - let data = self.partition_data.clone(); - let rb = async move { Ok(data) }; + + let counter = self.batch_counter.clone(); + let schema = self.schema.clone(); + + // Create an async stream that generates batches on demand + let stream = futures::stream::unfold(0usize, move |current_batch| { + let counter = counter.clone(); + let schema = schema.clone(); + + async move { + if current_batch >= 2 { // Total number of batches we want to generate + return None; + } + + // Simulate some async work before generating the batch + tokio::time::sleep(Duration::from_millis(100)).await; + + counter.fetch_add(1, Ordering::SeqCst); + println!("TEST SERVER: Preparing batch {}", current_batch); + + let batch = Self::generate_batch(schema, current_batch); + Some((Ok(batch), current_batch + 1)) + } + }); + // Build the Flight data encoder with the schema from the first batch let stream = FlightDataEncoderBuilder::default() - .with_schema(self.partition_data.schema()) - .build(stream::once(rb)) + .with_schema(self.schema.clone()) + .build(stream) .map_err(|e| Status::internal(e.to_string())); Ok(Response::new(Box::pin(stream))) @@ -177,29 +223,24 @@ impl FlightService for TestFlightService { #[tokio::test] async fn test_read_aggs() -> datafusion::common::Result<()> { - // Create test data - let partition_data: RecordBatch = RecordBatch::try_new( - Arc::new(Schema::new([ - Arc::new(Field::new("ord", DataType::Utf8, false)), - Arc::new(Field::new("count", DataType::Int64, false)), - ])), - vec![ - Arc::new(StringArray::from(vec!["A", "B", "A", "C"])), - Arc::new(Int64Array::from(vec![10, 20, 30, 40])), - ], - )?; + // Create test data schema + let schema = Arc::new(Schema::new([ + Arc::new(Field::new("ord", DataType::Utf8, false)), + Arc::new(Field::new("count", DataType::Int64, false)), + ])); // Set up flight endpoint let endpoint = FlightEndpoint::default().with_ticket(Ticket::new("bytes".as_bytes())); let flight_info = FlightInfo::default() - .try_with_schema(partition_data.schema().as_ref())? + .try_with_schema(schema.as_ref())? .with_endpoint(endpoint); // Set up test service let (tx, rx) = channel(); let service = TestFlightService { flight_info, - partition_data, + schema: schema.clone(), + batch_counter: Arc::new(AtomicUsize::new(0)), shutdown_sender: Some(tx), }; @@ -214,12 +255,6 @@ impl FlightService for TestFlightService { // Set up session context let config = SessionConfig::new().with_batch_size(1); let ctx = SessionContext::new_with_config(config); - // let props_template = FlightProperties::new().with_reusable_flight_info(true); - // let driver = FlightSqlDriver::new().with_properties_template(props_template); - // ctx.state_ref().write().table_factories_mut().insert( - // "FLIGHT_SQL".into(), - // Arc::new(FlightTableFactory::new(Arc::new(driver))), - // ); // Create test bytes for ticket let test_bytes = Bytes::from("test_ticket"); diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java index a287304634dd3..c5c7aa1e4fca0 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java @@ -76,7 +76,6 @@ public void run(VectorSchemaRoot root, FlushSignal flushSignal) { df = frameSupplier.apply(rootTicket).join(); recordBatchStream = df.getStream(allocator, root).get(); while (recordBatchStream.loadNextBatch().join()) { -// logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount()); // wait for a signal to load the next batch flushSignal.awaitConsumption(1000); } diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java index 32303cbb0ce68..00fd51dcb9944 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; @@ -27,7 +26,7 @@ public class DataFusion { // create a DataFrame from a list of tickets. static native void query(long runtime, long ctx, byte[] ticket, ObjectResultCallback callback); - static native void agg(long runtime, long ctx, byte[] ticket, ObjectResultCallback callback); + static native void agg(long runtime, long ctx, byte[] ticket, int size, ObjectResultCallback callback); // collect the DataFrame static native void collect(long runtime, long df, BiConsumer callback); @@ -48,10 +47,10 @@ public static CompletableFuture query(byte[] ticket) { return future; } - public static CompletableFuture agg(byte[] ticket) { + public static CompletableFuture agg(byte[] ticket, int size) { SessionContext ctx = new SessionContext(); CompletableFuture future = new CompletableFuture<>(); - DataFusion.agg(ctx.getRuntime(), ctx.getPointer(), ticket, (err, ptr) -> { + DataFusion.agg(ctx.getRuntime(), ctx.getPointer(), ticket, size, (err, ptr) -> { if (err != null) { future.completeExceptionally(new RuntimeException(err)); } else { diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index 35e1c33081c26..3c4046e250bed 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -44,8 +44,10 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder; import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.stream.StreamSearchResult; @@ -53,6 +55,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -88,6 +91,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; + private final Collection aggregationBuilders; /** * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results @@ -118,6 +122,7 @@ public QueryPhaseResultConsumer( this.hasAggs = source != null && source.aggregations() != null; int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo()); + this.aggregationBuilders = source.aggregations().getAggregatorFactories(); } @Override @@ -146,8 +151,12 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { SearchPhaseController.ReducedQueryPhase reducePhase = null; if (results.get(0) instanceof StreamSearchResult) { + int size = 0; + for (AggregationBuilder aggregatorFactory : aggregationBuilders) { + size = ((TermsAggregationBuilder) aggregatorFactory).size(); + } reducePhase = controller.reducedAggsFromStream(results.asList() - .stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList())); + .stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList()), size); } else { final SearchPhaseController.TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); final List topDocsList = pendingMerges.consumeTopDocs(); diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index a9d7aac754cb9..af236e4625479 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -726,15 +726,16 @@ static int getTopDocsSize(SearchRequest request) { public static Logger logger = LogManager.getLogger(SearchPhaseController.class); - public ReducedQueryPhase reducedAggsFromStream(List list) throws Exception { + public ReducedQueryPhase reducedAggsFromStream(List list, int size) throws Exception { List tickets = list.stream().flatMap(r -> r.getFlightTickets().stream()) .map(OSTicket::getBytes) .collect(Collectors.toList()); Set streamTickets = tickets.stream().map(t -> streamManager.getStreamTicketFactory().fromBytes(t)).collect(Collectors.toSet()); + DataFrameStreamProducer producer = AccessController.doPrivileged((PrivilegedAction) () -> - new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes()))); + new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes(), size))); logger.info("Register stream at coordinator"); return AccessController.doPrivileged((PrivilegedAction) () -> { @@ -748,7 +749,6 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th List buckets = new ArrayList<>(); logger.info("Starting iteration at coordinator"); while (streamIterator.next()) { - logger.info(root.getRowCount()); int rowCount = root.getRowCount(); totalRows+= rowCount; @@ -763,24 +763,6 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th } } -// recordBatchStream.close(); -// dataFrame.close(); -// while (streamIterator.next()) { -// int rowCount = root.getRowCount(); -// totalRows+= rowCount; -// logger.info("AT COORD Record Batch with " + rowCount + " rows:"); -// -// // Iterate through rows -// for (int row = 0; row < rowCount; row++) { -// FieldVector ordKey = root.getVector("ord"); -// String ordName = (String) getValue(ordKey, row); -// Float8Vector count = (Float8Vector) root.getVector("count"); -// -// Double bucketCount = (Double) getValue(count, row); -// logger.info("Ord: " + ordName + " Bucket Count" + bucketCount + "NodeID: "); -// buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW)); -// } -// } aggs.add(new StringTerms( "category", InternalOrder.key(true), diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index af339b5df8d68..97957e77ad0e7 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -45,6 +45,7 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.PriorityQueue; +import org.opensearch.action.search.SearchType; import org.opensearch.common.SetOnce; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; @@ -733,11 +734,15 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws long[] otherDocCount = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { final int size; - if (localBucketCountThresholds.getMinDocCount() == 0) { - // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns - size = (int) Math.min(valueCount, localBucketCountThresholds.getRequiredSize()); + if (context.searchType().equals(SearchType.STREAM)) { + size = (int) maxBucketOrd(); } else { - size = (int) Math.min(maxBucketOrd(), localBucketCountThresholds.getRequiredSize()); + if (localBucketCountThresholds.getMinDocCount() == 0) { + // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns + size = (int) Math.min(valueCount, localBucketCountThresholds.getRequiredSize()); + } else { + size = (int) Math.min(maxBucketOrd(), localBucketCountThresholds.getRequiredSize()); + } } PriorityQueue ordered = buildPriorityQueue(size); final int finalOrdIdx = ordIdx; diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java index e7ec11c89d720..a2a615d2cc81e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java @@ -52,6 +52,7 @@ public class StreamingAggregator extends FilterCollector { private final StreamProducer.FlushSignal flushSignal; private final int batchSize; private final ShardId shardId; + private final Map vectors; /** * Sole constructor. * @@ -72,16 +73,19 @@ public StreamingAggregator( this.batchSize = batchSize; this.flushSignal = flushSignal; this.shardId = shardId; + this.vectors = new HashMap<>(); + vectors.put("ord", root.getVector("ord")); + vectors.put("count", root.getVector("count")); } public static Logger logger = LogManager.getLogger(StreamingAggregator.class); +// final int[] sum = {0}; + final int[] totalDocs = {0}; @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - Map vectors = new HashMap<>(); - vectors.put("ord", root.getVector("ord")); - vectors.put("count", root.getVector("count")); final int[] currentRow = {0}; final LeafBucketCollector leaf = aggregator.getLeafCollector(context); +// logger.info("New segment collect {}", sum[0]); return new LeafBucketCollector() { @@ -103,6 +107,8 @@ public void collect(int doc, long owningBucketOrd) throws IOException { public void finish() throws IOException { if (currentRow[0] > 0) { flushBatch(); +// logger.info("otherDocCount {}", sum[0]); +// logger.info("total docs {}", totalDocs[0]); logger.info("Flushed last batch for segment {}", context.toString()); } } @@ -143,7 +149,7 @@ private void flushBatch() throws IOException { // Also access high-level statistics -// long otherDocCount = terms.getSumOfOtherDocCounts(); +// sum[0] += terms.getSumOfOtherDocCounts(); // long docCountError = terms.getDocCountError(); } else if (agg instanceof InternalCardinality) { InternalCardinality ic = (InternalCardinality) agg; @@ -162,6 +168,7 @@ private void flushBatch() throws IOException { flushSignal.awaitConsumption(10000000); logger.info("Consumed batch at data node"); aggregator.reset(); +// totalDocs[0] += currentRow[0]; currentRow[0] = 0; } };