diff --git a/src/jni/java/pom.xml b/src/jni/java/pom.xml
index bec60ad673e..9d0503462de 100644
--- a/src/jni/java/pom.xml
+++ b/src/jni/java/pom.xml
@@ -12,7 +12,7 @@
com.vesoft
nebula-utils
- 1.0.0-rc2
+ 1.0.0-rc3
nebula-utils
Nebula jni utils
diff --git a/src/tools/spark-sstfile-generator/pom.xml b/src/tools/spark-sstfile-generator/pom.xml
index 36f3b0df1f4..126824274b6 100644
--- a/src/tools/spark-sstfile-generator/pom.xml
+++ b/src/tools/spark-sstfile-generator/pom.xml
@@ -6,7 +6,7 @@
com.vesoft
sst.generator
- 1.0.0-rc2
+ 1.0.0-rc3
1.8
@@ -21,7 +21,7 @@
1.4.0
3.9.2
3.7.1
- 1.0.0-rc2
+ 1.0.0-rc3
1.0.0
diff --git a/src/tools/spark-sstfile-generator/src/main/resources/application.conf b/src/tools/spark-sstfile-generator/src/main/resources/application.conf
index dd2f00d7f2e..178a03a498a 100644
--- a/src/tools/spark-sstfile-generator/src/main/resources/application.conf
+++ b/src/tools/spark-sstfile-generator/src/main/resources/application.conf
@@ -30,13 +30,24 @@
execution {
retry: 3
}
+
+ error: {
+ max: 32
+ output: /tmp/errors
+ }
+
+ rate: {
+ limit: 1024
+ timeout: 1000
+ }
}
# Processing tags
- tags: {
+ tags: [
# Loading tag from HDFS and data type is parquet
- tag-name-0: {
+ {
+ name: tag-name-0
type: parquet
path: hdfs tag path 0
fields: {
@@ -48,7 +59,8 @@
}
# Loading from Hive
- tag-name-1: {
+ {
+ name: tag-name-1
type: hive
exec: "select hive-field0, hive-field1, hive-field2 from database.table"
fields: {
@@ -63,12 +75,13 @@
vertex: hive-field-0
partition: 32
}
- }
+ ]
# Processing edges
- edges: {
+ edges: [
# Loading tag from HDFS and data type is parquet
- edge-name-0: {
+ {
+ name: edge-name-0
type: json
path: hdfs edge path 0
fields: {
@@ -81,25 +94,26 @@
policy: "hash"
}
target: {
- field:hive-field-1
+ field: hive-field-1
policy: "uuid"
}
ranking: hive-field-2
partition: 32
}
- }
- # Loading from Hive
- edge-name-1: {
- type: hive
- exec: "select hive-field0, hive-field1, hive-field2 from database.table"
- fields: {
- hive-field-0: nebula-field-0,
- hive-field-1: nebula-field-1,
- hive-field-2: nebula-field-2
+ # Loading from Hive
+ {
+ name: edge-name-1
+ type: hive
+ exec: "select hive-field0, hive-field1, hive-field2 from database.table"
+ fields: {
+ hive-field-0: nebula-field-0,
+ hive-field-1: nebula-field-1,
+ hive-field-2: nebula-field-2
+ }
+ source: hive-field-0
+ target: hive-field-1
+ partition: 32
}
- source: hive-field-0
- target: hive-field-1
- partition: 32
- }
+ ]
}
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
index e54b4a3a6bd..fc5474d7420 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/nebula/tools/generator/v2/SparkClientGenerator.scala
@@ -7,66 +7,88 @@
package com.vesoft.nebula.tools.generator.v2
import org.apache.spark.sql.{DataFrame, Encoders, Row, SparkSession}
-import com.typesafe.config.{Config, ConfigFactory}
import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.functions.udf
+import com.typesafe.config.{Config, ConfigFactory}
import java.io.File
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
import com.google.common.base.Optional
import com.google.common.geometry.{S2CellId, S2LatLng}
import com.google.common.net.HostAndPort
-import com.google.common.util.concurrent.{FutureCallback, Futures}
+import com.google.common.util.concurrent.{
+ FutureCallback,
+ Futures,
+ ListenableFuture,
+ MoreExecutors,
+ RateLimiter
+}
import com.vesoft.nebula.client.graph.async.AsyncGraphClientImpl
import com.vesoft.nebula.graph.ErrorCode
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.log4j.Logger
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import util.control.Breaks._
-case class Argument(config: File = new File("application.conf"),
- hive: Boolean = false,
- directly: Boolean = false,
- dry: Boolean = false)
+final case class Argument(
+ config: File = new File("application.conf"),
+ hive: Boolean = false,
+ directly: Boolean = false,
+ dry: Boolean = false,
+ reload: String = ""
+)
+
+final case class TooManyErrorsException(private val message: String) extends Exception(message)
/**
* SparkClientGenerator is a simple spark job used to write data into Nebula Graph parallel.
*/
object SparkClientGenerator {
-
private[this] val LOG = Logger.getLogger(this.getClass)
- private[this] val HASH_POLICY = "hash"
- private[this] val UUID_POLICY = "uuid"
private[this] val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s"
- private[this] val INSERT_VALUE_TEMPLATE = "%d: (%s)"
- private[this] val INSERT_VALUE_TEMPLATE_WITH_POLICY = "%s(%d): (%s)"
- private[this] val ENDPOINT_TEMPLATE = "%s(%d)"
- private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%d->%d: (%s)"
+ private[this] val INSERT_VALUE_TEMPLATE = "%s: (%s)"
+ private[this] val INSERT_VALUE_TEMPLATE_WITH_POLICY = "%s(\"%s\"): (%s)"
+ private[this] val ENDPOINT_TEMPLATE = "%s(\"%s\")"
+ private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%s->%s: (%s)"
private[this] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY = "%s->%s: (%s)"
- private[this] val EDGE_VALUE_TEMPLATE = "%d->%d@%d: (%s)"
+ private[this] val EDGE_VALUE_TEMPLATE = "%s->%s@%d: (%s)"
private[this] val EDGE_VALUE_TEMPLATE_WITH_POLICY = "%s->%s@%d: (%s)"
private[this] val USE_TEMPLATE = "USE %s"
- private[this] val DEFAULT_BATCH = 64
- private[this] val DEFAULT_PARTITION = -1
- private[this] val DEFAULT_CONNECTION_TIMEOUT = 3000
- private[this] val DEFAULT_CONNECTION_RETRY = 3
- private[this] val DEFAULT_EXECUTION_RETRY = 3
- private[this] val DEFAULT_EXECUTION_INTERVAL = 3000
- private[this] val DEFAULT_EDGE_RANKING = 0L
- private[this] val DEFAULT_ERROR_TIMES = 16
+ private[this] val DEFAULT_BATCH = 2
+ private[this] val DEFAULT_PARTITION = -1
+ private[this] val DEFAULT_CONNECTION_TIMEOUT = 3000
+ private[this] val DEFAULT_CONNECTION_RETRY = 3
+ private[this] val DEFAULT_EXECUTION_RETRY = 3
+ private[this] val DEFAULT_EXECUTION_INTERVAL = 3000
+ private[this] val DEFAULT_EDGE_RANKING = 0L
+ private[this] val DEFAULT_ERROR_TIMES = 16
+ private[this] val DEFAULT_ERROR_OUTPUT_PATH = "/tmp/nebula.writer.errors/"
+ private[this] val DEFAULT_ERROR_MAX_BATCH_SIZE = Int.MaxValue
+ private[this] val DEFAULT_RATE_LIMIT = 1024
+ private[this] val DEFAULT_RATE_TIMEOUT = 100
+ private[this] val NEWLINE = "\n"
// GEO default config
- private[this] val DEFAULT_MIN_CELL_LEVEL = 5
- private[this] val DEFAULT_MAX_CELL_LEVEL = 24
+ private[this] val DEFAULT_MIN_CELL_LEVEL = 10
+ private[this] val DEFAULT_MAX_CELL_LEVEL = 18
private[this] val MAX_CORES = 64
private object Type extends Enumeration {
type Type = Value
- val Vertex = Value("Vertex")
- val Edge = Value("Edge")
+ val VERTEX = Value("VERTEX")
+ val EDGE = Value("EDGE")
+ }
+
+ private object KeyPolicy extends Enumeration {
+ type POLICY = Value
+ val HASH = Value("hash")
+ val UUID = Value("uuid")
}
def main(args: Array[String]): Unit = {
@@ -91,6 +113,11 @@ object SparkClientGenerator {
opt[Unit]('D', "dry")
.action((_, c) => c.copy(dry = true))
.text("dry run")
+
+ opt[String]('r', "reload")
+ .valueName("")
+ .action((x, c) => c.copy(reload = x))
+ .text("reload path")
}
val c: Argument = parser.parse(args, Argument()) match {
@@ -108,28 +135,48 @@ object SparkClientGenerator {
val pswd = nebulaConfig.getString("pswd")
val space = nebulaConfig.getString("space")
- val connectionTimeout =
- getOrElse(nebulaConfig, "connection.timeout", DEFAULT_CONNECTION_TIMEOUT)
- val connectionRetry = getOrElse(nebulaConfig, "connection.retry", DEFAULT_CONNECTION_RETRY)
+ val connectionConfig = getConfigOrNone(nebulaConfig, "connection")
+ val connectionTimeout = getOptOrElse(connectionConfig, "timeout", DEFAULT_CONNECTION_TIMEOUT)
+ val connectionRetry = getOptOrElse(connectionConfig, "retry", DEFAULT_CONNECTION_RETRY)
+
+ val executionConfig = getConfigOrNone(nebulaConfig, "execution")
+ val executionRetry = getOptOrElse(executionConfig, "retry", DEFAULT_EXECUTION_RETRY)
+ val executionInterval = getOptOrElse(executionConfig, "interval", DEFAULT_EXECUTION_INTERVAL)
- val executionRetry = getOrElse(nebulaConfig, "execution.retry", DEFAULT_EXECUTION_RETRY)
- val executionInterval =
- getOrElse(nebulaConfig, "execution.interval", DEFAULT_EXECUTION_INTERVAL)
+ val errorConfig = getConfigOrNone(nebulaConfig, "error")
+ val errorPath = getOptOrElse(errorConfig, "output", DEFAULT_ERROR_OUTPUT_PATH)
+ val errorMaxSize = getOptOrElse(errorConfig, "max", DEFAULT_ERROR_MAX_BATCH_SIZE)
+
+ val rateConfig = getConfigOrNone(nebulaConfig, "rate")
+ val rateLimit = getOptOrElse(rateConfig, "limit", DEFAULT_RATE_LIMIT)
+ val rateTimeout = getOptOrElse(rateConfig, "timeout", DEFAULT_RATE_TIMEOUT)
LOG.info(s"Nebula Addresses ${addresses} for ${user}:${pswd}")
LOG.info(s"Connection Timeout ${connectionTimeout} Retry ${connectionRetry}")
LOG.info(s"Execution Retry ${executionRetry} Interval Base ${executionInterval}")
+ LOG.info(s"Error Path ${errorPath} Max Size ${errorMaxSize}")
LOG.info(s"Switch to ${space}")
- val session = SparkSession
- .builder()
- .appName(PROGRAM_NAME)
+ val fs = FileSystem.get(new Configuration())
+ val hdfsPath = new Path(errorPath)
+ try {
+ if (!fs.exists(hdfsPath)) {
+ LOG.info(s"Create HDFS directory: ${errorPath}")
+ fs.mkdirs(hdfsPath)
+ }
+ } finally {
+ fs.close()
+ }
if (config.hasPath("spark.cores.max") &&
config.getInt("spark.cores.max") > MAX_CORES) {
LOG.warn(s"Concurrency is higher than ${MAX_CORES}")
}
+ val session = SparkSession
+ .builder()
+ .appName(PROGRAM_NAME)
+
val sparkConfig = config.getObject("spark")
for (key <- sparkConfig.unwrapped().keySet().asScala) {
val configKey = s"spark.${key}"
@@ -147,26 +194,74 @@ object SparkClientGenerator {
}
}
- val spark =
- if (c.hive) session.enableHiveSupport().getOrCreate()
- else session.getOrCreate()
-
- val tagConfigs =
- if (config.hasPath("tags"))
- Some(config.getObject("tags"))
- else None
+ val spark = if (c.hive) {
+ session.enableHiveSupport().getOrCreate()
+ } else {
+ session.getOrCreate()
+ }
- class TooManyErrorException(e: String) extends Exception(e) {}
+ if (!c.reload.isEmpty) {
+ val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.reload")
+ val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.reload")
+
+ spark.read
+ .text(c.reload)
+ .foreachPartition { records =>
+ val hostAndPorts = addresses.map(HostAndPort.fromString).asJava
+ val client = new AsyncGraphClientImpl(
+ hostAndPorts,
+ connectionTimeout,
+ connectionRetry,
+ executionRetry
+ )
+ client.setUser(user)
+ client.setPassword(pswd)
+
+ if (isSuccessfully(client.connect())) {
+ val rateLimiter = RateLimiter.create(rateLimit)
+ records.foreach { row =>
+ val exec = row.getString(0)
+ if (rateLimiter.tryAcquire(rateTimeout, TimeUnit.MILLISECONDS)) {
+ val future = client.execute(exec)
+ Futures.addCallback(
+ future,
+ new FutureCallback[Optional[Integer]] {
+ override def onSuccess(result: Optional[Integer]): Unit = {
+ batchSuccess.add(1)
+ }
+
+ override def onFailure(t: Throwable): Unit = {
+ if (batchFailure.value > DEFAULT_ERROR_TIMES) {
+ throw TooManyErrorsException("too many errors")
+ }
+ batchFailure.add(1)
+ }
+ }
+ )
+ } else {
+ batchFailure.add(1)
+ }
+ }
+ client.close()
+ } else {
+ LOG.error(s"Client connection failed. ${user}:${pswd}")
+ }
+ }
+ sys.exit(0)
+ }
+ val tagConfigs = getConfigsOrNone(config, "tags")
if (tagConfigs.isDefined) {
- for (tagName <- tagConfigs.get.unwrapped.keySet.asScala) {
- LOG.info(s"Processing Tag ${tagName}")
- val tagConfig = config.getConfig(s"tags.${tagName}")
- if (!tagConfig.hasPath("type")) {
- LOG.error("The type must be specified")
+ for (tagConfig <- tagConfigs.get.asScala) {
+ if (!tagConfig.hasPath("name") ||
+ !tagConfig.hasPath("type")) {
+ LOG.error("The `name` and `type` must be specified")
break()
}
+ val tagName = tagConfig.getString("name")
+ LOG.info(s"Processing Tag ${tagName}")
+
val pathOpt = if (tagConfig.hasPath("path")) {
Some(tagConfig.getString("path"))
} else {
@@ -174,14 +269,18 @@ object SparkClientGenerator {
}
val fields = tagConfig.getObject("fields").unwrapped
- val vertex = if (tagConfig.hasPath("vertex")) {
- tagConfig.getString("vertex")
- } else {
+
+ // You can specified the vertex field name via the config item `vertex`
+ // If you want to qualified the key policy, you can wrap them into a block.
+ val vertex = if (tagConfig.hasPath("vertex.field")) {
tagConfig.getString("vertex.field")
+ } else {
+ tagConfig.getString("vertex")
}
val policyOpt = if (tagConfig.hasPath("vertex.policy")) {
- Some(tagConfig.getString("vertex.policy").toLowerCase)
+ val policy = tagConfig.getString("vertex.policy").toLowerCase
+ Some(KeyPolicy.withName(policy))
} else {
None
}
@@ -198,49 +297,42 @@ object SparkClientGenerator {
fields.asScala.keys.toList
}
- val sourceColumn = sourceProperties.map { property =>
- if (property == vertex) {
- col(property).cast(LongType)
- } else {
- col(property)
- }
- }
-
val vertexIndex = sourceProperties.indexOf(vertex)
val nebulaProperties = properties.mkString(",")
-
- val toVertex: String => Long = _.toLong
- val toVertexUDF = udf(toVertex)
- val data = createDataSource(spark, pathOpt, tagConfig)
+ val data = createDataSource(spark, pathOpt, tagConfig)
if (data.isDefined && !c.dry) {
val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${tagName}")
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${tagName}")
repartition(data.get, partition)
- .select(sourceColumn: _*)
- .withColumn(vertex, toVertexUDF(col(vertex)))
.map { row =>
- (row.getLong(vertexIndex),
- (for { property <- valueProperties if property.trim.length != 0 } yield
- extraValue(row, property))
- .mkString(","))
- }(Encoders.tuple(Encoders.scalaLong, Encoders.STRING))
- .foreachPartition { iterator: Iterator[(Long, String)] =>
+ val values = (for {
+ property <- valueProperties if property.trim.length != 0
+ } yield extraValue(row, property)).mkString(",")
+ (row.getString(vertexIndex), values)
+ }(Encoders.tuple(Encoders.STRING, Encoders.STRING))
+ .foreachPartition { iterator: Iterator[(String, String)] =>
+ val service = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1))
val hostAndPorts = addresses.map(HostAndPort.fromString).asJava
- val client = new AsyncGraphClientImpl(hostAndPorts,
- connectionTimeout,
- connectionRetry,
- executionRetry)
+ val client = new AsyncGraphClientImpl(
+ hostAndPorts,
+ connectionTimeout,
+ connectionRetry,
+ executionRetry
+ )
client.setUser(user)
client.setPassword(pswd)
if (isSuccessfully(client.connect())) {
+ val errorBuffer = ArrayBuffer[String]()
val switchSpaceCode = client.execute(USE_TEMPLATE.format(space)).get().get()
if (isSuccessfully(switchSpaceCode)) {
+ val rateLimiter = RateLimiter.create(rateLimit)
+ val futures = new ListBuffer[ListenableFuture[Optional[Integer]]]()
iterator.grouped(batch).foreach { tags =>
val exec = BATCH_INSERT_TEMPLATE.format(
- Type.Vertex.toString,
+ Type.VERTEX.toString,
tagName,
nebulaProperties,
tags
@@ -249,14 +341,12 @@ object SparkClientGenerator {
INSERT_VALUE_TEMPLATE.format(tag._1, tag._2)
} else {
policyOpt.get match {
- case HASH_POLICY =>
- INSERT_VALUE_TEMPLATE_WITH_POLICY.format(HASH_POLICY,
- tag._1,
- tag._2)
- case UUID_POLICY =>
- INSERT_VALUE_TEMPLATE_WITH_POLICY.format(UUID_POLICY,
- tag._1,
- tag._2)
+ case KeyPolicy.HASH =>
+ INSERT_VALUE_TEMPLATE_WITH_POLICY
+ .format(KeyPolicy.HASH.toString, tag._1, tag._2)
+ case KeyPolicy.UUID =>
+ INSERT_VALUE_TEMPLATE_WITH_POLICY
+ .format(KeyPolicy.UUID.toString, tag._1, tag._2)
case _ => throw new IllegalArgumentException
}
}
@@ -264,31 +354,69 @@ object SparkClientGenerator {
.mkString(", ")
)
- LOG.debug(s"Exec : ${exec}")
- val future = client.execute(exec)
+ LOG.info(s"Exec : ${exec}")
+ if (rateLimiter.tryAcquire(rateTimeout, TimeUnit.MILLISECONDS)) {
+ val future = client.execute(exec)
+ futures += future
+ } else {
+ batchFailure.add(1)
+ errorBuffer += exec
+ if (errorBuffer.size == errorMaxSize) {
+ throw TooManyErrorsException(s"Too Many Errors ${errorMaxSize}")
+ }
+ }
+ }
+
+ val latch = new CountDownLatch(futures.size)
+ for (future <- futures) {
Futures.addCallback(
future,
new FutureCallback[Optional[Integer]] {
override def onSuccess(result: Optional[Integer]): Unit = {
+ latch.countDown()
batchSuccess.add(1)
}
override def onFailure(t: Throwable): Unit = {
+ latch.countDown()
if (batchFailure.value > DEFAULT_ERROR_TIMES) {
- throw new TooManyErrorException("too many error")
+ throw TooManyErrorsException("too many errors")
+ } else {
+ batchFailure.add(1)
}
- batchFailure.add(1)
}
- }
+ },
+ service
)
}
+
+ if (!errorBuffer.isEmpty) {
+ val fileSystem = FileSystem.get(new Configuration())
+ val errors = fileSystem.create(new Path(s"${errorPath}/${tagName}"))
+
+ try {
+ for (error <- errorBuffer) {
+ errors.writeBytes(error)
+ errors.writeBytes(NEWLINE)
+ }
+ } finally {
+ errors.close()
+ fileSystem.close()
+ }
+ }
+ latch.await()
} else {
LOG.error(s"Switch ${space} Failed")
}
- client.close()
} else {
LOG.error(s"Client connection failed. ${user}:${pswd}")
}
+
+ service.shutdown()
+ while (!service.awaitTermination(100, TimeUnit.MILLISECONDS)) {
+ Thread.sleep(10)
+ }
+ client.close()
}
}
}
@@ -296,34 +424,35 @@ object SparkClientGenerator {
LOG.warn("Tag is not defined")
}
- val edgeConfigs = getConfigOrNone(config, "edges")
-
+ val edgeConfigs = getConfigsOrNone(config, "edges")
if (edgeConfigs.isDefined) {
- for (edgeName <- edgeConfigs.get.unwrapped.keySet.asScala) {
- LOG.info(s"Processing Edge ${edgeName}")
- val edgeConfig = config.getConfig(s"edges.${edgeName}")
- if (!edgeConfig.hasPath("type")) {
- LOG.error("The type must be specified")
+ for (edgeConfig <- edgeConfigs.get.asScala) {
+ if (!edgeConfig.hasPath("name") ||
+ !edgeConfig.hasPath("type")) {
+ LOG.error("The `name` and `type`must be specified")
break()
}
+ val edgeName = edgeConfig.getString("name")
+ LOG.info(s"Processing Edge ${edgeName}")
+
val pathOpt = if (edgeConfig.hasPath("path")) {
Some(edgeConfig.getString("path"))
} else {
- LOG.warn("The path is not setting")
None
}
val fields = edgeConfig.getObject("fields").unwrapped
val isGeo = checkGeoSupported(edgeConfig)
- val target = if (edgeConfig.hasPath("target")) {
- edgeConfig.getString("target")
- } else {
+ val target = if (edgeConfig.hasPath("target.field")) {
edgeConfig.getString("target.field")
+ } else {
+ edgeConfig.getString("target")
}
val targetPolicyOpt = if (edgeConfig.hasPath("target.policy")) {
- Some(edgeConfig.getString("target.policy").toLowerCase)
+ val policy = edgeConfig.getString("target.policy").toLowerCase
+ Some(KeyPolicy.withName(policy))
} else {
None
}
@@ -339,74 +468,33 @@ object SparkClientGenerator {
val properties = fields.asScala.values.map(_.toString).toList
val valueProperties = fields.asScala.keys.toList
- val sourceProperties = if (!isGeo) {
- val source = if (edgeConfig.hasPath("source")) {
- edgeConfig.getString("source")
- } else {
- edgeConfig.getString("source.field")
- }
-
- if (!fields.containsKey(source) ||
- !fields.containsKey(target)) {
- (fields.asScala.keySet + source + target).toList
- } else {
- fields.asScala.keys.toList
- }
- } else {
- val latitude = edgeConfig.getString("latitude")
- val longitude = edgeConfig.getString("longitude")
- if (!fields.containsKey(latitude) ||
- !fields.containsKey(longitude) ||
- !fields.containsKey(target)) {
- (fields.asScala.keySet + latitude + longitude + target).toList
- } else {
- fields.asScala.keys.toList
- }
- }
-
val sourcePolicyOpt = if (edgeConfig.hasPath("source.policy")) {
- Some(edgeConfig.getString("source.policy").toLowerCase)
+ val policy = edgeConfig.getString("source.policy").toLowerCase
+ Some(KeyPolicy.withName(policy))
} else {
None
}
- val sourceColumn = if (!isGeo) {
- val source = edgeConfig.getString("source")
- sourceProperties.map { property =>
- if (property == source || property == target) {
- col(property).cast(LongType)
- } else {
- col(property)
- }
- }
- } else {
- val latitude = edgeConfig.getString("latitude")
- val longitude = edgeConfig.getString("longitude")
- sourceProperties.map { property =>
- if (property == latitude || property == longitude) {
- col(property).cast(DoubleType)
- } else {
- col(property)
- }
- }
- }
-
val nebulaProperties = properties.mkString(",")
-
- val data = createDataSource(spark, pathOpt, edgeConfig)
- val encoder =
- Encoders.tuple(Encoders.STRING, Encoders.scalaLong, Encoders.scalaLong, Encoders.STRING)
-
+ val data = createDataSource(spark, pathOpt, edgeConfig)
if (data.isDefined && !c.dry) {
val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeName}")
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeName}")
repartition(data.get, partition)
- .select(sourceColumn: _*)
.map { row =>
val sourceField = if (!isGeo) {
- val source = edgeConfig.getString("source")
- row.getLong(row.schema.fieldIndex(source)).toString
+ val source = if (edgeConfig.hasPath("source.field")) {
+ edgeConfig.getString("source.field")
+ } else {
+ edgeConfig.getString("source")
+ }
+
+ if (sourcePolicyOpt.isEmpty) {
+ row.getLong(row.schema.fieldIndex(source)).toString
+ } else {
+ row.getString(row.schema.fieldIndex(source))
+ }
} else {
val latitude = edgeConfig.getString("latitude")
val longitude = edgeConfig.getString("longitude")
@@ -415,12 +503,17 @@ object SparkClientGenerator {
indexCells(lat, lng).mkString(",")
}
- val targetField = row.getLong(row.schema.fieldIndex(target))
+ val targetField =
+ if (targetPolicyOpt.isEmpty) {
+ row.getLong(row.schema.fieldIndex(target)).toString
+ } else {
+ row.getString(row.schema.fieldIndex(target))
+ }
val values =
- (for { property <- valueProperties if property.trim.length != 0 } yield
- extraValue(row, property))
- .mkString(",")
+ (for {
+ property <- valueProperties if property.trim.length != 0
+ } yield extraValue(row, property)).mkString(",")
if (rankingOpt.isDefined) {
val ranking = row.getLong(row.schema.fieldIndex(rankingOpt.get))
@@ -428,22 +521,31 @@ object SparkClientGenerator {
} else {
(sourceField, targetField, DEFAULT_EDGE_RANKING, values)
}
- }(encoder)
- .foreachPartition { iterator: Iterator[(String, Long, Long, String)] =>
+ }(
+ Encoders
+ .tuple(Encoders.STRING, Encoders.STRING, Encoders.scalaLong, Encoders.STRING)
+ )
+ .foreachPartition { iterator: Iterator[(String, String, Long, String)] =>
+ val service = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1))
val hostAndPorts = addresses.map(HostAndPort.fromString).asJava
- val client = new AsyncGraphClientImpl(hostAndPorts,
- connectionTimeout,
- connectionRetry,
- executionRetry)
+ val client = new AsyncGraphClientImpl(
+ hostAndPorts,
+ connectionTimeout,
+ connectionRetry,
+ executionRetry
+ )
client.setUser(user)
client.setPassword(pswd)
if (isSuccessfully(client.connect())) {
val switchSpaceCode = client.switchSpace(space).get().get()
if (isSuccessfully(switchSpaceCode)) {
+ val errorBuffer = ArrayBuffer[String]()
+ val rateLimiter = RateLimiter.create(rateLimit)
+ val futures = new ListBuffer[ListenableFuture[Optional[Integer]]]()
iterator.grouped(batch).foreach { edges =>
val values =
- if (rankingOpt.isEmpty)
+ if (rankingOpt.isEmpty) {
edges
.map { edge =>
// TODO: (darion.yaphet) dataframe.explode() would be better ?
@@ -451,22 +553,32 @@ object SparkClientGenerator {
yield
if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) {
EDGE_VALUE_WITHOUT_RANKING_TEMPLATE
- .format(source.toLong, edge._2, edge._4)
+ .format(source, edge._2, edge._4)
} else {
- val source = sourcePolicyOpt.get match {
- case HASH_POLICY =>
- ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1)
- case UUID_POLICY =>
- ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1)
- case _ => throw new IllegalArgumentException
+ val source = if (sourcePolicyOpt.isDefined) {
+ sourcePolicyOpt.get match {
+ case KeyPolicy.HASH =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._1)
+ case KeyPolicy.UUID =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._1)
+ case _ =>
+ throw new IllegalArgumentException()
+ }
+ } else {
+ edge._1
}
- val target = targetPolicyOpt.get match {
- case HASH_POLICY =>
- ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2)
- case UUID_POLICY =>
- ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2)
- case _ => throw new IllegalArgumentException
+ val target = if (targetPolicyOpt.isDefined) {
+ targetPolicyOpt.get match {
+ case KeyPolicy.HASH =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._2)
+ case KeyPolicy.UUID =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._2)
+ case _ =>
+ throw new IllegalArgumentException()
+ }
+ } else {
+ edge._2
}
EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY
@@ -475,7 +587,7 @@ object SparkClientGenerator {
}
.toList
.mkString(", ")
- else
+ } else {
edges
.map { edge =>
// TODO: (darion.yaphet) dataframe.explode() would be better ?
@@ -483,22 +595,24 @@ object SparkClientGenerator {
yield
if (sourcePolicyOpt.isEmpty && targetPolicyOpt.isEmpty) {
EDGE_VALUE_TEMPLATE
- .format(source.toLong, edge._2, edge._3, edge._4)
+ .format(source, edge._2, edge._3, edge._4)
} else {
val source = sourcePolicyOpt.get match {
- case HASH_POLICY =>
- ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._1)
- case UUID_POLICY =>
- ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._1)
- case _ => throw new IllegalArgumentException
+ case KeyPolicy.HASH =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._1)
+ case KeyPolicy.UUID =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._1)
+ case _ =>
+ edge._1
}
val target = targetPolicyOpt.get match {
- case HASH_POLICY =>
- ENDPOINT_TEMPLATE.format(HASH_POLICY, edge._2)
- case UUID_POLICY =>
- ENDPOINT_TEMPLATE.format(UUID_POLICY, edge._2)
- case _ => throw new IllegalArgumentException
+ case KeyPolicy.HASH =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._2)
+ case KeyPolicy.UUID =>
+ ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._2)
+ case _ =>
+ edge._2
}
EDGE_VALUE_TEMPLATE_WITH_POLICY
@@ -508,53 +622,93 @@ object SparkClientGenerator {
}
.toList
.mkString(", ")
+ }
val exec = BATCH_INSERT_TEMPLATE
- .format(Type.Edge.toString, edgeName, nebulaProperties, values)
- LOG.debug(s"Exec : ${exec}")
- val future = client.execute(exec)
+ .format(Type.EDGE.toString, edgeName, nebulaProperties, values)
+ LOG.info(s"Exec : ${exec}")
+ if (rateLimiter.tryAcquire(rateTimeout, TimeUnit.MILLISECONDS)) {
+ val future = client.execute(exec)
+ futures += future
+ } else {
+ batchFailure.add(1)
+ LOG.debug("Save the error execution sentence into buffer")
+ errorBuffer += exec
+ if (errorBuffer.size == errorMaxSize) {
+ throw TooManyErrorsException(s"Too Many Errors ${errorMaxSize}")
+ }
+ }
+ }
+
+ val latch = new CountDownLatch(futures.size)
+ for (future <- futures) {
Futures.addCallback(
future,
new FutureCallback[Optional[Integer]] {
override def onSuccess(result: Optional[Integer]): Unit = {
+ latch.countDown()
batchSuccess.add(1)
}
override def onFailure(t: Throwable): Unit = {
+ latch.countDown()
if (batchFailure.value > DEFAULT_ERROR_TIMES) {
- throw new TooManyErrorException("too many error")
+ throw TooManyErrorsException("too many errors")
+ } else {
+ batchFailure.add(1)
}
- batchFailure.add(1)
}
- }
+ },
+ service
)
}
+
+ if (!errorBuffer.isEmpty) {
+ val fileSystem = FileSystem.get(new Configuration())
+ val errors = fileSystem.create(new Path(s"${errorPath}/${edgeName}"))
+ try {
+ for (error <- errorBuffer) {
+ errors.writeBytes(error)
+ errors.writeBytes(NEWLINE)
+ }
+ } finally {
+ errors.close()
+ fileSystem.close()
+ }
+ }
+ latch.await()
} else {
LOG.error(s"Switch ${space} Failed")
}
- client.close()
} else {
LOG.error(s"Client connection failed. ${user}:${pswd}")
}
+ service.shutdown()
+ while (!service.awaitTermination(100, TimeUnit.MILLISECONDS)) {
+ Thread.sleep(10)
+ }
+ client.close()
}
+ } else {
+ LOG.warn("Edge is not defined")
}
}
- } else {
- LOG.warn("Edge is not defined")
}
}
/**
* Create data source for different data type.
*
- * @param session The Spark Session.
- * @param pathOpt The path for config.
- * @param config The config.
+ * @param session The Spark Session.
+ * @param pathOpt The path for config.
+ * @param config The config.
* @return
*/
- private[this] def createDataSource(session: SparkSession,
- pathOpt: Option[String],
- config: Config): Option[DataFrame] = {
+ private[this] def createDataSource(
+ session: SparkSession,
+ pathOpt: Option[String],
+ config: Config
+ ): Option[DataFrame] = {
val `type` = config.getString("type")
pathOpt match {
@@ -598,9 +752,9 @@ object SparkClientGenerator {
.format("socket")
.option("host", host)
.option("port", port)
- .load())
+ .load()
+ )
}
-
case "kafka" => {
if (!config.hasPath("servers") || !config.hasPath("topic")) {
LOG.error("Reading kafka source should specify servers and topic")
@@ -615,7 +769,8 @@ object SparkClientGenerator {
.format("kafka")
.option("kafka.bootstrap.servers", server)
.option("subscribe", topic)
- .load())
+ .load()
+ )
}
case _ => {
LOG.error(s"Data source ${`type`} not supported")
@@ -630,8 +785,8 @@ object SparkClientGenerator {
* Extra value from the row by field name.
* When the field is null, we will fill it with default value.
*
- * @param row The row value.
- * @param field The field name.
+ * @param row The row value.
+ * @param field The field name.
* @return
*/
private[this] def extraValue(row: Row, field: String): Any = {
@@ -730,51 +885,48 @@ object SparkClientGenerator {
/**
* Check the statement execution result.
*
- * @param code The statement's execution result code.
+ * @param code The statement's execution result code.
* @return
*/
private[this] def isSuccessfully(code: Int) = code == ErrorCode.SUCCEEDED
/**
- * Check the statement execution result.
- * If the result code is not SUCCEEDED, will sleep a little while.
+ * Whether the edge is Geo supported.
*
- * @param code The sentence execute's result code.
- * @param interval The sleep interval.
+ * @param edgeConfig The config of edge.
* @return
*/
- private[this] def isSuccessfullyWithSleep(code: Int, interval: Long)(
- implicit exec: String): Boolean = {
- val result = isSuccessfully(code)
- if (!result) {
- LOG.error(s"Exec Failed: ${exec} retry interval ${interval}")
- Thread.sleep(interval)
- }
- result
+ private[this] def checkGeoSupported(edgeConfig: Config): Boolean = {
+ !edgeConfig.hasPath("source") &&
+ edgeConfig.hasPath("latitude") &&
+ edgeConfig.hasPath("longitude")
}
/**
- * Whether the edge is Geo supported.
+ * Get the config list by the path.
*
- * @param edgeConfig The config of edge.
+ * @param config The config.
+ * @param path The path of the config.
* @return
*/
- private[this] def checkGeoSupported(edgeConfig: Config): Boolean = {
- !edgeConfig.hasPath("source") &&
- edgeConfig.hasPath("latitude") &&
- edgeConfig.hasPath("longitude")
+ private[this] def getConfigsOrNone(config: Config, path: String) = {
+ if (config.hasPath(path)) {
+ Some(config.getConfigList(path))
+ } else {
+ None
+ }
}
/**
* Get the config by the path.
*
- * @param config The config.
- * @param path The path of the config.
+ * @param config
+ * @param path
* @return
*/
private[this] def getConfigOrNone(config: Config, path: String) = {
if (config.hasPath(path)) {
- Some(config.getObject(path))
+ Some(config.getConfig(path))
} else {
None
}
@@ -783,9 +935,9 @@ object SparkClientGenerator {
/**
* Get the value from config by the path. If the path not exist, return the default value.
*
- * @param config The config.
- * @param path The path of the config.
- * @param defaultValue The default value for the path.
+ * @param config The config.
+ * @param path The path of the config.
+ * @param defaultValue The default value for the path.
* @return
*/
private[this] def getOrElse[T](config: Config, path: String, defaultValue: T): T = {
@@ -796,11 +948,29 @@ object SparkClientGenerator {
}
}
+ /**
+ * Get the value from config by the path which is optional.
+ * If the path not exist, return the default value.
+ *
+ * @param config
+ * @param path
+ * @param defaultValue
+ * @tparam T
+ * @return
+ */
+ private[this] def getOptOrElse[T](config: Option[Config], path: String, defaultValue: T): T = {
+ if (!config.isEmpty && config.get.hasPath(path)) {
+ config.get.getAnyRef(path).asInstanceOf[T]
+ } else {
+ defaultValue
+ }
+ }
+
/**
* Calculate the coordinate's correlation id list.
*
- * @param lat The latitude of coordinate.
- * @param lng The longitude of coordinate.
+ * @param lat The latitude of coordinate.
+ * @param lng The longitude of coordinate.
* @return
*/
private[this] def indexCells(lat: Double, lng: Double): IndexedSeq[Long] = {
@@ -810,3 +980,4 @@ object SparkClientGenerator {
yield s2CellId.parent(index).id()
}
}
+
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/FNVHash.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/FNVHash.scala
index 577f5c4c6b3..b471256f8de 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/FNVHash.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/FNVHash.scala
@@ -13,9 +13,9 @@ package com.vesoft.tools
*/
object FNVHash {
- private val FNV_64_INIT = 0xcbf29ce484222325L
- private val FNV_64_PRIME = 0x100000001b3L
- private val FNV_32_INIT = 0x811c9dc5
+ private val FNV_64_INIT = 0xCBF29CE484222325L
+ private val FNV_64_PRIME = 0x100000001B3L
+ private val FNV_32_INIT = 0x811c9dc5
private val FNV_32_PRIME = 0x01000193
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/MappingConfiguration.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/MappingConfiguration.scala
index c7b7784734d..fcb3ed4351a 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/MappingConfiguration.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/MappingConfiguration.scala
@@ -12,11 +12,11 @@ import scala.io.{Codec, Source}
import scala.language.implicitConversions
/**
-* column mapping
-*
-* @param columnName hive column name
-* @param propertyName the property name this column maps to
-* @param `type` map to certain data type of nebula graph
+ * column mapping
+ *
+ * @param columnName hive column name
+ * @param propertyName the property name this column maps to
+ * @param `type` map to certain data type of nebula graph
*/
case class Column(
columnName: String,
@@ -58,8 +58,8 @@ object Column {
}
/**
-* a trait that both Tag and Edge should extends
-*/
+ * a trait that both Tag and Edge should extends
+ */
trait WithColumnMapping {
def tableName: String
@@ -70,14 +70,14 @@ trait WithColumnMapping {
}
/**
-* tag section of configuration file
-*
-* @param tableName hive table name
-* @param name tag name
-* @param primaryKey show the PK column
-* @param datePartitionKey date partition column,Hive table in production is usually Date partitioned
-* @param typePartitionKey type partition columns, when different vertex/edge's properties are identical,they are stored in one hive table, and partitioned by a `type` column
-* @param columnMappings map of hive table column to properties
+ * tag section of configuration file
+ *
+ * @param tableName hive table name
+ * @param name tag name
+ * @param primaryKey show the PK column
+ * @param datePartitionKey date partition column,Hive table in production is usually Date partitioned
+ * @param typePartitionKey type partition columns, when different vertex/edge's properties are identical,they are stored in one hive table, and partitioned by a `type` column
+ * @param columnMappings map of hive table column to properties
*/
case class Tag(
override val tableName: String,
@@ -158,15 +158,15 @@ object Tag {
}
/**
-* edge section of configuration file
-*
-* @param tableName hive table name
-* @param name edge type name
-* @param fromForeignKeyColumn srcID column
-* @param fromReferenceTag Tag srcID column referenced
-* @param toForeignKeyColumn dstID column
-* @param toReferenceTag Tag dstID column referenced
-* @param columnMappings map of hive table column to properties
+ * edge section of configuration file
+ *
+ * @param tableName hive table name
+ * @param name edge type name
+ * @param fromForeignKeyColumn srcID column
+ * @param fromReferenceTag Tag srcID column referenced
+ * @param toForeignKeyColumn dstID column
+ * @param toReferenceTag Tag dstID column referenced
+ * @param columnMappings map of hive table column to properties
*/
case class Edge(
override val tableName: String,
@@ -254,13 +254,13 @@ object Edge {
}
/**
-* a mapping file in-memory representation
-*
-* @param databaseName hive database name for this mapping configuration
-* @param partitions partition number of the target graphspace
-* @param tags tag's mapping
-* @param edges edge's mapping
-* @param keyPolicy policy used to generate unique id, default=hash_primary_key
+ * a mapping file in-memory representation
+ *
+ * @param databaseName hive database name for this mapping configuration
+ * @param partitions partition number of the target graphspace
+ * @param tags tag's mapping
+ * @param edges edge's mapping
+ * @param keyPolicy policy used to generate unique id, default=hash_primary_key
*/
case class MappingConfiguration(
databaseName: String,
@@ -293,7 +293,6 @@ object MappingConfiguration {
}
}
-
implicit val MappingConfigurationReads: Reads[MappingConfiguration] =
new Reads[MappingConfiguration] {
override def reads(json: JsValue): JsResult[MappingConfiguration] = {
@@ -343,11 +342,11 @@ object MappingConfiguration {
}
/**
-* construct from a mapping file
-*
-* @param mappingFile mapping file should be provided through "--files" option, and specified the application arg "---mapping_file_input"(--mi for short) at the same time,
-* it will be consumed as a classpath resource
-* @return MappingConfiguration instance
+ * construct from a mapping file
+ *
+ * @param mappingFile mapping file should be provided through "--files" option, and specified the application arg "---mapping_file_input"(--mi for short) at the same time,
+ * it will be consumed as a classpath resource
+ * @return MappingConfiguration instance
*/
def apply(mappingFile: String): MappingConfiguration = {
val bufferedSource = Source.fromFile(mappingFile)(Codec("UTF-8"))
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SparkSstFileGenerator.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SparkSstFileGenerator.scala
index 552fe1d2867..f1445f51f1b 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SparkSstFileGenerator.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SparkSstFileGenerator.scala
@@ -110,12 +110,10 @@ object SparkSstFileGenerator {
)
.build
-
// when the newest data arrive, used in non-incremental environment
val latestDate = CliOption
.builder("di")
.longOpt("latest_date_input")
-
.required()
.hasArg()
.desc("Latest date to query,date format YYYY-MM-dd")
@@ -130,12 +128,10 @@ object SparkSstFileGenerator {
)
.build
-
// may be used in some test run to prove the correctness
val limit = CliOption
.builder("li")
.longOpt("limit_input")
-
.hasArg()
.desc(
"Return at most this number of edges/vertex, usually used in POC stage, when omitted, fetch all data."
@@ -332,7 +328,6 @@ object SparkSstFileGenerator {
case None => (key: String) => FNVHash.hash64(key)
}
-
// implicit ordering used by PairedRDD.repartitionAndSortWithinPartitions whose key is PartitionIdAndBytesEncoded typed
implicit def ordering[A <: GraphPartitionIdAndKeyValueEncoded]: Ordering[A] = new Ordering[A] {
override def compare(x: A, y: A): Int = {
@@ -367,9 +362,7 @@ object SparkSstFileGenerator {
}
val whereClause = tag.typePartitionKey
- .map(
- key => s"${key}='${tag.name}' AND ${datePartitionKey}='${latestDate}'"
- )
+ .map(key => s"${key}='${tag.name}' AND ${datePartitionKey}='${latestDate}'")
.getOrElse(s"${datePartitionKey}='${latestDate}'")
//TODO:to handle multiple partition columns' Cartesian product
val sql =
@@ -453,7 +446,6 @@ object SparkSstFileGenerator {
mappingConfiguration.databaseName
)
-
val columnExpression = {
assert(allColumns.size > 0)
s"${edge.fromForeignKeyColumn},${edge.toForeignKeyColumn}," + allColumns
@@ -462,12 +454,9 @@ object SparkSstFileGenerator {
}
val whereClause = edge.typePartitionKey
- .map(
- key => s"${key}='${edge.name}' AND ${datePartitionKey}='${latestDate}'"
- )
+ .map(key => s"${key}='${edge.name}' AND ${datePartitionKey}='${latestDate}'")
.getOrElse(s"${datePartitionKey}='${latestDate}'")
-
//TODO: join FROM_COLUMN and join TO_COLUMN from the table where this columns referencing, to make sure that the claimed id really exists in the reference table.BUT with HUGE Perf penalty
val edgeDf = sqlContext.sql(
s"SELECT ${columnExpression} FROM ${mappingConfiguration.databaseName}.${edge.tableName} WHERE ${whereClause} ${limit}"
@@ -475,24 +464,20 @@ object SparkSstFileGenerator {
assert(edgeDf.count() > 0)
//RDD[Tuple3(from_vertex_businessKey,end_vertex_businessKey,values)]
val edgeKeyAndValues: RDD[(String, String, Seq[AnyRef])] =
- edgeDf.map(
- row => {
- (
- row.getAs[String](edge.fromForeignKeyColumn), // consistent with vertexId generation logic, to make sure that vertex and its' outbound edges are in the same partition
- row.getAs[String](edge.toForeignKeyColumn),
- allColumns
- .filterNot(
- col =>
- (col.columnName.equalsIgnoreCase(
- edge.fromForeignKeyColumn
- ) || col.columnName
- .equalsIgnoreCase(edge.toForeignKeyColumn))
- )
- .map(valueExtractor(row, _, charset))
- )
- }
-
- )
+ edgeDf.map(row => {
+ (
+ row.getAs[String](edge.fromForeignKeyColumn), // consistent with vertexId generation logic, to make sure that vertex and its' outbound edges are in the same partition
+ row.getAs[String](edge.toForeignKeyColumn),
+ allColumns
+ .filterNot(col =>
+ (col.columnName.equalsIgnoreCase(
+ edge.fromForeignKeyColumn
+ ) || col.columnName
+ .equalsIgnoreCase(edge.toForeignKeyColumn))
+ )
+ .map(valueExtractor(row, _, charset))
+ )
+ })
edgeKeyAndValues
.map {
@@ -615,7 +600,6 @@ object SparkSstFileGenerator {
)
}
-
// check the claimed columns really exist in db
colsMustCheck.map(_.toUpperCase).foreach { col =>
if (allColumnMap.get(col).isEmpty) {
@@ -646,13 +630,11 @@ object SparkSstFileGenerator {
// tag/edge's columnMappings should be checked and returned
val columnMappings = edge.columnMappings.get
val notValid = columnMappings
- .filter(
- col => {
- val typeInDb = allColumnMap.get(col.columnName.toUpperCase)
- typeInDb.isEmpty || !DataTypeCompatibility
- .isCompatible(col.`type`, typeInDb.get)
- }
- )
+ .filter(col => {
+ val typeInDb = allColumnMap.get(col.columnName.toUpperCase)
+ typeInDb.isEmpty || !DataTypeCompatibility
+ .isCompatible(col.`type`, typeInDb.get)
+ })
.map {
case col => s"name=${col.columnName},type=${col.`type`}"
}
diff --git a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SstFileOutputFormat.scala b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SstFileOutputFormat.scala
index e9fe036c863..8def9d8b44f 100644
--- a/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SstFileOutputFormat.scala
+++ b/src/tools/spark-sstfile-generator/src/main/scala/com/vesoft/tools/SstFileOutputFormat.scala
@@ -170,7 +170,6 @@ class SstRecordWriter(localSstFileOutput: String, configuration: Configuration)
val hdfsSubDirectory =
s"${File.separator}${key.partitionId}${File.separator}"
-
val localDir = s"${localSstFileOutput}${hdfsSubDirectory}"
val sstFileName =
s"${value.vertexOrEdgeEnum}-${key.`type`}-${DatatypeConverter
@@ -224,7 +223,6 @@ class SstRecordWriter(localSstFileOutput: String, configuration: Configuration)
}
}
-
try {
// There could be multiple containers on a single host, parent dir are shared between multiple containers,
// so should not delete parent dir but delete individual file
diff --git a/src/tools/spark-sstfile-generator/src/test/scala/com/vesoft/nebula/tools/generator/v2/ConfigTest.scala b/src/tools/spark-sstfile-generator/src/test/scala/com/vesoft/nebula/tools/generator/v2/ConfigTest.scala
index 9f4d7c2eb0e..2fbbdba3545 100644
--- a/src/tools/spark-sstfile-generator/src/test/scala/com/vesoft/nebula/tools/generator/v2/ConfigTest.scala
+++ b/src/tools/spark-sstfile-generator/src/test/scala/com/vesoft/nebula/tools/generator/v2/ConfigTest.scala
@@ -8,6 +8,4 @@ package com.vesoft.nebula.tools.generator.v2
import org.scalatest.{BeforeAndAfter, FlatSpec}
-class ConfigTest extends FlatSpec with BeforeAndAfter {
-
-}
+class ConfigTest extends FlatSpec with BeforeAndAfter {}