Skip to content

Commit

Permalink
[improve] add auto redirect options (apache#150)
Browse files Browse the repository at this point in the history
Add redirection parameters. After opening, there is no need to obtain the be list, and streamloading is performed through fe.
  • Loading branch information
JNSimba authored Oct 26, 2023
1 parent 00ff399 commit 1ee1cae
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,9 @@ public interface ConfigurationOptions {
String DORIS_SINK_TXN_RETRIES = "doris.sink.txn.retries";
int DORIS_SINK_TXN_RETRIES_DEFAULT = 3;

/**
* Use automatic redirection of fe without explicitly obtaining the be list
*/
String DORIS_SINK_AUTO_REDIRECT = "doris.sink.auto-redirect";
boolean DORIS_SINK_AUTO_REDIRECT_DEFAULT = false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
// under the License.
package org.apache.doris.spark.load;

import org.apache.doris.spark.cfg.ConfigurationOptions;
import org.apache.doris.spark.cfg.SparkSettings;
import org.apache.doris.spark.exception.StreamLoadException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.BackendV2;
import org.apache.doris.spark.rest.models.RespContent;
import org.apache.doris.spark.util.ResponseUtil;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -32,6 +24,14 @@
import com.google.common.cache.LoadingCache;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.doris.spark.cfg.ConfigurationOptions;
import org.apache.doris.spark.cfg.SparkSettings;
import org.apache.doris.spark.exception.IllegalArgumentException;
import org.apache.doris.spark.exception.StreamLoadException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.BackendV2;
import org.apache.doris.spark.rest.models.RespContent;
import org.apache.doris.spark.util.ResponseUtil;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
Expand All @@ -41,7 +41,9 @@
import org.apache.http.entity.BufferedHttpEntity;
import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.DefaultRedirectStrategy;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
Expand Down Expand Up @@ -94,18 +96,21 @@ public class DorisStreamLoad implements Serializable {
private boolean addDoubleQuotes;
private static final long cacheExpireTimeout = 4 * 60;
private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
private final String fenodes;
private final String fileType;
private String FIELD_DELIMITER;
private final String LINE_DELIMITER;
private boolean streamingPassthrough = false;
private final boolean enable2PC;
private final Integer txnRetries;
private final Integer txnIntervalMs;
private final boolean autoRedirect;

public DorisStreamLoad(SparkSettings settings) {
String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
this.db = dbTable[0];
this.tbl = dbTable[1];
this.fenodes = settings.getProperty(ConfigurationOptions.DORIS_FENODES);
String user = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER);
String passwd = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD);
this.authEncoded = getAuthEncoded(user, passwd);
Expand Down Expand Up @@ -133,6 +138,9 @@ public DorisStreamLoad(SparkSettings settings) {
ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT);
this.txnIntervalMs = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT);

this.autoRedirect = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_AUTO_REDIRECT,
ConfigurationOptions.DORIS_SINK_AUTO_REDIRECT_DEFAULT);
}

public String getLoadUrlStr() {
Expand All @@ -143,7 +151,14 @@ public String getLoadUrlStr() {
}

private CloseableHttpClient getHttpClient() {
HttpClientBuilder httpClientBuilder = HttpClientBuilder.create().disableRedirectHandling();
HttpClientBuilder httpClientBuilder = HttpClients
.custom()
.setRedirectStrategy(new DefaultRedirectStrategy() {
@Override
protected boolean isRedirectable(String method) {
return true;
}
});
return httpClientBuilder.build();
}

Expand Down Expand Up @@ -187,7 +202,7 @@ public String toString() {
}
}

public int load(Iterator<InternalRow> rows, StructType schema)
public long load(Iterator<InternalRow> rows, StructType schema)
throws StreamLoadException, JsonProcessingException {

String label = generateLoadLabel();
Expand Down Expand Up @@ -240,15 +255,15 @@ public int load(Iterator<InternalRow> rows, StructType schema)

}

public Integer loadStream(Iterator<InternalRow> rows, StructType schema)
public Long loadStream(Iterator<InternalRow> rows, StructType schema)
throws StreamLoadException, JsonProcessingException {
if (this.streamingPassthrough) {
handleStreamPassThrough();
}
return load(rows, schema);
}

public void commit(int txnId) throws StreamLoadException {
public void commit(long txnId) throws StreamLoadException {

try (CloseableHttpClient client = getHttpClient()) {

Expand Down Expand Up @@ -296,7 +311,7 @@ public void commit(int txnId) throws StreamLoadException {
* @param txnId transaction id
* @throws StreamLoadException
*/
public void abortById(int txnId) throws StreamLoadException {
public void abortById(long txnId) throws StreamLoadException {

LOG.info("start abort transaction {}.", txnId);

Expand Down Expand Up @@ -385,13 +400,18 @@ public Map<String, String> getStreamLoadProp(SparkSettings sparkSettings) {

private String getBackend() {
try {
if (autoRedirect) {
return RestService.randomEndpoint(fenodes, LOG);
}
// get backends from cache
List<BackendV2.BackendRowV2> backends = cache.get("backends");
Collections.shuffle(backends);
BackendV2.BackendRowV2 backend = backends.get(0);
return backend.getIp() + ":" + backend.getHttpPort();
} catch (ExecutionException e) {
throw new RuntimeException("get backends info fail", e);
} catch (IllegalArgumentException e) {
throw new RuntimeException("get frontend info fail", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ static String[] parseIdentifier(String tableIdentifier, Logger logger) throws Il
* @throws IllegalArgumentException fe nodes is illegal
*/
@VisibleForTesting
static String randomEndpoint(String feNodes, Logger logger) throws IllegalArgumentException {
public static String randomEndpoint(String feNodes, Logger logger) throws IllegalArgumentException {
logger.trace("Parse fenodes '{}'.", feNodes);
if (StringUtils.isEmpty(feNodes)) {
logger.error(ILLEGAL_ARGUMENT_MESSAGE, "fenodes", feNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
public class RespContent {

@JsonProperty(value = "TxnId")
private int TxnId;
private long TxnId;

@JsonProperty(value = "Label")
private String Label;
Expand Down Expand Up @@ -75,7 +75,7 @@ public class RespContent {
@JsonProperty(value = "ErrorURL")
private String ErrorURL;

public int getTxnId() {
public long getTxnId() {
return TxnId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success}

class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], dorisStreamLoad: DorisStreamLoad, sinkTnxIntervalMs: Int, sinkTxnRetries: Int)
class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], dorisStreamLoad: DorisStreamLoad, sinkTnxIntervalMs: Int, sinkTxnRetries: Int)
extends SparkListener {

val logger: Logger = LoggerFactory.getLogger(classOf[DorisTransactionListener])

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
val txnIds: mutable.Buffer[Int] = preCommittedTxnAcc.value.asScala
val failedTxnIds = mutable.Buffer[Int]()
val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
val failedTxnIds = mutable.Buffer[Long]()
jobEnd.jobResult match {
// if job succeed, commit all transactions
case JobSucceeded =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
doWrite(dataFrame, dorisStreamLoader.loadStream)
}

private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = {
private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Long): Unit = {

val sc = dataFrame.sqlContext.sparkContext
val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
val preCommittedTxnAcc = sc.collectionAccumulator[Long]("preCommittedTxnAcc")
if (enable2PC) {
sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
}
Expand All @@ -99,7 +99,7 @@ class DorisWriter(settings: SparkSettings) extends Serializable {

while (iterator.hasNext) {
val batchIterator = new BatchIterator[InternalRow](iterator, batchSize, maxRetryTimes > 0)
val retry = Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _
val retry = Utils.retry[Long, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _
retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) match {
case Success(txnId) =>
if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
Expand All @@ -116,19 +116,19 @@ class DorisWriter(settings: SparkSettings) extends Serializable {

}

private def handleLoadSuccess(txnId: Int, acc: CollectionAccumulator[Int]): Unit = {
private def handleLoadSuccess(txnId: Long, acc: CollectionAccumulator[Long]): Unit = {
acc.add(txnId)
}

private def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
private def handleLoadFailure(acc: CollectionAccumulator[Long]): Unit = {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed transactions")
if (acc.isZero) {
logger.info("no pre-committed transactions, skip abort")
return
}
val abortFailedTxnIds = mutable.Buffer[Int]()
val abortFailedTxnIds = mutable.Buffer[Long]()
acc.value.asScala.foreach(txnId => {
Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
dorisStreamLoader.abortById(txnId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSparkConnector {
.option("doris.table.identifier", dorisTable)
.option("user", dorisUser)
.option("password", dorisPwd)
// .option("sink.auto-redirect", "true")
//specify your field
.option("doris.write.fields", "name,gender")
.option("sink.batch.size",2)
Expand Down

0 comments on commit 1ee1cae

Please sign in to comment.