diff --git a/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridDecoder.java b/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridDecoder.java index e55b276b29..b715d00af1 100644 --- a/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridDecoder.java +++ b/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridDecoder.java @@ -18,9 +18,9 @@ */ package org.apache.parquet.column.values.rle; -import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -48,6 +48,8 @@ private static enum MODE { private int currentCount; private int currentValue; private int[] currentBuffer; + private int currentBufferLength; + private byte[] packedBytes; public RunLengthBitPackingHybridDecoder(int bitWidth, InputStream in) { LOG.debug("decoding bitWidth {}", bitWidth); @@ -69,7 +71,7 @@ public int readInt() throws IOException { result = currentValue; break; case PACKED: - result = currentBuffer[currentBuffer.length - 1 - currentCount]; + result = currentBuffer[currentBufferLength - 1 - currentCount]; break; default: throw new ParquetDecodingException("not a valid mode " + mode); @@ -90,17 +92,23 @@ private void readNext() throws IOException { case PACKED: int numGroups = header >>> 1; currentCount = numGroups * 8; + currentBufferLength = currentCount; LOG.debug("reading {} values BIT PACKED", currentCount); - currentBuffer = new int[currentCount]; // TODO: reuse a buffer - byte[] bytes = new byte[numGroups * bitWidth]; - // At the end of the file RLE data though, there might not be that many bytes left. - int bytesToRead = (int) Math.ceil(currentCount * bitWidth / 8.0); - bytesToRead = Math.min(bytesToRead, in.available()); - new DataInputStream(in).readFully(bytes, 0, bytesToRead); + if (currentBuffer == null || currentBuffer.length < currentCount) { + currentBuffer = new int[currentCount]; + } + int bytesNeeded = numGroups * bitWidth; + if (packedBytes == null || packedBytes.length < bytesNeeded) { + packedBytes = new byte[bytesNeeded]; + } + int bytesRead = in.readNBytes(packedBytes, 0, bytesNeeded); + if (bytesRead < bytesNeeded) { + Arrays.fill(packedBytes, bytesRead, bytesNeeded, (byte) 0); + } for (int valueIndex = 0, byteIndex = 0; valueIndex < currentCount; valueIndex += 8, byteIndex += bitWidth) { - packer.unpack8Values(bytes, byteIndex, currentBuffer, valueIndex); + packer.unpack8Values(packedBytes, byteIndex, currentBuffer, valueIndex); } break; default: diff --git a/parquet-column/src/test/java/org/apache/parquet/column/values/rle/TestRunLengthBitPackingHybridEncoder.java b/parquet-column/src/test/java/org/apache/parquet/column/values/rle/TestRunLengthBitPackingHybridEncoder.java index 93a6c8deb4..04dbeed23e 100644 --- a/parquet-column/src/test/java/org/apache/parquet/column/values/rle/TestRunLengthBitPackingHybridEncoder.java +++ b/parquet-column/src/test/java/org/apache/parquet/column/values/rle/TestRunLengthBitPackingHybridEncoder.java @@ -19,9 +19,11 @@ package org.apache.parquet.column.values.rle; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; import java.io.ByteArrayInputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.bytes.DirectByteBufferAllocator; @@ -298,6 +300,49 @@ public void testGroupBoundary() throws Exception { assertEquals(stream.available(), 0); } + @Test + public void testTruncatedPackedRunAfterFullPackedRunDoesNotReuseStaleBytes() throws Exception { + int bitWidth = 3; + BytePacker packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); + + int[] firstRunValues = new int[8]; + Arrays.fill(firstRunValues, 7); + byte[] firstRunPacked = new byte[bitWidth]; + packer.pack8Values(firstRunValues, 0, firstRunPacked, 0); + + int[] secondRunValues = {1, 2, 3, 4, 5, 6, 7, 0}; + byte[] secondRunPacked = new byte[bitWidth]; + packer.pack8Values(secondRunValues, 0, secondRunPacked, 0); + + byte[] encoded = { + (byte) ((1 << 1) | 1), + firstRunPacked[0], + firstRunPacked[1], + firstRunPacked[2], + (byte) ((1 << 1) | 1), + secondRunPacked[0] + }; + + RunLengthBitPackingHybridDecoder decoder = + new RunLengthBitPackingHybridDecoder(bitWidth, new ByteArrayInputStream(encoded)); + + for (int ignored = 0; ignored < 8; ignored++) { + assertEquals(7, decoder.readInt()); + } + + int[] actualSecondRun = new int[8]; + for (int i = 0; i < 8; i++) { + actualSecondRun[i] = decoder.readInt(); + } + + byte[] expectedSecondPacked = new byte[bitWidth]; + expectedSecondPacked[0] = secondRunPacked[0]; + int[] expectedSecondRun = new int[8]; + packer.unpack8Values(expectedSecondPacked, 0, expectedSecondRun, 0); + + assertArrayEquals(expectedSecondRun, actualSecondRun); + } + private static List unpack(int bitWidth, int numValues, ByteArrayInputStream is) throws Exception { BytePacker packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth);