Skip to content

Commit

Permalink
Prepare zlib's deflater for Kotlin/Native (#1420)
Browse files Browse the repository at this point in the history
* Prepare zlib's deflater for Kotlin/Native

* Note on exhausted sources

* Figure out assertk later

* Properly alloc and free the z_stream_s struct

* Add tests where deflate() returns false
  • Loading branch information
swankjesse committed Feb 6, 2024
1 parent 0d76d41 commit c1aaf58
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 0 deletions.
152 changes: 152 additions & 0 deletions okio/src/nativeMain/kotlin/okio/Deflater.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

import kotlinx.cinterop.CPointer
import kotlinx.cinterop.UByteVar
import kotlinx.cinterop.addressOf
import kotlinx.cinterop.alloc
import kotlinx.cinterop.free
import kotlinx.cinterop.nativeHeap
import kotlinx.cinterop.ptr
import kotlinx.cinterop.usePinned
import platform.zlib.Z_BEST_COMPRESSION
import platform.zlib.Z_DEFAULT_STRATEGY
import platform.zlib.Z_DEFLATED
import platform.zlib.Z_FINISH
import platform.zlib.Z_NO_FLUSH
import platform.zlib.Z_OK
import platform.zlib.Z_STREAM_END
import platform.zlib.Z_STREAM_ERROR
import platform.zlib.Z_SYNC_FLUSH
import platform.zlib.deflate
import platform.zlib.deflateEnd
import platform.zlib.deflateInit2
import platform.zlib.z_stream_s

private val emptyByteArray = byteArrayOf()

/**
* Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits
* the zlib header and trailer, and does not compute a check value.
*
* To use:
*
* 1. Create an instance.
*
* 2. Populate [source] with uncompressed data. Set [sourcePos] and [sourceLimit] to a readable
* slice of this array.
*
* 3. Populate [target] with a destination for compressed data. Set [targetPos] and [targetLimit] to
* a writable slice of this array.
*
* 4. Call [deflate] to read input data from [source] and write compressed output to [target]. This
* function advances [sourcePos] if input data was read and [targetPos] if compressed output was
* written. If the input array is exhausted (`sourcePos == sourceLimit`) or the output array is
* full (`targetPos == targetLimit`), make an adjustment and call [deflate] again.
*
* 5. Repeat steps 2 through 4 until the input data is completely exhausted. Set [sourceFinished]
* to true before the last call to [deflate]. (It is okay to call deflate() when the source is
* exhausted.)
*
* 6. Close the Deflater.
*
* See also, the [zlib manual](https://www.zlib.net/manual.html).
*/
internal class Deflater : Closeable {
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> {
zalloc = null
zfree = null
opaque = null
check(
deflateInit2(
strm = ptr,
level = Z_BEST_COMPRESSION,
method = Z_DEFLATED,
windowBits = -15, // Default value for raw deflate.
memLevel = 8, // Default value.
strategy = Z_DEFAULT_STRATEGY,
) == Z_OK,
)
}

var source: ByteArray = emptyByteArray
var sourcePos: Int = 0
var sourceLimit: Int = 0
var sourceFinished = false

var target: ByteArray = emptyByteArray
var targetPos: Int = 0
var targetLimit: Int = 0

private var closed = false

/**
* Returns true if no further calls to [deflate] are required to complete the operation.
* Otherwise, make space available in [target] and call [deflate] again with the same arguments.
*/
fun deflate(flush: Boolean = false): Boolean {
check(!closed) { "closed" }
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size)
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size)

source.usePinned { pinnedSource ->
target.usePinned { pinnedTarget ->
val sourceByteCount = sourceLimit - sourcePos
zStream.next_in = when {
sourceByteCount > 0 -> pinnedSource.addressOf(sourcePos) as CPointer<UByteVar>
else -> null
}
zStream.avail_in = sourceByteCount.toUInt()

val targetByteCount = targetLimit - targetPos
zStream.next_out = when {
targetByteCount > 0 -> pinnedTarget.addressOf(targetPos) as CPointer<UByteVar>
else -> null
}
zStream.avail_out = targetByteCount.toUInt()

val deflateFlush = when {
sourceFinished -> Z_FINISH
flush -> Z_SYNC_FLUSH
else -> Z_NO_FLUSH
}

// One of Z_OK, Z_STREAM_END, Z_STREAM_ERROR, or Z_BUF_ERROR.
val deflateResult = deflate(zStream.ptr, deflateFlush)
check(deflateResult != Z_STREAM_ERROR)

sourcePos += sourceByteCount - zStream.avail_in.toInt()
targetPos += targetByteCount - zStream.avail_out.toInt()

return when {
sourceFinished -> deflateResult == Z_STREAM_END
flush -> targetPos < targetLimit
else -> true
}
}
}
}

override fun close() {
if (closed) return
closed = true

val deflateEndResult = deflateEnd(zStream.ptr)
check(deflateEndResult == Z_OK)
nativeHeap.free(zStream)
}
}
177 changes: 177 additions & 0 deletions okio/src/nativeTest/kotlin/okio/DeflaterTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
import okio.ByteString.Companion.decodeBase64
import okio.ByteString.Companion.encodeUtf8
import okio.ByteString.Companion.toByteString

class DeflaterTest {
@Test
fun happyPath() {
val deflater = Deflater().apply {
source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray()
sourcePos = 0
sourceLimit = source.size
sourceFinished = true

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(deflater.deflate())
assertEquals(deflater.sourceLimit, deflater.sourcePos)
val deflated = deflater.target.toByteString(0, deflater.targetPos)

// Golden compressed output.
assertEquals(
"c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64(),
deflated,
)

deflater.close()
}

@Test
fun deflateInParts() {
val deflater = Deflater().apply {
target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

deflater.source = "God help us, we're in the hands".encodeUtf8().toByteArray()
deflater.sourcePos = 0
deflater.sourceLimit = deflater.source.size
deflater.sourceFinished = false
assertTrue(deflater.deflate())
assertEquals(deflater.sourceLimit, deflater.sourcePos)

deflater.source = " of engineers.".encodeUtf8().toByteArray()
deflater.sourcePos = 0
deflater.sourceLimit = deflater.source.size
deflater.sourceFinished = true
assertTrue(deflater.deflate())
assertEquals(deflater.sourceLimit, deflater.sourcePos)

val deflated = deflater.target.toByteString(0, deflater.targetPos)

// Golden compressed output.
assertEquals(
"c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64(),
deflated,
)

deflater.close()
}

@Test
fun deflateInsufficientSpaceInTargetWithoutSourceFinished() {
val targetBuffer = Buffer()

val deflater = Deflater().apply {
source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray()
sourcePos = 0
sourceLimit = source.size
}

deflater.target = ByteArray(10)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertFalse(deflater.deflate(flush = true))
assertEquals(deflater.targetLimit, deflater.targetPos)
targetBuffer.write(deflater.target)

deflater.target = ByteArray(256)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertTrue(deflater.deflate())
assertEquals(deflater.sourcePos, deflater.sourceLimit)
targetBuffer.write(deflater.target, 0, deflater.targetPos)

deflater.sourceFinished = true
assertTrue(deflater.deflate())

// Golden compressed output.
assertEquals(
"cs9PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAw==".decodeBase64(),
targetBuffer.readByteString(),
)

deflater.close()
}

@Test
fun deflateInsufficientSpaceInTargetWithSourceFinished() {
val targetBuffer = Buffer()

val deflater = Deflater().apply {
source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray()
sourcePos = 0
sourceLimit = source.size
sourceFinished = true
}

deflater.target = ByteArray(10)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertFalse(deflater.deflate())
assertEquals(deflater.targetLimit, deflater.targetPos)
targetBuffer.write(deflater.target)

deflater.target = ByteArray(256)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertTrue(deflater.deflate())
assertEquals(deflater.sourcePos, deflater.sourceLimit)
targetBuffer.write(deflater.target, 0, deflater.targetPos)

// Golden compressed output.
assertEquals(
"c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64(),
targetBuffer.readByteString(),
)

deflater.close()
}

@Test
fun deflateEmptySource() {
val deflater = Deflater().apply {
sourceFinished = true

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(deflater.deflate())
val deflated = deflater.target.toByteString(0, deflater.targetPos)

// Golden compressed output.
assertEquals(
"AwA=".decodeBase64(),
deflated,
)

deflater.close()
}
}

0 comments on commit c1aaf58

Please sign in to comment.