Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PipedStreamBuffer circular copying overflow fix #2638

Merged
merged 7 commits into from
Sep 28, 2021
30 changes: 25 additions & 5 deletions io/jvm/src/main/scala/fs2/io/internal/PipedStreamBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (head != tail) {
// There is at least one byte to read.
val byte = buffer(head % capacity) & 0xff
val byte = buffer(modulus(head, capacity)) & 0xff
// The byte is marked as read by advancing the head of the
// circular buffer.
head += 1
Expand Down Expand Up @@ -211,7 +211,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstPos: Int,
length: Int
): Unit = {
val srcOffset = srcPos % srcCap
val srcOffset = modulus(srcPos, srcCap)
if (srcOffset + length >= srcCap) {
val batch1 = srcCap - srcOffset
val batch2 = length - batch1
Expand All @@ -237,7 +237,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
self.synchronized {
if (tail - head < capacity) {
// There is capacity for at least one byte to be written.
buffer(tail % capacity) = (b & 0xff).toByte
buffer(modulus(tail, capacity)) = (b & 0xff).toByte
// The byte is marked as written by advancing the tail of the
// circular buffer.
tail += 1
Expand Down Expand Up @@ -364,15 +364,35 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
dstCap: Int,
length: Int
): Unit = {
val dstOffset = dstPos % dstCap
val dstOffset = modulus(dstPos, dstCap)
if (dstOffset + length >= dstCap) {
val batch1 = dstCap - dstOffset
val batch2 = length - batch1
System.arraycopy(src, srcPos, dst, dstOffset, batch1)
System.arraycopy(src, srcPos + batch1, dst, 0, batch2)
} else {
System.arraycopy(src, srcPos, dst, dstOffset, length)
try System.arraycopy(src, srcPos, dst, dstOffset, length)
catch {
case e: ArrayIndexOutOfBoundsException =>
println(
s"srcPos = $srcPos, dstPos = $dstPos, dstCap = $dstCap, length = $length, dstOffset = $dstOffset, dstOffset + length = ${dstOffset + length}"
)
throw e
}
}
}
}

/** Calculates `n` modulo `m`. This is different from the JVMs built in `%`
* remainder operator which can return a negative result for negative
* numbers. This method always returns a non-negative result.
*
* @param n the divident
* @param m the divisor
* @return the result of `n` modulo `m`
*/
private[this] def modulus(n: Int, m: Int): Int = {
val r = n % m
if (r < 0) r + m else r
}
vasilmkd marked this conversation as resolved.
Show resolved Hide resolved
}