Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PipedStreamBuffer circular copying overflow fix #2638

Merged
merged 7 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions io/jvm/src/main/scala/fs2/io/internal/PipedStreamBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (head != tail) {
// There is at least one byte to read.
val byte = buffer(head % capacity) & 0xff
val byte = buffer(Integer.remainderUnsigned(head, capacity)) & 0xff
// The byte is marked as read by advancing the head of the
// circular buffer.
head += 1
Expand Down Expand Up @@ -211,7 +211,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstPos: Int,
length: Int
): Unit = {
val srcOffset = srcPos % srcCap
val srcOffset = Integer.remainderUnsigned(srcPos, srcCap)
if (srcOffset + length >= srcCap) {
val batch1 = srcCap - srcOffset
val batch2 = length - batch1
Expand All @@ -237,7 +237,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (tail - head < capacity) {
// There is capacity for at least one byte to be written.
buffer(tail % capacity) = (b & 0xff).toByte
buffer(Integer.remainderUnsigned(tail, capacity)) = (b & 0xff).toByte
// The byte is marked as written by advancing the tail of the
// circular buffer.
tail += 1
Expand Down Expand Up @@ -364,7 +364,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstCap: Int,
length: Int
): Unit = {
val dstOffset = dstPos % dstCap
val dstOffset = Integer.remainderUnsigned(dstPos, dstCap)
if (dstOffset + length >= dstCap) {
val batch1 = dstCap - dstOffset
val batch2 = length - batch1
Expand Down
113 changes: 70 additions & 43 deletions io/jvm/src/test/scala/fs2/io/IoPlatformSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@
package fs2
package io

import java.io.OutputStream
import java.util.concurrent.Executors
import cats.effect.{IO, Resource}
import cats.effect.unsafe.{IORuntime, IORuntimeConfig}
import fs2.Fs2Suite
import fs2.Err
import scala.concurrent.ExecutionContext
import fs2.{Err, Fs2Suite}
import org.scalacheck.{Arbitrary, Gen, Shrink}
import org.scalacheck.effect.PropF.forAllF

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._

import java.io.OutputStream
import java.nio.charset.StandardCharsets
import java.util.concurrent.Executors

class IoPlatformSuite extends Fs2Suite {

// This suite runs for a long time, this avoids timeouts in CI.
override def munitTimeout: Duration = 1.minute

group("readOutputStream") {
test("writes data and terminates when `f` returns") {
forAllF { (bytes: Array[Byte], chunkSize0: Int) =>
Expand Down Expand Up @@ -130,51 +136,72 @@ class IoPlatformSuite extends Fs2Suite {
.drain
}
}
}

test("Doesn't deadlock with size-1 thread pool") {
def singleThreadedRuntime(): IORuntime = {
val compute = {
val pool = Executors.newSingleThreadExecutor()
(ExecutionContext.fromExecutor(pool), () => pool.shutdown())
test("Doesn't deadlock with size-1 thread pool") {
def singleThreadedRuntime(): IORuntime = {
val compute = {
val pool = Executors.newSingleThreadExecutor()
(ExecutionContext.fromExecutor(pool), () => pool.shutdown())
}
val blocking = IORuntime.createDefaultBlockingExecutionContext()
val scheduler = IORuntime.createDefaultScheduler()
IORuntime(
compute._1,
blocking._1,
scheduler._1,
() => {
compute._2.apply()
blocking._2.apply()
scheduler._2.apply()
},
IORuntimeConfig()
)
}

val runtime = Resource.make(IO(singleThreadedRuntime()))(rt => IO(rt.shutdown()))

def write(os: OutputStream): IO[Unit] =
IO.blocking {
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
}

val prog = readOutputStream[IO](chunkSize = 1)(write)
.take(5)
.compile
.toVector
.map(_.size)
.assertEquals(5)

runtime.use { rt =>
IO.fromFuture(IO(prog.unsafeToFuture()(rt)))
}
val blocking = IORuntime.createDefaultBlockingExecutionContext()
val scheduler = IORuntime.createDefaultScheduler()
IORuntime(
compute._1,
blocking._1,
scheduler._1,
() => {
compute._2.apply()
blocking._2.apply()
scheduler._2.apply()
},
IORuntimeConfig()
)
}

val runtime = Resource.make(IO(singleThreadedRuntime()))(rt => IO(rt.shutdown()))
test("can copy more than Int.MaxValue bytes") {
// Unit test adapted from the original issue reproduction at https://github.com/mrdziuban/fs2-writeOutputStream.

def write(os: OutputStream): IO[Unit] =
IO.blocking {
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
os.write(1)
}
val byteStream =
Stream
.chunk[IO, Byte](Chunk.array(("foobar" * 50000).getBytes(StandardCharsets.UTF_8)))
.repeatN(7200L) // 6 * 50,000 * 7,200 == 2,160,000,000 > 2,147,483,647 == Int.MaxValue

val prog = readOutputStream[IO](chunkSize = 1)(write)
.take(5)
.compile
.toVector
.map(_.size)
.assertEquals(5)
def writeToOutputStream(out: OutputStream): IO[Unit] =
byteStream
.through(writeOutputStream(IO.pure(out)))
.compile
.drain

runtime.use { rt =>
IO.fromFuture(IO(prog.unsafeToFuture()(rt)))
readOutputStream[IO](1024 * 8)(writeToOutputStream)
.chunkN(6 * 50000)
.map(c => new String(c.toArray[Byte], StandardCharsets.UTF_8))
.foreach(str => IO.pure(str).assertEquals("foobar" * 50000))
.compile
.drain
}
}

}