Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Avro sampler to FileSystems API #140

Merged
merged 2 commits into from
Nov 8, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object Ratatool {
val o = opts.get
o.mode match {
case "avro" =>
val data = new AvroSampler(new Path(o.in)).sample(o.n, o.head)
val data = new AvroSampler(o.in).sample(o.n, o.head)
AvroIO.writeToFile(data, data.head.getSchema, o.out)
case "bigquery" =>
val sampler = new BigQuerySampler(BigQueryIO.parseTableSpec(o.in))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package com.spotify.ratatool.diffy

import java.net.URI

import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema}
import com.google.protobuf.AbstractMessage
import com.spotify.ratatool.{Command, GcsConfiguration}
import com.spotify.ratatool.Command
import com.spotify.ratatool.samplers.AvroSampler
import com.spotify.scio._
import com.spotify.scio.bigquery.BigQueryClient
Expand All @@ -31,8 +29,7 @@ import com.spotify.scio.values.SCollection
import com.twitter.algebird._
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.beam.sdk.io.TextIO
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.beam.sdk.io.{FileSystems, TextIO}

import scala.annotation.tailrec
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -480,10 +477,9 @@ object BigDiffy extends Command {

val result = inputMode match {
case "avro" =>
// TODO: handle schema evolution
val fs = FileSystem.get(new URI(rhs), GcsConfiguration.get())
val path = fs.globStatus(new Path(rhs)).head.getPath
val schema = new AvroSampler(path).sample(1, true).head.getSchema
// TODO: handle schema
val schema = new AvroSampler(rhs, conf = Some(sc.options))
.sample(1, head = true).head.getSchema
val diffy = new AvroDiffy[GenericRecord](ignore, unordered)
BigDiffy.diffAvro[GenericRecord](sc, lhs, rhs, avroKeyFn(key), diffy, schema)
case "bigquery" =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ public static Configuration get() {

conf.setIfUnset(
HadoopCredentialConfiguration.BASE_KEY_PREFIX +
HadoopCredentialConfiguration.ENABLE_SERVICE_ACCOUNTS_SUFFIX,
HadoopCredentialConfiguration.ENABLE_SERVICE_ACCOUNTS_SUFFIX,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we can get rid of this class.
Seems like it is only being used in ParquetIO (besides AvroSampler and an example)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, Parquet cleanup is not a priority for me at the moment but I think we can get rid of it once that's done.

"false");
conf.setIfUnset(
HadoopCredentialConfiguration.BASE_KEY_PREFIX +
HadoopCredentialConfiguration.CLIENT_ID_SUFFIX,
HadoopCredentialConfiguration.CLIENT_ID_SUFFIX,
clientId);
conf.setIfUnset(
HadoopCredentialConfiguration.BASE_KEY_PREFIX +
HadoopCredentialConfiguration.CLIENT_SECRET_SUFFIX,
HadoopCredentialConfiguration.CLIENT_SECRET_SUFFIX,
clientSecret);
}
}
Expand Down Expand Up @@ -157,4 +157,4 @@ private static Map<String, String> getEnvironment() {
return System.getenv();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,62 @@

package com.spotify.ratatool.samplers

import com.spotify.ratatool.GcsConfiguration
import org.apache.avro.file.DataFileReader
import java.nio.channels.Channels

import org.apache.avro.file.{DataFileReader, DataFileStream}
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.hadoop.fs._
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.{ListBuffer, Set => MSet}
import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set => MSet}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}
import scala.util.Random
import scala.collection.JavaConverters._

/** Sampler for Avro files. */
class AvroSampler(path: Path, protected val seed: Option[Long] = None)
class AvroSampler(path: String, protected val seed: Option[Long] = None,
protected val conf: Option[PipelineOptions] = None)
extends Sampler[GenericRecord] {

private val logger: Logger = LoggerFactory.getLogger(classOf[AvroSampler])

private def getFileContext: FileContext = FileContext.getFileContext(GcsConfiguration.get())
// private def getFileContext: FileContext = FileContext.getFileContext(GcsConfiguration.get())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment?


override def sample(n: Long, head: Boolean): Seq[GenericRecord] = {
require(n > 0, "n must be > 0")
logger.info("Taking a sample of {} from Avro {}", n, path)

val fs = FileSystem.get(path.toUri, GcsConfiguration.get())
if (fs.isFile(path)) {
new AvroFileSampler(getFileContext, path, seed).sample(n, head)
// val fs = FileSystem.get(path.toUri, GcsConfiguration.get())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment

FileSystems.setDefaultPipelineOptions(conf.getOrElse(PipelineOptionsFactory.create()))
val matches = FileSystems.`match`(path).metadata().asScala
if (!FileSystems.hasGlobWildcard(path)) {
val resource = matches.head.resourceId()
new AvroFileSampler(resource, seed).sample(n, head)
} else {
val filter = new PathFilter {
override def accept(path: Path): Boolean = path.getName.endsWith(".avro")
}
val statuses = fs.listStatus(path, filter).sortBy(_.getPath.toString)
val paths = statuses.map(_.getPath)

if (head) {
val resources = matches.map(_.resourceId())
.sortBy(_.toString)
// read from the start
val result = ListBuffer.empty[GenericRecord]
val iter = paths.iterator
val iter = resources.toIterator
while (result.size < n && iter.hasNext) {
result.appendAll(new AvroFileSampler(getFileContext, iter.next()).sample(n, head))
result.appendAll(new AvroFileSampler(iter.next()).sample(n, head))
}
result
} else {
val tups = matches
.map(md => (md.resourceId(), md.sizeBytes()))
.sortBy(_._1.toString).toArray
// randomly sample from shards
val sizes = statuses.map(_.getLen)
val sizes = tups.map(_._2)
val resources = tups.map(_._1)
val samples = scaleWeights(sizes, n)
val futures = paths.zip(samples).map { case (p, s) =>
val fc = getFileContext
Future(new AvroFileSampler(fc, p)
.sample(s, head))
val futures = resources.zip(samples).map { case (r, s) =>
Future(new AvroFileSampler(r).sample(s, head))
}.toSeq
Await.result(Future.sequence(futures), Duration.Inf).flatten
}
Expand All @@ -92,51 +98,44 @@ class AvroSampler(path: Path, protected val seed: Option[Long] = None)

}

private class AvroFileSampler(fc: FileContext, path: Path, protected val seed: Option[Long] = None)
private class AvroFileSampler(r: ResourceId, protected val seed: Option[Long] = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be more generic to expect the file path instead of the resource id?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't matter that much since the class is private

extends Sampler[GenericRecord] {

private val logger: Logger = LoggerFactory.getLogger(classOf[AvroFileSampler])

override def sample(n: Long, head: Boolean): Seq[GenericRecord] = {
require(n > 0, "n must be > 0")
logger.debug("Taking a sample of {} from Avro file {}", n, path)

val input = new AvroFSInput(fc, path)
val input = Channels.newInputStream(FileSystems.open(r))
val datumReader = new GenericDatumReader[GenericRecord]()
val fileReader = DataFileReader.openReader(input, datumReader)
val fileStream = new DataFileStream[GenericRecord](input, datumReader)
fileStream.getBlockCount

val schema = fileReader.getSchema
val schema = fileStream.getSchema
logger.debug("Avro schema {}", schema)
val start = fileReader.tell()
val end = input.length()
val range = end - start

val result = ListBuffer.empty[GenericRecord]
val result = ArrayBuffer.empty[GenericRecord]
if (head) {
// read from the start
while (result.size < n && fileReader.hasNext) {
result.append(fileReader.next())
while (result.size < n && fileStream.hasNext) {
result.append(fileStream.next())
}
} else {
// rejection sampling until n unique samples are obtained
val positions = MSet.empty[Long]
var collisions = 0
while (result.size < n && collisions < 10) {
// pick a random offset and move to the next sync point
val off = start + nextLong(range)
fileReader.sync(off)
val pos = fileReader.tell()

// sync position may be sampled already
if (positions.contains(pos) || !fileReader.hasNext) {
collisions += 1
logger.debug("Sync point collision {} at position {}", collisions, pos)
} else {
collisions = 0
positions.add(pos)
result.append(fileReader.next())
logger.debug("New sample sync point at position {}", pos)
// Reservoir sample imperative way
// Fill result with first n elements
while (result.size < n && fileStream.hasNext) {
result.append(fileStream.next())
}

// Then randomly select from all other elements in the stream
var index = n
while (fileStream.hasNext) {
val next = fileStream.next()
val loc = nextLong(index + 1)
if (loc < n) {
result(loc.toInt) = next
}
index += 1
}
}
result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package com.spotify.ratatool.samplers

import java.io.File
import java.nio.file.Files
import java.nio.file.{Files, Paths}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary import?


import com.spotify.ratatool.Schemas
import com.spotify.ratatool.scalacheck._
Expand All @@ -35,6 +35,7 @@ class AvroSamplerTest extends FlatSpec with Matchers with BeforeAndAfterAllConfi
val dir = Files.createTempDirectory("ratatool-")
val file1 = new File(dir.toString, "part-00000.avro")
val file2 = new File(dir.toString, "part-00001.avro")
val dirWildcard = new File(dir.toString, "*.avro")

override protected def beforeAll(configMap: ConfigMap): Unit = {
AvroIO.writeToFile(data1, schema, file1)
Expand All @@ -46,25 +47,25 @@ class AvroSamplerTest extends FlatSpec with Matchers with BeforeAndAfterAllConfi
}

"AvroSampler" should "support single file in head mode" in {
val result = new AvroSampler(new Path(file1.getAbsolutePath)).sample(10, head = true)
val result = new AvroSampler(file1.getAbsolutePath).sample(10, head = true)
result.size shouldBe 10
result should equal (data1.take(10))
}

it should "support single file in random mode" in {
val result = new AvroSampler(new Path(file1.getAbsolutePath)).sample(10, head = false)
val result = new AvroSampler(file1.getAbsolutePath).sample(10, head = false)
result.size shouldBe 10
result.forall(data1.contains(_)) shouldBe true
}

it should "support multiple files in head mode" in {
val result = new AvroSampler(new Path(dir.toString)).sample(10, head = true)
val result = new AvroSampler(dirWildcard.getAbsolutePath).sample(10, head = true)
result.size shouldBe 10
result should equal (data1.take(10))
}

it should "support multiple files in random mode" in {
val result = new AvroSampler(new Path(dir.toString)).sample(10, head = false)
val result = new AvroSampler(dirWildcard.getAbsolutePath).sample(10, head = false)
result.size shouldBe 10
result.exists(data1.contains(_)) shouldBe true
result.exists(data2.contains(_)) shouldBe true
Expand Down