Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PipedStreamBuffer circular copying overflow fix #2638

Merged
merged 7 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions io/jvm/src/main/scala/fs2/io/internal/PipedStreamBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (head != tail) {
// There is at least one byte to read.
val byte = buffer(head % capacity) & 0xff
val byte = buffer(modulus(head, capacity)) & 0xff
// The byte is marked as read by advancing the head of the
// circular buffer.
head += 1
Expand Down Expand Up @@ -211,7 +211,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstPos: Int,
length: Int
): Unit = {
val srcOffset = srcPos % srcCap
val srcOffset = modulus(srcPos, srcCap)
if (srcOffset + length >= srcCap) {
val batch1 = srcCap - srcOffset
val batch2 = length - batch1
Expand All @@ -237,7 +237,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (tail - head < capacity) {
// There is capacity for at least one byte to be written.
buffer(tail % capacity) = (b & 0xff).toByte
buffer(modulus(tail, capacity)) = (b & 0xff).toByte
// The byte is marked as written by advancing the tail of the
// circular buffer.
tail += 1
Expand Down Expand Up @@ -364,15 +364,35 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstCap: Int,
length: Int
): Unit = {
val dstOffset = dstPos % dstCap
val dstOffset = modulus(dstPos, dstCap)
if (dstOffset + length >= dstCap) {
val batch1 = dstCap - dstOffset
val batch2 = length - batch1
System.arraycopy(src, srcPos, dst, dstOffset, batch1)
System.arraycopy(src, srcPos + batch1, dst, 0, batch2)
} else {
System.arraycopy(src, srcPos, dst, dstOffset, length)
try System.arraycopy(src, srcPos, dst, dstOffset, length)
catch {
case e: ArrayIndexOutOfBoundsException =>
println(
s"srcPos = $srcPos, dstPos = $dstPos, dstCap = $dstCap, length = $length, dstOffset = $dstOffset, dstOffset + length = ${dstOffset + length}"
)
throw e
}
}
}
}

/** Calculates `n` modulo `m`. This is different from the JVMs built in `%`
* remainder operator which can return a negative result for negative
* numbers. This method always returns a non-negative result.
*
* @param n the divident
* @param m the divisor
* @return the result of `n` modulo `m`
*/
private[this] def modulus(n: Int, m: Int): Int = {
val r = n % m
if (r < 0) r + m else r
}
vasilmkd marked this conversation as resolved.
Show resolved Hide resolved
}
28 changes: 23 additions & 5 deletions io/jvm/src/test/scala/fs2/io/IoPlatformSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
package fs2
package io

import java.io.OutputStream
import java.util.concurrent.Executors
import cats.effect.{IO, Resource}
import cats.effect.unsafe.{IORuntime, IORuntimeConfig}
import fs2.Fs2Suite
import fs2.Err
import scala.concurrent.ExecutionContext
import fs2.{Err, Fs2Suite}
import org.scalacheck.{Arbitrary, Gen, Shrink}
import org.scalacheck.effect.PropF.forAllF

import scala.concurrent.ExecutionContext

import java.io.OutputStream
import java.nio.charset.StandardCharsets
import java.util.concurrent.Executors

class IoPlatformSuite extends Fs2Suite {

group("readOutputStream") {
Expand Down Expand Up @@ -177,4 +179,20 @@ class IoPlatformSuite extends Fs2Suite {
}
}

test("can copy more than Int.MaxValue bytes") {
// Unit test adapted from the original issue reproduction at https://github.com/mrdziuban/fs2-writeOutputStream.

val byteStream =
Stream
.chunk[IO, Byte](Chunk.array(("foobar" * 50000).getBytes(StandardCharsets.UTF_8)))
.repeatN(7200L) // 6 * 50,000 * 7,200 == 2,160,000,000 > 2,147,483,647 == Int.MaxValue

def writeToOutputStream(out: OutputStream): IO[Unit] =
byteStream
.through(writeOutputStream(IO(out)))
.compile
.drain

readOutputStream[IO](1024 * 8)(writeToOutputStream).compile.drain
}
}