diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc index ab1fe24b1c..fe1eed65c4 100644 --- a/c++/src/ColumnReader.cc +++ b/c++/src/ColumnReader.cc @@ -666,7 +666,11 @@ namespace orc { while (done < numValues) { uint64_t step = std::min(BUFFER_SIZE, static_cast(numValues - done)); lengthRle_->next(buffer, step, nullptr); - totalBytes += computeSize(buffer, nullptr, step); + size_t stepBytes = computeSize(buffer, nullptr, step); + if (totalBytes > std::numeric_limits::max() - stepBytes) { + throw ParseError("String length overflow in StringDirectColumn"); + } + totalBytes += stepBytes; done += step; } if (totalBytes <= lastBufferLength_) { @@ -694,17 +698,32 @@ namespace orc { size_t StringDirectColumnReader::computeSize(const int64_t* lengths, const char* notNull, uint64_t numValues) { size_t totalLength = 0; + bool hasNegativeLength = false; + bool hasLengthOverflow = false; + auto addLength = [&](int64_t value) { + hasNegativeLength |= value < 0; + size_t length = static_cast(value); + size_t nextTotalLength = totalLength + length; + hasLengthOverflow |= nextTotalLength < totalLength; + totalLength = nextTotalLength; + }; if (notNull) { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { - totalLength += static_cast(lengths[i]); + addLength(lengths[i]); } } } else { for (size_t i = 0; i < numValues; ++i) { - totalLength += static_cast(lengths[i]); + addLength(lengths[i]); } } + if (hasNegativeLength) { + throw ParseError("Negative string length in StringDirectColumn"); + } + if (hasLengthOverflow) { + throw ParseError("String length overflow in StringDirectColumn"); + } return totalLength; } diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc index d2aa38cb66..7528c9f769 100644 --- a/c++/test/TestColumnReader.cc +++ b/c++/test/TestColumnReader.cc @@ -42,6 +42,43 @@ namespace orc { return timeptr != nullptr; } + void expectStringDirectLengthError(const unsigned char* lengthData, uint64_t lengthDataSize, + uint64_t numValues, bool skip) { + MockStripeStreams streams; + + std::vector selectedColumns(2, true); + EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns)); + + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding)); + EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr)); + + EXPECT_CALL(streams, getStreamProxy(0, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + + const char blob[] = {'x'}; + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(blob, ARRAY_SIZE(blob)))); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); + + std::unique_ptr rowType = createStructType(); + rowType->addStructField("col0", createPrimitiveType(STRING)); + + std::unique_ptr reader = buildReader(*rowType, streams); + if (skip) { + EXPECT_THROW(reader->skip(numValues), ParseError); + } else { + StructVectorBatch batch(numValues, *getDefaultPool()); + StringVectorBatch* strings = new StringVectorBatch(numValues, *getDefaultPool()); + batch.fields.push_back(strings); + EXPECT_THROW(reader->next(batch, numValues, nullptr), ParseError); + } + } + class TestColumnReaderEncoded : public TestWithParam { void SetUp() override; @@ -882,6 +919,24 @@ namespace orc { EXPECT_THROW(reader->next(batch, 100, 0), ParseError); } + TEST(TestColumnReader, testStringDirectRejectsNegativeLength) { + // RLEv1 literal run with one unsigned value UINT64_MAX, which becomes -1 + // when decoded into int64_t. + const unsigned char lengthData[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x01}; + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 1, false); + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 1, true); + } + + TEST(TestColumnReader, testStringDirectRejectsLengthOverflow) { + // RLEv1 literal run with three INT64_MAX lengths. + const unsigned char lengthData[] = {0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, false); + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, true); + } + TEST_P(TestColumnReaderEncoded, testStringDirectShortBuffer) { MockStripeStreams streams;