-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move Avro sampler to FileSystems API #140
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,56 +17,62 @@ | |
|
||
package com.spotify.ratatool.samplers | ||
|
||
import com.spotify.ratatool.GcsConfiguration | ||
import org.apache.avro.file.DataFileReader | ||
import java.nio.channels.Channels | ||
|
||
import org.apache.avro.file.{DataFileReader, DataFileStream} | ||
import org.apache.avro.generic.{GenericDatumReader, GenericRecord} | ||
import org.apache.hadoop.fs._ | ||
import org.apache.beam.sdk.io.FileSystems | ||
import org.apache.beam.sdk.io.fs.ResourceId | ||
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory} | ||
import org.slf4j.{Logger, LoggerFactory} | ||
|
||
import scala.collection.mutable.{ListBuffer, Set => MSet} | ||
import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set => MSet} | ||
import scala.concurrent.ExecutionContext.Implicits.global | ||
import scala.concurrent.duration.Duration | ||
import scala.concurrent.{Await, Future} | ||
import scala.util.Random | ||
import scala.collection.JavaConverters._ | ||
|
||
/** Sampler for Avro files. */ | ||
class AvroSampler(path: Path, protected val seed: Option[Long] = None) | ||
class AvroSampler(path: String, protected val seed: Option[Long] = None, | ||
protected val conf: Option[PipelineOptions] = None) | ||
extends Sampler[GenericRecord] { | ||
|
||
private val logger: Logger = LoggerFactory.getLogger(classOf[AvroSampler]) | ||
|
||
private def getFileContext: FileContext = FileContext.getFileContext(GcsConfiguration.get()) | ||
// private def getFileContext: FileContext = FileContext.getFileContext(GcsConfiguration.get()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove comment? |
||
|
||
override def sample(n: Long, head: Boolean): Seq[GenericRecord] = { | ||
require(n > 0, "n must be > 0") | ||
logger.info("Taking a sample of {} from Avro {}", n, path) | ||
|
||
val fs = FileSystem.get(path.toUri, GcsConfiguration.get()) | ||
if (fs.isFile(path)) { | ||
new AvroFileSampler(getFileContext, path, seed).sample(n, head) | ||
// val fs = FileSystem.get(path.toUri, GcsConfiguration.get()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove comment |
||
FileSystems.setDefaultPipelineOptions(conf.getOrElse(PipelineOptionsFactory.create())) | ||
val matches = FileSystems.`match`(path).metadata().asScala | ||
if (!FileSystems.hasGlobWildcard(path)) { | ||
val resource = matches.head.resourceId() | ||
new AvroFileSampler(resource, seed).sample(n, head) | ||
} else { | ||
val filter = new PathFilter { | ||
override def accept(path: Path): Boolean = path.getName.endsWith(".avro") | ||
} | ||
val statuses = fs.listStatus(path, filter).sortBy(_.getPath.toString) | ||
val paths = statuses.map(_.getPath) | ||
|
||
if (head) { | ||
val resources = matches.map(_.resourceId()) | ||
.sortBy(_.toString) | ||
// read from the start | ||
val result = ListBuffer.empty[GenericRecord] | ||
val iter = paths.iterator | ||
val iter = resources.toIterator | ||
while (result.size < n && iter.hasNext) { | ||
result.appendAll(new AvroFileSampler(getFileContext, iter.next()).sample(n, head)) | ||
result.appendAll(new AvroFileSampler(iter.next()).sample(n, head)) | ||
} | ||
result | ||
} else { | ||
val tups = matches | ||
.map(md => (md.resourceId(), md.sizeBytes())) | ||
.sortBy(_._1.toString).toArray | ||
// randomly sample from shards | ||
val sizes = statuses.map(_.getLen) | ||
val sizes = tups.map(_._2) | ||
val resources = tups.map(_._1) | ||
val samples = scaleWeights(sizes, n) | ||
val futures = paths.zip(samples).map { case (p, s) => | ||
val fc = getFileContext | ||
Future(new AvroFileSampler(fc, p) | ||
.sample(s, head)) | ||
val futures = resources.zip(samples).map { case (r, s) => | ||
Future(new AvroFileSampler(r).sample(s, head)) | ||
}.toSeq | ||
Await.result(Future.sequence(futures), Duration.Inf).flatten | ||
} | ||
|
@@ -92,51 +98,44 @@ class AvroSampler(path: Path, protected val seed: Option[Long] = None) | |
|
||
} | ||
|
||
private class AvroFileSampler(fc: FileContext, path: Path, protected val seed: Option[Long] = None) | ||
private class AvroFileSampler(r: ResourceId, protected val seed: Option[Long] = None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't it be more generic to expect the file path instead of the resource id? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it doesn't matter that much since the class is private |
||
extends Sampler[GenericRecord] { | ||
|
||
private val logger: Logger = LoggerFactory.getLogger(classOf[AvroFileSampler]) | ||
|
||
override def sample(n: Long, head: Boolean): Seq[GenericRecord] = { | ||
require(n > 0, "n must be > 0") | ||
logger.debug("Taking a sample of {} from Avro file {}", n, path) | ||
|
||
val input = new AvroFSInput(fc, path) | ||
val input = Channels.newInputStream(FileSystems.open(r)) | ||
val datumReader = new GenericDatumReader[GenericRecord]() | ||
val fileReader = DataFileReader.openReader(input, datumReader) | ||
val fileStream = new DataFileStream[GenericRecord](input, datumReader) | ||
fileStream.getBlockCount | ||
|
||
val schema = fileReader.getSchema | ||
val schema = fileStream.getSchema | ||
logger.debug("Avro schema {}", schema) | ||
val start = fileReader.tell() | ||
val end = input.length() | ||
val range = end - start | ||
|
||
val result = ListBuffer.empty[GenericRecord] | ||
val result = ArrayBuffer.empty[GenericRecord] | ||
if (head) { | ||
// read from the start | ||
while (result.size < n && fileReader.hasNext) { | ||
result.append(fileReader.next()) | ||
while (result.size < n && fileStream.hasNext) { | ||
result.append(fileStream.next()) | ||
} | ||
} else { | ||
// rejection sampling until n unique samples are obtained | ||
val positions = MSet.empty[Long] | ||
var collisions = 0 | ||
while (result.size < n && collisions < 10) { | ||
// pick a random offset and move to the next sync point | ||
val off = start + nextLong(range) | ||
fileReader.sync(off) | ||
val pos = fileReader.tell() | ||
|
||
// sync position may be sampled already | ||
if (positions.contains(pos) || !fileReader.hasNext) { | ||
collisions += 1 | ||
logger.debug("Sync point collision {} at position {}", collisions, pos) | ||
} else { | ||
collisions = 0 | ||
positions.add(pos) | ||
result.append(fileReader.next()) | ||
logger.debug("New sample sync point at position {}", pos) | ||
// Reservoir sample imperative way | ||
// Fill result with first n elements | ||
while (result.size < n && fileStream.hasNext) { | ||
result.append(fileStream.next()) | ||
} | ||
|
||
// Then randomly select from all other elements in the stream | ||
var index = n | ||
while (fileStream.hasNext) { | ||
val next = fileStream.next() | ||
val loc = nextLong(index + 1) | ||
if (loc < n) { | ||
result(loc.toInt) = next | ||
} | ||
index += 1 | ||
} | ||
} | ||
result | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
package com.spotify.ratatool.samplers | ||
|
||
import java.io.File | ||
import java.nio.file.Files | ||
import java.nio.file.{Files, Paths} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary import? |
||
|
||
import com.spotify.ratatool.Schemas | ||
import com.spotify.ratatool.scalacheck._ | ||
|
@@ -35,6 +35,7 @@ class AvroSamplerTest extends FlatSpec with Matchers with BeforeAndAfterAllConfi | |
val dir = Files.createTempDirectory("ratatool-") | ||
val file1 = new File(dir.toString, "part-00000.avro") | ||
val file2 = new File(dir.toString, "part-00001.avro") | ||
val dirWildcard = new File(dir.toString, "*.avro") | ||
|
||
override protected def beforeAll(configMap: ConfigMap): Unit = { | ||
AvroIO.writeToFile(data1, schema, file1) | ||
|
@@ -46,25 +47,25 @@ class AvroSamplerTest extends FlatSpec with Matchers with BeforeAndAfterAllConfi | |
} | ||
|
||
"AvroSampler" should "support single file in head mode" in { | ||
val result = new AvroSampler(new Path(file1.getAbsolutePath)).sample(10, head = true) | ||
val result = new AvroSampler(file1.getAbsolutePath).sample(10, head = true) | ||
result.size shouldBe 10 | ||
result should equal (data1.take(10)) | ||
} | ||
|
||
it should "support single file in random mode" in { | ||
val result = new AvroSampler(new Path(file1.getAbsolutePath)).sample(10, head = false) | ||
val result = new AvroSampler(file1.getAbsolutePath).sample(10, head = false) | ||
result.size shouldBe 10 | ||
result.forall(data1.contains(_)) shouldBe true | ||
} | ||
|
||
it should "support multiple files in head mode" in { | ||
val result = new AvroSampler(new Path(dir.toString)).sample(10, head = true) | ||
val result = new AvroSampler(dirWildcard.getAbsolutePath).sample(10, head = true) | ||
result.size shouldBe 10 | ||
result should equal (data1.take(10)) | ||
} | ||
|
||
it should "support multiple files in random mode" in { | ||
val result = new AvroSampler(new Path(dir.toString)).sample(10, head = false) | ||
val result = new AvroSampler(dirWildcard.getAbsolutePath).sample(10, head = false) | ||
result.size shouldBe 10 | ||
result.exists(data1.contains(_)) shouldBe true | ||
result.exists(data2.contains(_)) shouldBe true | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if we can get rid of this class.
Seems like it is only being used in
ParquetIO
(besides AvroSampler and an example)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, Parquet cleanup is not a priority for me at the moment but I think we can get rid of it once that's done.