From b946778e7dc4f41f633c4ce38acc6ac05b87e765 Mon Sep 17 00:00:00 2001 From: Martin Devillers Date: Thu, 9 Jul 2020 15:00:33 +0200 Subject: [PATCH 1/2] Base64 read & write Buffer --- okio/src/commonMain/kotlin/okio/-Base64.kt | 21 ++- okio/src/commonMain/kotlin/okio/Buffer.kt | 159 ++++++++++++++++++ .../jvmTest/kotlin/okio/BufferBase64Test.kt | 135 +++++++++++++++ 3 files changed, 304 insertions(+), 11 deletions(-) create mode 100644 okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt diff --git a/okio/src/commonMain/kotlin/okio/-Base64.kt b/okio/src/commonMain/kotlin/okio/-Base64.kt index bdf2e1bc65..2cbd28908d 100644 --- a/okio/src/commonMain/kotlin/okio/-Base64.kt +++ b/okio/src/commonMain/kotlin/okio/-Base64.kt @@ -18,15 +18,14 @@ @file:JvmName("-Base64") package okio -import okio.ByteString.Companion.encodeUtf8 import kotlin.jvm.JvmName /** @author Alexander Y. Kleymenov */ -internal val BASE64 = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".encodeUtf8().data -internal val BASE64_URL_SAFE = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_".encodeUtf8().data +internal const val BASE64 = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +internal const val BASE64_URL_SAFE = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" internal fun String.decodeBase64ToArray(): ByteArray? { // Ignore trailing '=' padding and whitespace from the input. @@ -112,9 +111,9 @@ internal fun String.decodeBase64ToArray(): ByteArray? { return out.copyOf(outCount) } -internal fun ByteArray.encodeBase64(map: ByteArray = BASE64): String { +internal fun ByteArray.encodeBase64(map: String = BASE64): String { val length = (size + 2) / 3 * 4 - val out = ByteArray(length) + val out = CharArray(length) var index = 0 val end = size - size % 3 var i = 0 @@ -132,8 +131,8 @@ internal fun ByteArray.encodeBase64(map: ByteArray = BASE64): String { val b0 = this[i].toInt() out[index++] = map[b0 and 0xff shr 2] out[index++] = map[b0 and 0x03 shl 4] - out[index++] = '='.toByte() - out[index] = '='.toByte() + out[index++] = '=' + out[index] = '=' } 2 -> { val b0 = this[i++].toInt() @@ -141,8 +140,8 @@ internal fun ByteArray.encodeBase64(map: ByteArray = BASE64): String { out[index++] = map[(b0 and 0xff shr 2)] out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] out[index++] = map[(b1 and 0x0f shl 2)] - out[index] = '='.toByte() + out[index] = '=' } } - return out.toUtf8String() + return out.concatToString() } diff --git a/okio/src/commonMain/kotlin/okio/Buffer.kt b/okio/src/commonMain/kotlin/okio/Buffer.kt index 1ff7342556..130dce56db 100644 --- a/okio/src/commonMain/kotlin/okio/Buffer.kt +++ b/okio/src/commonMain/kotlin/okio/Buffer.kt @@ -123,3 +123,162 @@ expect class Buffer() : BufferedSource, BufferedSink { /** Returns an immutable copy of the first `byteCount` bytes of this buffer as a byte string. */ fun snapshot(byteCount: Int): ByteString } + +fun Buffer.writeBase64(string: String) { + // Ignore trailing '=' padding and whitespace from the input. + var limit = string.length + while (limit > 0) { + val c = string[limit - 1] + if (c != '=' && c != '\n' && c != '\r' && c != ' ' && c != '\t') { + break + } + limit-- + } + + var inCount = 0 + var word = 0 + var pos = 0 + var s = head + while (pos < limit) { + val c = string[pos++] + val bits: Int + if (c in 'A'..'Z') { + // char ASCII value + // A 65 0 + // Z 90 25 (ASCII - 65) + bits = c.toInt() - 65 + } else if (c in 'a'..'z') { + // char ASCII value + // a 97 26 + // z 122 51 (ASCII - 71) + bits = c.toInt() - 71 + } else if (c in '0'..'9') { + // char ASCII value + // 0 48 52 + // 9 57 61 (ASCII + 4) + bits = c.toInt() + 4 + } else if (c == '+' || c == '-') { + bits = 62 + } else if (c == '/' || c == '_') { + bits = 63 + } else if (c == '\n' || c == '\r' || c == ' ' || c == '\t') { + continue + } else { + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? + } + + // Append this char's 6 bits to the word. + word = word shl 6 or bits + + // For every 4 chars of input, we accumulate 24 bits of output. Emit 3 bytes. + inCount++ + if (inCount % 4 == 0) { + if (s == null || s.limit + 3 > Segment.SIZE) { + // For simplicity, don't try to write blocks across different segments, allocate new segment when current doesn't have enough capacity + s = writableSegment(3) + } + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + data[i++] = word.toByte() + s.limit = i + size += 3 + } + } + + val lastWordChars = inCount % 4 + when (lastWordChars) { + 1 -> { + // We read 1 char followed by "===". But 6 bits is a truncated byte! Fail. + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? + } + 2 -> { + // We read 2 chars followed by "==". Emit 1 byte with 8 of those 12 bits. + if (s == null || s.limit + 1 > Segment.SIZE) { + s = writableSegment(1) + } + word = word shl 12 + s.data[s.limit++] = (word shr 16).toByte() + size += 1 + } + 3 -> { + // We read 3 chars, followed by "=". Emit 2 bytes for 16 of those 18 bits. + if (s == null || s.limit + 2 > Segment.SIZE) { + s = writableSegment(2) + } + word = word shl 6 + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + s.limit = i + size += 2 + } + } +} + +fun Buffer.readBase64(): String = + readBase64(BASE64) + +fun Buffer.readBase64Url(): String = + readBase64(BASE64_URL_SAFE) + +private fun Buffer.readBase64(map: String = BASE64): String { + val length = ((size + 2) / 3 * 4).toInt() // TODO: Prevent Int overflow / arithmetic overflow ? + val out = CharArray(length) + var index = 0 + while (size >= 3) { + val s = head!! + val segmentSize = s.limit - s.pos + if (segmentSize > 3) { + // Read all complete blocks from head segment + val data = s.data + val end = s.limit - segmentSize % 3 + var i = s.pos + while (i < end) { + val b0 = data[i++].toInt() + val b1 = data[i++].toInt() + val b2 = data[i++].toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } + size -= end - s.pos + if (end == s.limit) { + head = s.pop() + SegmentPool.recycle(s) + } else { + s.pos = end + } + } else { + // Read next block, which is spread over multiple segments + val b0 = readByte().toInt() + val b1 = readByte().toInt() + val b2 = readByte().toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } + } + when (size) { + 1L -> { + val b0 = readByte().toInt() + out[index++] = map[b0 and 0xff shr 2] + out[index++] = map[b0 and 0x03 shl 4] + out[index++] = '=' + out[index] = '=' + } + 2L -> { + val b0 = readByte().toInt() + val b1 = readByte().toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2)] + out[index] = '=' + } + } + return out.concatToString() +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt b/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt new file mode 100644 index 0000000000..73df64b147 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt @@ -0,0 +1,135 @@ +package okio + +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +@RunWith(Parameterized::class) +class BufferBase64Test(val size: Int) { + + private val random = Random(4975347L + size) + private val bytes = random.nextBytes(size) + + companion object { + @get:Parameterized.Parameters(name = "{0}") + @get:JvmStatic + val parameters: List + get() = (0..32).toList() + ((Segment.SIZE - 32)..Segment.SIZE + 32).toList() + + private val base64Encoder = Base64.getEncoder() + private val base64UrlEncoder = Base64.getUrlEncoder() + + private fun Random.nextBytes(size: Int): ByteArray = + ByteArray(size).also { nextBytes(it) } + } + + @Test + fun write() { + val encoded = base64Encoder.encodeToString(bytes) + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeWithWhitespace() { + val encoded = base64Encoder.encodeToString(bytes).chunked(8).joinToString("\n") + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeCorruptedInvalidChar() { + if (size == 0) return // Skip this test when there is no data + val corruptedIndex = random.nextInt(size) + val encoded = + base64Encoder.encodeToString(bytes).replaceRange(corruptedIndex..corruptedIndex, "?") + + val buffer = Buffer() + assertFailsWith { + buffer.writeBase64(encoded) + } + } + + @Test + fun writeCorruptedInvalidLength() { + if (size == 0) return // Skip this test when there is no data + val encoded = base64Encoder.encodeToString(bytes) + "A" + + val buffer = Buffer() + assertFailsWith { + buffer.writeBase64(encoded) + } + } + + @Test + fun writeMultiple() { + val buffer = Buffer().apply { + bytes.asList().chunked(4).forEach { + val encoded = base64Encoder.encodeToString(it.toByteArray()) + writeBase64(encoded) + } + } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeUrlEncoded() { + val encoded = base64UrlEncoder.encodeToString(bytes) + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun read() { + val buffer = Buffer().apply { write(bytes) } + + val s = buffer.readBase64() + + assertEquals(base64Encoder.encodeToString(bytes), s) + } + + @Test + fun readUrlEncoded() { + val buffer = Buffer().apply { write(bytes) } + + val s = buffer.readBase64Url() + + val encoded = base64UrlEncoder.encodeToString(bytes) + assertEquals(encoded, s) + } + + @Test + fun readFragmented() { + // Buffer made of segments with only one byte, randomly located + val buffer = Buffer().apply { + bytes.forEach { + val s = writableSegment(Segment.SIZE) + check(s.pos == 0 && s.limit == 0) // Implementation should provide an empty segment + val pos = random.nextInt(Segment.SIZE) + s.pos = pos + s.data[pos] = it + s.limit = pos + 1 + size++ + } + } + + val s = buffer.readBase64() + + val encoded = base64Encoder.encodeToString(bytes) + assertEquals(encoded, s) + } +} From 479a7f78f7929911fa61698cb4f7e5c0b0b99806 Mon Sep 17 00:00:00 2001 From: Martin Devillers Date: Thu, 19 Nov 2020 17:08:00 +0100 Subject: [PATCH 2/2] Base64 buffer-first implementation `ByteString` relies on Buffer implementation for Base64 capabilities. `BufferedSource` and `BufferedSink` API for Base64. --- okio/src/commonMain/kotlin/okio/-Base64.kt | 124 +++++++++----- okio/src/commonMain/kotlin/okio/Buffer.kt | 161 +----------------- .../commonMain/kotlin/okio/BufferedSink.kt | 5 + .../commonMain/kotlin/okio/BufferedSource.kt | 12 ++ .../kotlin/okio/internal/ByteString.kt | 19 ++- .../kotlin/okio/internal/RealBufferedSink.kt | 6 + .../okio/internal/RealBufferedSource.kt | 10 ++ okio/src/jsMain/kotlin/okio/Buffer.kt | 7 + okio/src/jsMain/kotlin/okio/BufferedSink.kt | 2 + okio/src/jsMain/kotlin/okio/BufferedSource.kt | 4 + .../jsMain/kotlin/okio/RealBufferedSink.kt | 2 + .../jsMain/kotlin/okio/RealBufferedSource.kt | 4 + okio/src/jvmMain/kotlin/okio/Buffer.kt | 9 + okio/src/jvmMain/kotlin/okio/BufferedSink.kt | 3 + .../src/jvmMain/kotlin/okio/BufferedSource.kt | 6 + .../jvmMain/kotlin/okio/RealBufferedSink.kt | 2 + .../jvmMain/kotlin/okio/RealBufferedSource.kt | 4 + okio/src/nativeMain/kotlin/okio/Buffer.kt | 7 + .../nativeMain/kotlin/okio/BufferedSink.kt | 2 + .../nativeMain/kotlin/okio/BufferedSource.kt | 4 + .../kotlin/okio/RealBufferedSink.kt | 2 + .../kotlin/okio/RealBufferedSource.kt | 4 + 22 files changed, 192 insertions(+), 207 deletions(-) diff --git a/okio/src/commonMain/kotlin/okio/-Base64.kt b/okio/src/commonMain/kotlin/okio/-Base64.kt index 2cbd28908d..9d401ce3f3 100644 --- a/okio/src/commonMain/kotlin/okio/-Base64.kt +++ b/okio/src/commonMain/kotlin/okio/-Base64.kt @@ -27,26 +27,23 @@ internal const val BASE64 = internal const val BASE64_URL_SAFE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" -internal fun String.decodeBase64ToArray(): ByteArray? { +fun Buffer.commonWriteBase64(string: String): Buffer { // Ignore trailing '=' padding and whitespace from the input. - var limit = length + var limit = string.length while (limit > 0) { - val c = this[limit - 1] + val c = string[limit - 1] if (c != '=' && c != '\n' && c != '\r' && c != ' ' && c != '\t') { break } limit-- } - // If the input includes whitespace, this output array will be longer than necessary. - val out = ByteArray((limit * 6L / 8L).toInt()) - var outCount = 0 var inCount = 0 - var word = 0 - for (pos in 0 until limit) { - val c = this[pos] - + var pos = 0 + var s = head + while (pos < limit) { + val c = string[pos++] val bits: Int if (c in 'A'..'Z') { // char ASCII value @@ -70,7 +67,7 @@ internal fun String.decodeBase64ToArray(): ByteArray? { } else if (c == '\n' || c == '\r' || c == ' ' || c == '\t') { continue } else { - return null + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? } // Append this char's 6 bits to the word. @@ -79,9 +76,17 @@ internal fun String.decodeBase64ToArray(): ByteArray? { // For every 4 chars of input, we accumulate 24 bits of output. Emit 3 bytes. inCount++ if (inCount % 4 == 0) { - out[outCount++] = (word shr 16).toByte() - out[outCount++] = (word shr 8).toByte() - out[outCount++] = word.toByte() + if (s == null || s.limit + 3 > Segment.SIZE) { + // For simplicity, don't try to write blocks across different segments, allocate new segment when current doesn't have enough capacity + s = writableSegment(3) + } + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + data[i++] = word.toByte() + s.limit = i + size += 3 } } @@ -89,54 +94,91 @@ internal fun String.decodeBase64ToArray(): ByteArray? { when (lastWordChars) { 1 -> { // We read 1 char followed by "===". But 6 bits is a truncated byte! Fail. - return null + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? } 2 -> { // We read 2 chars followed by "==". Emit 1 byte with 8 of those 12 bits. + if (s == null || s.limit + 1 > Segment.SIZE) { + s = writableSegment(1) + } word = word shl 12 - out[outCount++] = (word shr 16).toByte() + s.data[s.limit++] = (word shr 16).toByte() + size += 1 } 3 -> { // We read 3 chars, followed by "=". Emit 2 bytes for 16 of those 18 bits. + if (s == null || s.limit + 2 > Segment.SIZE) { + s = writableSegment(2) + } word = word shl 6 - out[outCount++] = (word shr 16).toByte() - out[outCount++] = (word shr 8).toByte() + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + s.limit = i + size += 2 } } - // If we sized our out array perfectly, we're done. - if (outCount == out.size) return out - - // Copy the decoded bytes to a new, right-sized array. - return out.copyOf(outCount) + return this } -internal fun ByteArray.encodeBase64(map: String = BASE64): String { - val length = (size + 2) / 3 * 4 +fun Buffer.commonReadBase64(): String = + readBase64(BASE64) + +fun Buffer.commonReadBase64Url(): String = + readBase64(BASE64_URL_SAFE) + +private fun Buffer.readBase64(map: String = BASE64): String { + val length = ((size + 2) / 3 * 4).toInt() // TODO: Prevent Int overflow / arithmetic overflow ? val out = CharArray(length) var index = 0 - val end = size - size % 3 - var i = 0 - while (i < end) { - val b0 = this[i++].toInt() - val b1 = this[i++].toInt() - val b2 = this[i++].toInt() - out[index++] = map[(b0 and 0xff shr 2)] - out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] - out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] - out[index++] = map[(b2 and 0x3f)] + while (size >= 3) { + val s = head!! + val segmentSize = s.limit - s.pos + if (segmentSize > 3) { + // Read all complete blocks from head segment + val data = s.data + val end = s.limit - segmentSize % 3 + var i = s.pos + while (i < end) { + val b0 = data[i++].toInt() + val b1 = data[i++].toInt() + val b2 = data[i++].toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } + size -= end - s.pos + if (end == s.limit) { + head = s.pop() + SegmentPool.recycle(s) + } else { + s.pos = end + } + } else { + // Read next block, which is spread over multiple segments + val b0 = readByte().toInt() + val b1 = readByte().toInt() + val b2 = readByte().toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } } - when (size - end) { - 1 -> { - val b0 = this[i].toInt() + when (size) { + 1L -> { + val b0 = readByte().toInt() out[index++] = map[b0 and 0xff shr 2] out[index++] = map[b0 and 0x03 shl 4] out[index++] = '=' out[index] = '=' } - 2 -> { - val b0 = this[i++].toInt() - val b1 = this[i].toInt() + 2L -> { + val b0 = readByte().toInt() + val b1 = readByte().toInt() out[index++] = map[(b0 and 0xff shr 2)] out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] out[index++] = map[(b1 and 0x0f shl 2)] diff --git a/okio/src/commonMain/kotlin/okio/Buffer.kt b/okio/src/commonMain/kotlin/okio/Buffer.kt index 130dce56db..0bbc7abb07 100644 --- a/okio/src/commonMain/kotlin/okio/Buffer.kt +++ b/okio/src/commonMain/kotlin/okio/Buffer.kt @@ -114,6 +114,8 @@ expect class Buffer() : BufferedSource, BufferedSink { override fun writeHexadecimalUnsignedLong(v: Long): Buffer + override fun writeBase64(string: String): Buffer + /** Returns a deep copy of this buffer. */ fun copy(): Buffer @@ -123,162 +125,3 @@ expect class Buffer() : BufferedSource, BufferedSink { /** Returns an immutable copy of the first `byteCount` bytes of this buffer as a byte string. */ fun snapshot(byteCount: Int): ByteString } - -fun Buffer.writeBase64(string: String) { - // Ignore trailing '=' padding and whitespace from the input. - var limit = string.length - while (limit > 0) { - val c = string[limit - 1] - if (c != '=' && c != '\n' && c != '\r' && c != ' ' && c != '\t') { - break - } - limit-- - } - - var inCount = 0 - var word = 0 - var pos = 0 - var s = head - while (pos < limit) { - val c = string[pos++] - val bits: Int - if (c in 'A'..'Z') { - // char ASCII value - // A 65 0 - // Z 90 25 (ASCII - 65) - bits = c.toInt() - 65 - } else if (c in 'a'..'z') { - // char ASCII value - // a 97 26 - // z 122 51 (ASCII - 71) - bits = c.toInt() - 71 - } else if (c in '0'..'9') { - // char ASCII value - // 0 48 52 - // 9 57 61 (ASCII + 4) - bits = c.toInt() + 4 - } else if (c == '+' || c == '-') { - bits = 62 - } else if (c == '/' || c == '_') { - bits = 63 - } else if (c == '\n' || c == '\r' || c == ' ' || c == '\t') { - continue - } else { - throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? - } - - // Append this char's 6 bits to the word. - word = word shl 6 or bits - - // For every 4 chars of input, we accumulate 24 bits of output. Emit 3 bytes. - inCount++ - if (inCount % 4 == 0) { - if (s == null || s.limit + 3 > Segment.SIZE) { - // For simplicity, don't try to write blocks across different segments, allocate new segment when current doesn't have enough capacity - s = writableSegment(3) - } - val data = s.data - var i = s.limit - data[i++] = (word shr 16).toByte() - data[i++] = (word shr 8).toByte() - data[i++] = word.toByte() - s.limit = i - size += 3 - } - } - - val lastWordChars = inCount % 4 - when (lastWordChars) { - 1 -> { - // We read 1 char followed by "===". But 6 bits is a truncated byte! Fail. - throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? - } - 2 -> { - // We read 2 chars followed by "==". Emit 1 byte with 8 of those 12 bits. - if (s == null || s.limit + 1 > Segment.SIZE) { - s = writableSegment(1) - } - word = word shl 12 - s.data[s.limit++] = (word shr 16).toByte() - size += 1 - } - 3 -> { - // We read 3 chars, followed by "=". Emit 2 bytes for 16 of those 18 bits. - if (s == null || s.limit + 2 > Segment.SIZE) { - s = writableSegment(2) - } - word = word shl 6 - val data = s.data - var i = s.limit - data[i++] = (word shr 16).toByte() - data[i++] = (word shr 8).toByte() - s.limit = i - size += 2 - } - } -} - -fun Buffer.readBase64(): String = - readBase64(BASE64) - -fun Buffer.readBase64Url(): String = - readBase64(BASE64_URL_SAFE) - -private fun Buffer.readBase64(map: String = BASE64): String { - val length = ((size + 2) / 3 * 4).toInt() // TODO: Prevent Int overflow / arithmetic overflow ? - val out = CharArray(length) - var index = 0 - while (size >= 3) { - val s = head!! - val segmentSize = s.limit - s.pos - if (segmentSize > 3) { - // Read all complete blocks from head segment - val data = s.data - val end = s.limit - segmentSize % 3 - var i = s.pos - while (i < end) { - val b0 = data[i++].toInt() - val b1 = data[i++].toInt() - val b2 = data[i++].toInt() - out[index++] = map[(b0 and 0xff shr 2)] - out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] - out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] - out[index++] = map[(b2 and 0x3f)] - } - size -= end - s.pos - if (end == s.limit) { - head = s.pop() - SegmentPool.recycle(s) - } else { - s.pos = end - } - } else { - // Read next block, which is spread over multiple segments - val b0 = readByte().toInt() - val b1 = readByte().toInt() - val b2 = readByte().toInt() - out[index++] = map[(b0 and 0xff shr 2)] - out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] - out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] - out[index++] = map[(b2 and 0x3f)] - } - } - when (size) { - 1L -> { - val b0 = readByte().toInt() - out[index++] = map[b0 and 0xff shr 2] - out[index++] = map[b0 and 0x03 shl 4] - out[index++] = '=' - out[index] = '=' - } - 2L -> { - val b0 = readByte().toInt() - val b1 = readByte().toInt() - out[index++] = map[(b0 and 0xff shr 2)] - out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] - out[index++] = map[(b1 and 0x0f shl 2)] - out[index] = '=' - } - } - return out.concatToString() -} diff --git a/okio/src/commonMain/kotlin/okio/BufferedSink.kt b/okio/src/commonMain/kotlin/okio/BufferedSink.kt index 40c26585c8..0646ae49b9 100644 --- a/okio/src/commonMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/commonMain/kotlin/okio/BufferedSink.kt @@ -241,6 +241,11 @@ expect interface BufferedSink : Sink { */ fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + /** + * Decodes the Base64-encoded bytes from [string] and writes them to this sink. + */ + fun writeBase64(string: String): BufferedSink + /** * Writes all buffered data to the underlying sink, if one exists. Then that sink is recursively * flushed which pushes data as far as possible towards its ultimate destination. Typically that diff --git a/okio/src/commonMain/kotlin/okio/BufferedSource.kt b/okio/src/commonMain/kotlin/okio/BufferedSource.kt index 0ba4d152bd..eccda298c0 100644 --- a/okio/src/commonMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/commonMain/kotlin/okio/BufferedSource.kt @@ -421,6 +421,18 @@ expect interface BufferedSource : Source { */ fun readUtf8CodePoint(): Int + /** + * Removes all bytes from this, encodes them as as [Base64](http://www.ietf.org/rfc/rfc2045.txt), and returns the + * string. In violation of the RFC, the returned string does not wrap lines at 76 columns. + */ + fun readBase64(): String + + /** + * Removes all bytes from this, encodes them as as [URL-safe Base64](http://www.ietf.org/rfc/rfc4648.txt), and + * returns the string. + */ + fun readBase64Url(): String + /** Equivalent to [indexOf(b, 0)][indexOf]. */ fun indexOf(b: Byte): Long diff --git a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt index 5f0ec021e4..1d4fe82ded 100644 --- a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt +++ b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt @@ -16,7 +16,6 @@ package okio.internal -import okio.BASE64_URL_SAFE import okio.Buffer import okio.ByteString import okio.REPLACEMENT_CODE_POINT @@ -24,10 +23,9 @@ import okio.and import okio.arrayRangeEquals import okio.asUtf8ToByteArray import okio.checkOffsetAndCount -import okio.decodeBase64ToArray -import okio.encodeBase64 import okio.isIsoControl import okio.processUtf8CodePoints +import okio.commonReadBase64Url import okio.shr import okio.toUtf8String @@ -46,10 +44,12 @@ internal inline fun ByteString.commonUtf8(): String { } @Suppress("NOTHING_TO_INLINE") -internal inline fun ByteString.commonBase64(): String = data.encodeBase64() +internal inline fun ByteString.commonBase64(): String = + Buffer().write(this).readBase64() @Suppress("NOTHING_TO_INLINE") -internal inline fun ByteString.commonBase64Url() = data.encodeBase64(map = BASE64_URL_SAFE) +internal inline fun ByteString.commonBase64Url(): String = + Buffer().write(this).commonReadBase64Url() internal val HEX_DIGIT_CHARS = charArrayOf('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f') @@ -266,8 +266,13 @@ internal inline fun String.commonEncodeUtf8(): ByteString { @Suppress("NOTHING_TO_INLINE") internal inline fun String.commonDecodeBase64(): ByteString? { - val decoded = decodeBase64ToArray() - return if (decoded != null) ByteString(decoded) else null + val buffer = Buffer() + try { + buffer.writeBase64(this) + } catch (e: IllegalArgumentException) { // TODO: Dedicated Base64 exception? + return null + } + return buffer.readByteString() } @Suppress("NOTHING_TO_INLINE") diff --git a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt index 49b0c4d755..54aae6bd6b 100644 --- a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt +++ b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt @@ -163,6 +163,12 @@ internal inline fun RealBufferedSink.commonWriteHexadecimalUnsignedLong(v: Long) return emitCompleteSegments() } +internal inline fun RealBufferedSink.commonWriteBase64(string: String): BufferedSink { + check(!closed) { "closed" } + buffer.writeBase64(string) + return emitCompleteSegments() +} + internal inline fun RealBufferedSink.commonEmitCompleteSegments(): BufferedSink { check(!closed) { "closed" } val byteCount = buffer.completeSegmentByteCount() diff --git a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt index 40938ea2f8..b955058449 100644 --- a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt +++ b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt @@ -110,6 +110,16 @@ internal inline fun RealBufferedSource.commonReadByteArray(byteCount: Long): Byt return buffer.readByteArray(byteCount) } +internal inline fun RealBufferedSource.commonReadBase64(): String { + buffer.writeAll(source) + return buffer.readBase64() +} + +internal inline fun RealBufferedSource.commonReadBase64Url(): String { + buffer.writeAll(source) + return buffer.readBase64Url() +} + internal inline fun RealBufferedSource.commonReadFully(sink: ByteArray) { try { require(sink.size.toLong()) diff --git a/okio/src/jsMain/kotlin/okio/Buffer.kt b/okio/src/jsMain/kotlin/okio/Buffer.kt index 402bf6c41d..28a742ea49 100644 --- a/okio/src/jsMain/kotlin/okio/Buffer.kt +++ b/okio/src/jsMain/kotlin/okio/Buffer.kt @@ -130,6 +130,10 @@ actual class Buffer : BufferedSource, BufferedSink { override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + override fun readBase64(): String = commonReadBase64() + + override fun readBase64Url(): String = commonReadBase64Url() + override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() @@ -192,6 +196,9 @@ actual class Buffer : BufferedSource, BufferedSink { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + override fun write(source: Buffer, byteCount: Long): Unit = commonWrite(source, byteCount) override fun read(sink: Buffer, byteCount: Long): Long = commonRead(sink, byteCount) diff --git a/okio/src/jsMain/kotlin/okio/BufferedSink.kt b/okio/src/jsMain/kotlin/okio/BufferedSink.kt index 65d717c60a..20db6cbc00 100644 --- a/okio/src/jsMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/jsMain/kotlin/okio/BufferedSink.kt @@ -54,6 +54,8 @@ actual interface BufferedSink : Sink { actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + actual fun writeBase64(string: String): BufferedSink + actual fun emit(): BufferedSink actual fun emitCompleteSegments(): BufferedSink diff --git a/okio/src/jsMain/kotlin/okio/BufferedSource.kt b/okio/src/jsMain/kotlin/okio/BufferedSource.kt index 98b7718a14..23323e81e6 100644 --- a/okio/src/jsMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/jsMain/kotlin/okio/BufferedSource.kt @@ -76,6 +76,10 @@ actual interface BufferedSource : Source { actual fun readUtf8CodePoint(): Int + actual fun readBase64(): String + + actual fun readBase64Url(): String + actual fun indexOf(b: Byte): Long actual fun indexOf(b: Byte, fromIndex: Long): Long diff --git a/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt b/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt index ed03094ec3..924fe32316 100644 --- a/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt @@ -24,6 +24,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -66,6 +67,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() override fun flush() = commonFlush() diff --git a/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt b/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt index d6f4b94221..bec6bd7b87 100644 --- a/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -62,6 +64,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int = diff --git a/okio/src/jvmMain/kotlin/okio/Buffer.kt b/okio/src/jvmMain/kotlin/okio/Buffer.kt index c514bb6eba..9c76b01bb3 100644 --- a/okio/src/jvmMain/kotlin/okio/Buffer.kt +++ b/okio/src/jvmMain/kotlin/okio/Buffer.kt @@ -330,6 +330,12 @@ actual class Buffer : BufferedSource, BufferedSink, Cloneable, ByteChannel { @Throws(EOFException::class) override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + @Throws(EOFException::class) + override fun readBase64(): String = commonReadBase64() + + @Throws(EOFException::class) + override fun readBase64Url(): String = commonReadBase64Url() + override fun readByteArray() = commonReadByteArray() @Throws(EOFException::class) @@ -448,6 +454,9 @@ actual class Buffer : BufferedSource, BufferedSink, Cloneable, ByteChannel { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + internal actual fun writableSegment(minimumCapacity: Int): Segment = commonWritableSegment(minimumCapacity) diff --git a/okio/src/jvmMain/kotlin/okio/BufferedSink.kt b/okio/src/jvmMain/kotlin/okio/BufferedSink.kt index edb632f910..953293ae2b 100644 --- a/okio/src/jvmMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/jvmMain/kotlin/okio/BufferedSink.kt @@ -90,6 +90,9 @@ actual interface BufferedSink : Sink, WritableByteChannel { @Throws(IOException::class) actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + @Throws(IOException::class) + actual fun writeBase64(string: String): BufferedSink + @Throws(IOException::class) actual override fun flush() diff --git a/okio/src/jvmMain/kotlin/okio/BufferedSource.kt b/okio/src/jvmMain/kotlin/okio/BufferedSource.kt index b30c635ae2..24d328f2a1 100644 --- a/okio/src/jvmMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/jvmMain/kotlin/okio/BufferedSource.kt @@ -117,6 +117,12 @@ actual interface BufferedSource : Source, ReadableByteChannel { @Throws(IOException::class) actual fun readUtf8CodePoint(): Int + @Throws(IOException::class) + actual fun readBase64(): String + + @Throws(IOException::class) + actual fun readBase64Url(): String + /** Removes all bytes from this, decodes them as `charset`, and returns the string. */ @Throws(IOException::class) fun readString(charset: Charset): String diff --git a/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt b/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt index 7df3f93776..d3fb3ec5ef 100644 --- a/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt @@ -23,6 +23,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -100,6 +101,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() diff --git a/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt b/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt index 109ef1402e..1a58f9e288 100644 --- a/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -72,6 +74,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int = diff --git a/okio/src/nativeMain/kotlin/okio/Buffer.kt b/okio/src/nativeMain/kotlin/okio/Buffer.kt index 49623cbae8..63be873dc7 100644 --- a/okio/src/nativeMain/kotlin/okio/Buffer.kt +++ b/okio/src/nativeMain/kotlin/okio/Buffer.kt @@ -132,6 +132,10 @@ actual class Buffer : BufferedSource, BufferedSink { override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + override fun readBase64(): String = commonReadBase64() + + override fun readBase64Url(): String = commonReadBase64Url() + override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) @@ -192,6 +196,9 @@ actual class Buffer : BufferedSource, BufferedSink { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + override fun write(source: Buffer, byteCount: Long): Unit = commonWrite(source, byteCount) override fun read(sink: Buffer, byteCount: Long): Long = commonRead(sink, byteCount) diff --git a/okio/src/nativeMain/kotlin/okio/BufferedSink.kt b/okio/src/nativeMain/kotlin/okio/BufferedSink.kt index 65d717c60a..20db6cbc00 100644 --- a/okio/src/nativeMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/nativeMain/kotlin/okio/BufferedSink.kt @@ -54,6 +54,8 @@ actual interface BufferedSink : Sink { actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + actual fun writeBase64(string: String): BufferedSink + actual fun emit(): BufferedSink actual fun emitCompleteSegments(): BufferedSink diff --git a/okio/src/nativeMain/kotlin/okio/BufferedSource.kt b/okio/src/nativeMain/kotlin/okio/BufferedSource.kt index 98b7718a14..23323e81e6 100644 --- a/okio/src/nativeMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/nativeMain/kotlin/okio/BufferedSource.kt @@ -76,6 +76,10 @@ actual interface BufferedSource : Source { actual fun readUtf8CodePoint(): Int + actual fun readBase64(): String + + actual fun readBase64Url(): String + actual fun indexOf(b: Byte): Long actual fun indexOf(b: Byte, fromIndex: Long): Long diff --git a/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt b/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt index ed03094ec3..924fe32316 100644 --- a/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt @@ -24,6 +24,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -66,6 +67,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() override fun flush() = commonFlush() diff --git a/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt b/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt index d6f4b94221..bec6bd7b87 100644 --- a/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -62,6 +64,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int =