Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-35351][SQL] Add code-gen for left anti sort merge join
### What changes were proposed in this pull request? As title. This PR is to add code-gen support for LEFT ANTI sort merge join. The main change is to extract `loadStreamed` in `SortMergeJoinExec.doProduce()`. That is to set all columns values for streamed row, when the streamed row has no output row. Example query: ``` val df1 = spark.range(10).select($"id".as("k1")) val df2 = spark.range(4).select($"id".as("k2")) df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti") ``` Example generated code: ``` == Subtree 5 / 5 (maxMethodCodeSize:296; maxConstantPoolSize:156(0.24% used); numInnerClasses:0) == *(5) Project [id#0L AS k1#2L] +- *(5) SortMergeJoin [id#0L], [k2#6L], LeftAnti :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#0L, 5), ENSURE_REQUIREMENTS, [id=#27] : +- *(1) Range (0, 10, step=1, splits=2) +- *(4) Sort [k2#6L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(k2#6L, 5), ENSURE_REQUIREMENTS, [id=#33] +- *(3) Project [id#4L AS k2#6L] +- *(3) Range (0, 4, step=1, splits=2) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage5(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=5 /* 006 */ final class GeneratedIteratorForCodegenStage5 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator smj_streamedInput_0; /* 010 */ private scala.collection.Iterator smj_bufferedInput_0; /* 011 */ private InternalRow smj_streamedRow_0; /* 012 */ private InternalRow smj_bufferedRow_0; /* 013 */ private long smj_value_2; /* 014 */ private org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray smj_matches_0; /* 015 */ private long smj_value_3; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] smj_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage5(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ smj_streamedInput_0 = inputs[0]; /* 026 */ smj_bufferedInput_0 = inputs[1]; /* 027 */ /* 028 */ smj_matches_0 = new org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray(1, 2147483647); /* 029 */ smj_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ smj_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 031 */ /* 032 */ } /* 033 */ /* 034 */ private boolean findNextJoinRows( /* 035 */ scala.collection.Iterator streamedIter, /* 036 */ scala.collection.Iterator bufferedIter) { /* 037 */ smj_streamedRow_0 = null; /* 038 */ int comp = 0; /* 039 */ while (smj_streamedRow_0 == null) { /* 040 */ if (!streamedIter.hasNext()) return false; /* 041 */ smj_streamedRow_0 = (InternalRow) streamedIter.next(); /* 042 */ long smj_value_0 = smj_streamedRow_0.getLong(0); /* 043 */ if (false) { /* 044 */ if (!smj_matches_0.isEmpty()) { /* 045 */ smj_matches_0.clear(); /* 046 */ } /* 047 */ return false; /* 048 */ /* 049 */ } /* 050 */ if (!smj_matches_0.isEmpty()) { /* 051 */ comp = 0; /* 052 */ if (comp == 0) { /* 053 */ comp = (smj_value_0 > smj_value_3 ? 1 : smj_value_0 < smj_value_3 ? -1 : 0); /* 054 */ } /* 055 */ /* 056 */ if (comp == 0) { /* 057 */ return true; /* 058 */ } /* 059 */ smj_matches_0.clear(); /* 060 */ } /* 061 */ /* 062 */ do { /* 063 */ if (smj_bufferedRow_0 == null) { /* 064 */ if (!bufferedIter.hasNext()) { /* 065 */ smj_value_3 = smj_value_0; /* 066 */ return !smj_matches_0.isEmpty(); /* 067 */ } /* 068 */ smj_bufferedRow_0 = (InternalRow) bufferedIter.next(); /* 069 */ long smj_value_1 = smj_bufferedRow_0.getLong(0); /* 070 */ if (false) { /* 071 */ smj_bufferedRow_0 = null; /* 072 */ continue; /* 073 */ } /* 074 */ smj_value_2 = smj_value_1; /* 075 */ } /* 076 */ /* 077 */ comp = 0; /* 078 */ if (comp == 0) { /* 079 */ comp = (smj_value_0 > smj_value_2 ? 1 : smj_value_0 < smj_value_2 ? -1 : 0); /* 080 */ } /* 081 */ /* 082 */ if (comp > 0) { /* 083 */ smj_bufferedRow_0 = null; /* 084 */ } else if (comp < 0) { /* 085 */ if (!smj_matches_0.isEmpty()) { /* 086 */ smj_value_3 = smj_value_0; /* 087 */ return true; /* 088 */ } else { /* 089 */ return false; /* 090 */ } /* 091 */ } else { /* 092 */ if (smj_matches_0.isEmpty()) { /* 093 */ smj_matches_0.add((UnsafeRow) smj_bufferedRow_0); /* 094 */ } /* 095 */ /* 096 */ smj_bufferedRow_0 = null; /* 097 */ } /* 098 */ } while (smj_streamedRow_0 != null); /* 099 */ } /* 100 */ return false; // unreachable /* 101 */ } /* 102 */ /* 103 */ protected void processNext() throws java.io.IOException { /* 104 */ while (smj_streamedInput_0.hasNext()) { /* 105 */ findNextJoinRows(smj_streamedInput_0, smj_bufferedInput_0); /* 106 */ /* 107 */ long smj_value_4 = -1L; /* 108 */ smj_value_4 = smj_streamedRow_0.getLong(0); /* 109 */ scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator(); /* 110 */ /* 111 */ boolean wholestagecodegen_hasOutputRow_0 = false; /* 112 */ /* 113 */ while (!wholestagecodegen_hasOutputRow_0 && smj_iterator_0.hasNext()) { /* 114 */ InternalRow smj_bufferedRow_1 = (InternalRow) smj_iterator_0.next(); /* 115 */ /* 116 */ wholestagecodegen_hasOutputRow_0 = true; /* 117 */ } /* 118 */ /* 119 */ if (!wholestagecodegen_hasOutputRow_0) { /* 120 */ // load all values of streamed row, because the values not in join condition are not /* 121 */ // loaded yet. /* 122 */ /* 123 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 124 */ /* 125 */ // common sub-expressions /* 126 */ /* 127 */ smj_mutableStateArray_0[1].reset(); /* 128 */ /* 129 */ smj_mutableStateArray_0[1].write(0, smj_value_4); /* 130 */ append((smj_mutableStateArray_0[1].getRow()).copy()); /* 131 */ /* 132 */ } /* 133 */ if (shouldStop()) return; /* 134 */ } /* 135 */ ((org.apache.spark.sql.execution.joins.SortMergeJoinExec) references[1] /* plan */).cleanupResources(); /* 136 */ } /* 137 */ /* 138 */ } ``` ### Why are the changes needed? Improve the query CPU performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `WholeStageCodegenSuite.scala`, and existed unit test in `ExistenceJoinSuite.scala`. Closes #32547 from c21/smj-left-anti. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
- Loading branch information