Skip to content

Commit

Permalink
Merge pull request #2638 from vasilmkd/circular-overflow-fix
Browse files Browse the repository at this point in the history
`PipedStreamBuffer` circular copying overflow fix
  • Loading branch information
mpilquist authored Sep 28, 2021
2 parents a2a4740 + 7c30d6f commit b6790a0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 47 deletions.
8 changes: 4 additions & 4 deletions io/jvm/src/main/scala/fs2/io/internal/PipedStreamBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (head != tail) {
// There is at least one byte to read.
val byte = buffer(head % capacity) & 0xff
val byte = buffer(Integer.remainderUnsigned(head, capacity)) & 0xff
// The byte is marked as read by advancing the head of the
// circular buffer.
head += 1
Expand Down Expand Up @@ -211,7 +211,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstPos: Int,
length: Int
): Unit = {
val srcOffset = srcPos % srcCap
val srcOffset = Integer.remainderUnsigned(srcPos, srcCap)
if (srcOffset + length >= srcCap) {
val batch1 = srcCap - srcOffset
val batch2 = length - batch1
Expand All @@ -237,7 +237,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (tail - head < capacity) {
// There is capacity for at least one byte to be written.
buffer(tail % capacity) = (b & 0xff).toByte
buffer(Integer.remainderUnsigned(tail, capacity)) = (b & 0xff).toByte
// The byte is marked as written by advancing the tail of the
// circular buffer.
tail += 1
Expand Down Expand Up @@ -364,7 +364,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstCap: Int,
length: Int
): Unit = {
val dstOffset = dstPos % dstCap
val dstOffset = Integer.remainderUnsigned(dstPos, dstCap)
if (dstOffset + length >= dstCap) {
val batch1 = dstCap - dstOffset
val batch2 = length - batch1
Expand Down
113 changes: 70 additions & 43 deletions io/jvm/src/test/scala/fs2/io/IoPlatformSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@
package fs2
package io

import java.io.OutputStream
import java.util.concurrent.Executors
import cats.effect.{IO, Resource}
import cats.effect.unsafe.{IORuntime, IORuntimeConfig}
import fs2.Fs2Suite
import fs2.Err
import scala.concurrent.ExecutionContext
import fs2.{Err, Fs2Suite}
import org.scalacheck.{Arbitrary, Gen, Shrink}
import org.scalacheck.effect.PropF.forAllF

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._

import java.io.OutputStream
import java.nio.charset.StandardCharsets
import java.util.concurrent.Executors

class IoPlatformSuite extends Fs2Suite {

// This suite runs for a long time, this avoids timeouts in CI.
override def munitTimeout: Duration = 1.minute

group("readOutputStream") {
test("writes data and terminates when `f` returns") {
forAllF { (bytes: Array[Byte], chunkSize0: Int) =>
Expand Down Expand Up @@ -130,51 +136,72 @@ class IoPlatformSuite extends Fs2Suite {
.drain
}
}
}

test("Doesn't deadlock with size-1 thread pool") {
def singleThreadedRuntime(): IORuntime = {
val compute = {
val pool = Executors.newSingleThreadExecutor()
(ExecutionContext.fromExecutor(pool), () => pool.shutdown())
test("Doesn't deadlock with size-1 thread pool") {
def singleThreadedRuntime(): IORuntime = {
val compute = {
val pool = Executors.newSingleThreadExecutor()
(ExecutionContext.fromExecutor(pool), () => pool.shutdown())
}
val blocking = IORuntime.createDefaultBlockingExecutionContext()
val scheduler = IORuntime.createDefaultScheduler()
IORuntime(
compute._1,
blocking._1,
scheduler._1,
() => {
compute._2.apply()
blocking._2.apply()
scheduler._2.apply()
},
IORuntimeConfig()
)
}

val runtime = Resource.make(IO(singleThreadedRuntime()))(rt => IO(rt.shutdown()))

def write(os: OutputStream): IO[Unit] =
IO.blocking {
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
}

val prog = readOutputStream[IO](chunkSize = 1)(write)
.take(5)
.compile
.toVector
.map(_.size)
.assertEquals(5)

runtime.use { rt =>
IO.fromFuture(IO(prog.unsafeToFuture()(rt)))
}
val blocking = IORuntime.createDefaultBlockingExecutionContext()
val scheduler = IORuntime.createDefaultScheduler()
IORuntime(
compute._1,
blocking._1,
scheduler._1,
() => {
compute._2.apply()
blocking._2.apply()
scheduler._2.apply()
},
IORuntimeConfig()
)
}

val runtime = Resource.make(IO(singleThreadedRuntime()))(rt => IO(rt.shutdown()))
test("can copy more than Int.MaxValue bytes") {
// Unit test adapted from the original issue reproduction at https://github.com/mrdziuban/fs2-writeOutputStream.

def write(os: OutputStream): IO[Unit] =
IO.blocking {
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
}
val byteStream =
Stream
.chunk[IO, Byte](Chunk.array(("foobar" * 50000).getBytes(StandardCharsets.UTF_8)))
.repeatN(7200L) // 6 * 50,000 * 7,200 == 2,160,000,000 > 2,147,483,647 == Int.MaxValue

val prog = readOutputStream[IO](chunkSize = 1)(write)
.take(5)
.compile
.toVector
.map(_.size)
.assertEquals(5)
def writeToOutputStream(out: OutputStream): IO[Unit] =
byteStream
.through(writeOutputStream(IO.pure(out)))
.compile
.drain

runtime.use { rt =>
IO.fromFuture(IO(prog.unsafeToFuture()(rt)))
readOutputStream[IO](1024 * 8)(writeToOutputStream)
.chunkN(6 * 50000)
.map(c => new String(c.toArray[Byte], StandardCharsets.UTF_8))
.foreach(str => IO.pure(str).assertEquals("foobar" * 50000))
.compile
.drain
}
}

}

0 comments on commit b6790a0

Please sign in to comment.