diff options
Diffstat (limited to 'icing/index/embed/posting-list-embedding-hit-serializer_test.cc')
-rw-r--r-- | icing/index/embed/posting-list-embedding-hit-serializer_test.cc | 864 |
1 files changed, 864 insertions, 0 deletions
diff --git a/icing/index/embed/posting-list-embedding-hit-serializer_test.cc b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc new file mode 100644 index 0000000..f829634 --- /dev/null +++ b/icing/index/embed/posting-list-embedding-hit-serializer_test.cc @@ -0,0 +1,864 @@ +// Copyright (C) 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/embed/posting-list-embedding-hit-serializer.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <deque> +#include <iterator> +#include <limits> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/embed/embedding-hit.h" +#include "icing/index/hit/hit.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/hit-test-utils.h" + +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::Eq; +using testing::IsEmpty; +using testing::Le; +using testing::Lt; + +namespace icing { +namespace lib { + +namespace { + +struct HitElt { + HitElt() = default; + explicit HitElt(const EmbeddingHit &hit_in) : hit(hit_in) {} + + static EmbeddingHit get_hit(const HitElt &hit_elt) { return hit_elt.hit; } + + EmbeddingHit hit; +}; + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedPrependHitNotFull) { + PostingListEmbeddingHitSerializer serializer; + + static const int kNumHits = 2551; + static const size_t kHitsSize = kNumHits * sizeof(EmbeddingHit); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + + // Make used. + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit0)); + int expected_size = sizeof(EmbeddingHit::Value); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + EmbeddingHit hit1(BasicHit(/*section_id=*/0, /*document_id=*/1), + /*location=*/1); + uint64_t delta = hit0.value() - hit1.value(); + uint8_t delta_buf[VarInt::kMaxEncodedLen64]; + size_t delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit1)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/2), + /*location=*/2); + delta = hit1.value() - hit2.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit2)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + EmbeddingHit hit3(BasicHit(/*section_id=*/0, /*document_id=*/3), + /*location=*/3); + delta = hit2.value() - hit3.value(); + delta_len = VarInt::Encode(delta, delta_buf); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit3)); + expected_size += delta_len; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListUsedPrependHitAlmostFull) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 32 + int pl_size = 2 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + // Fill up the compressed region. + // Transitions: + // Adding hit0: EMPTY -> NOT_FULL + // Adding hit1: NOT_FULL -> NOT_FULL + // Adding hit2: NOT_FULL -> NOT_FULL + EmbeddingHit hit0(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/1); + EmbeddingHit hit1 = CreateEmbeddingHit(hit0, /*desired_byte_length=*/3); + EmbeddingHit hit2 = CreateEmbeddingHit(hit1, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit2)); + // Size used will be 8 (hit2) + 3 (hit1-hit2) + 3 (hit0-hit1) = 14 bytes + int expected_size = sizeof(EmbeddingHit) + 3 + 3; + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit2, hit1, hit0))); + + // Add one more hit to transition NOT_FULL -> ALMOST_FULL + EmbeddingHit hit3 = CreateEmbeddingHit(hit2, /*desired_byte_length=*/3); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit3)); + // Storing them in the compressed region requires 8 (hit) + 3 (hit2-hit3) + + // 3 (hit1-hit2) + 3 (hit0-hit1) = 17 bytes, but there are only 16 bytes in + // the compressed region. So instead, the posting list will transition to + // ALMOST_FULL. The in-use compressed region will actually shrink from 14 + // bytes to 9 bytes because the uncompressed version of hit2 will be + // overwritten with the compressed delta of hit2. hit3 will be written to one + // of the special hits. Because we're in ALMOST_FULL, the expected size is the + // size of the pl minus the one hit used to mark the posting list as + // ALMOST_FULL. + expected_size = pl_size - sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> ALMOST_FULL + EmbeddingHit hit4 = CreateEmbeddingHit(hit3, /*desired_byte_length=*/6); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit4)); + // There are currently 9 bytes in use in the compressed region. Hit3 will + // have a 6-byte delta, which fits in the compressed region. Hit3 will be + // moved from the special hit to the compressed region (which will have 15 + // bytes in use after adding hit3). Hit4 will be placed in one of the special + // hits and the posting list will remain in ALMOST_FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit4, hit3, hit2, hit1, hit0))); + + // Add one more hit to transition ALMOST_FULL -> FULL + EmbeddingHit hit5 = CreateEmbeddingHit(hit4, /*desired_byte_length=*/2); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit5)); + // There are currently 15 bytes in use in the compressed region. Hit4 will + // have a 2-byte delta which will not fit in the compressed region. So hit4 + // will remain in one of the special hits and hit5 will occupy the other, + // making the posting list FULL. + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit5, hit4, hit3, hit2, hit1, hit0))); + + // The posting list is FULL. Adding another hit should fail. + EmbeddingHit hit6 = CreateEmbeddingHit(hit5, /*desired_byte_length=*/1); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit6), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, PostingListUsedMinSize) { + PostingListEmbeddingHitSerializer serializer; + + // Min size = 16 + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PL State: EMPTY + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + // Add a hit, PL should shift to ALMOST_FULL state + EmbeddingHit hit0(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/1); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit0)); + // Size = sizeof(uncompressed hit0) + int expected_size = sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(ElementsAre(hit0))); + + // Add the smallest hit possible with a delta of 0b1. PL should shift to FULL + // state. + EmbeddingHit hit1(BasicHit(/*section_id=*/1, /*document_id=*/0), + /*location=*/0); + ICING_EXPECT_OK(serializer.PrependHit(&pl_used, hit1)); + // Size = sizeof(uncompressed hit1) + sizeof(uncompressed hit0) + expected_size += sizeof(EmbeddingHit); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); + + // Try to add the smallest hit possible. Should fail + EmbeddingHit hit2(BasicHit(/*section_id=*/0, /*document_id=*/0), + /*location=*/0); + EXPECT_THAT(serializer.PrependHit(&pl_used, hit2), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAre(hit1, hit0))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayMinSizePostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Min Size = 16 + int pl_size = serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + + // Add five hits. The PL is in the empty state and an empty min size PL can + // only fit two hits. So PrependHitArray should fail. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(2)); + + int can_fit_hits = num_can_prepend; + // The PL has room for 2 hits. We should be able to add them without any + // problem, transitioning the PL from EMPTY -> ALMOST_FULL -> FULL + const HitElt *hits_in_ptr = hits_in.data() + (hits_in.size() - 2); + ICING_ASSERT_OK_AND_ASSIGN( + num_can_prepend, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, hits_in_ptr, can_fit_hits, /*keep_prepended=*/false))); + EXPECT_THAT(num_can_prepend, Eq(can_fit_hits)); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), + hits_in.rend() - hits_in.size() + can_fit_hits, + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayPostingList) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 48 + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + std::vector<HitElt> hits_in; + hits_in.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/1, /*document_id=*/0), /*location=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + std::reverse(hits_in.begin(), hits_in.end()); + // The last hit is uncompressed and the four before it should only take one + // byte. Total use = 8 bytes. + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43-36 EmbeddingHit #4 + // 35-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=36 + // ---------------------- + int byte_size = sizeof(EmbeddingHit::Value) + hits_in.size() - 1; + + // Add five hits. The PL is in the empty state and should be able to fit all + // five hits without issue, transitioning the PL from EMPTY -> NOT_FULL. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + std::deque<EmbeddingHit> hits_pushed; + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + EmbeddingHit first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/1); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/1)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/2)); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // Size increased by the deltas of these hits (1+2+1+2+3+2+3) = 14 bytes + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 EmbeddingHit #11 + // 21-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=22 + // ---------------------- + byte_size += 14; + + // Add these 7 hits. The PL is currently in the NOT_FULL state and should + // remain in the NOT_FULL state. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/8); + hits_in.clear(); + hits_in.emplace_back(first_hit); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-16 <unused> + // 15-8 EmbeddingHit #12 + // 7-0 kSpecialHit + // ---------------------- + byte_size = 40; // 48 - 8 + + // Add this 1 hit. The PL is currently in the NOT_FULL state and should + // transition to the ALMOST_FULL state - even though there is still some + // unused space. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(byte_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + first_hit = + CreateEmbeddingHit(hits_in.begin()->hit, /*desired_byte_length=*/5); + hits_in.clear(); + hits_in.emplace_back(first_hit); + hits_in.emplace_back( + CreateEmbeddingHit(hits_in.rbegin()->hit, /*desired_byte_length=*/3)); + std::reverse(hits_in.begin(), hits_in.end()); + // ---------------------- + // 47 delta(EmbeddingHit #0) + // 46 delta(EmbeddingHit #1) + // 45 delta(EmbeddingHit #2) + // 44 delta(EmbeddingHit #3) + // 43 delta(EmbeddingHit #4) + // 42-41 delta(EmbeddingHit #5) + // 40 delta(EmbeddingHit #6) + // 39-38 delta(EmbeddingHit #7) + // 37-35 delta(EmbeddingHit #8) + // 34-33 delta(EmbeddingHit #9) + // 32-30 delta(EmbeddingHit #10) + // 29-22 delta(EmbeddingHit #11) + // 21-17 delta(EmbeddingHit #12) + // 16 <unused> + // 15-8 EmbeddingHit #13 + // 7-0 EmbeddingHit #14 + // ---------------------- + + // Add these 2 hits. + // - The PL is currently in the ALMOST_FULL state. Adding the first hit should + // keep the PL in ALMOST_FULL because the delta between + // EmbeddingHit #12 and EmbeddingHit #13 (5 byte) can fit in the unused area + // (6 bytes). + // - Adding the second hit should transition to the FULL state because the + // delta between EmbeddingHit #13 and EmbeddingHit #14 (3 bytes) is larger + // than the remaining unused area (1 byte). + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hits_in[0], hits_in.size(), /*keep_prepended=*/false))); + EXPECT_THAT(num_could_fit, Eq(hits_in.size())); + EXPECT_THAT(pl_size, Eq(serializer.GetBytesUsed(&pl_used))); + // All hits from hits_in were added. + std::transform(hits_in.rbegin(), hits_in.rend(), + std::front_inserter(hits_pushed), HitElt::get_hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListPrependHitArrayTooManyHits) { + PostingListEmbeddingHitSerializer serializer; + + static constexpr int kNumHits = 130; + static constexpr int kDeltaSize = 1; + static constexpr size_t kHitsSize = + ((kNumHits - 2) * kDeltaSize + (2 * sizeof(EmbeddingHit))); + + // Create an array with one too many hits + std::vector<HitElt> hit_elts_in_too_many; + hit_elts_in_too_many.emplace_back(EmbeddingHit( + BasicHit(/*section_id=*/0, /*document_id=*/0), /*location=*/0)); + for (int i = 0; i < kNumHits; ++i) { + hit_elts_in_too_many.emplace_back(CreateEmbeddingHit( + hit_elts_in_too_many.back().hit, /*desired_byte_length=*/1)); + } + // Reverse so that hits are inserted in descending order + std::reverse(hit_elts_in_too_many.begin(), hit_elts_in_too_many.end()); + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + // PrependHitArray should fail because hit_elts_in_too_many is far too large + // for the minimum size pl. + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(2)); + ASSERT_THAT(num_could_fit, Lt(hit_elts_in_too_many.size())); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); + + ICING_ASSERT_OK_AND_ASSIGN( + pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, kHitsSize)); + // PrependHitArray should fail because hit_elts_in_too_many is one hit too + // large for this pl. + ICING_ASSERT_OK_AND_ASSIGN( + num_could_fit, + (serializer.PrependHitArray<HitElt, HitElt::get_hit>( + &pl_used, &hit_elts_in_too_many[0], hit_elts_in_too_many.size(), + /*keep_prepended=*/false))); + ASSERT_THAT(num_could_fit, Eq(hit_elts_in_too_many.size() - 1)); + ASSERT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + ASSERT_THAT(serializer.GetHits(&pl_used), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + PostingListStatusJumpFromNotFullToFullAndBack) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 24 + const uint32_t pl_size = 3 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + EmbeddingHit max_valued_hit( + BasicHit(/*section_id=*/kMaxSectionId, /*document_id=*/kMinDocumentId), + /*location=*/std::numeric_limits<uint32_t>::max()); + ICING_ASSERT_OK(serializer.PrependHit(&pl, max_valued_hit)); + uint32_t bytes_used = serializer.GetBytesUsed(&pl); + ASSERT_THAT(bytes_used, sizeof(EmbeddingHit)); + // Status not full. + ASSERT_THAT( + bytes_used, + Le(pl_size - PostingListEmbeddingHitSerializer::kSpecialHitsSize)); + + EmbeddingHit min_valued_hit( + BasicHit(/*section_id=*/kMinSectionId, /*document_id=*/kMaxDocumentId), + /*location=*/0); + ICING_ASSERT_OK(serializer.PrependHit(&pl, min_valued_hit)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(min_valued_hit, max_valued_hit))); + // Status should jump to full directly. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(pl_size)); + ICING_ASSERT_OK(serializer.PopFrontHits(&pl, 1)); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAre(max_valued_hit))); + // Status should return to not full as before. + ASSERT_THAT(serializer.GetBytesUsed(&pl), Eq(bytes_used)); +} + +TEST(PostingListEmbeddingHitSerializerTest, DeltaOverflow) { + PostingListEmbeddingHitSerializer serializer; + + const uint32_t pl_size = 4 * sizeof(EmbeddingHit); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + + static const EmbeddingHit::Value kMaxHitValue = + std::numeric_limits<EmbeddingHit::Value>::max(); + static const EmbeddingHit::Value kOverflow[4] = { + kMaxHitValue >> 2, + (kMaxHitValue >> 2) * 2, + (kMaxHitValue >> 2) * 3, + kMaxHitValue - 1, + }; + + // Fit at least 4 ordinary values. + std::deque<EmbeddingHit> hits_pushed; + for (EmbeddingHit::Value v = 0; v < 4; v++) { + hits_pushed.push_front( + EmbeddingHit(BasicHit(kMaxSectionId, kMaxDocumentId), 4 - v)); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + + // Cannot fit 4 overflow values. + hits_pushed.clear(); + ICING_ASSERT_OK_AND_ASSIGN( + pl, PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + for (int i = 3; i >= 1; i--) { + hits_pushed.push_front(EmbeddingHit(/*value=*/kOverflow[i])); + ICING_EXPECT_OK(serializer.PrependHit(&pl, hits_pushed.front())); + EXPECT_THAT(serializer.GetHits(&pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + } + EXPECT_THAT(serializer.PrependHit(&pl, EmbeddingHit(/*value=*/kOverflow[0])), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForNotFullPL) { + PostingListEmbeddingHitSerializer serializer; + + // Size = 64 + int pl_size = 4 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used NOT_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 63-62 delta(EmbeddingHit #0) + // 61-60 delta(EmbeddingHit #1) + // 59-58 delta(EmbeddingHit #2) + // 57-56 delta(EmbeddingHit #3) + // 55-48 EmbeddingHit #5 + // 47-16 <unused> + // 15-8 kSpecialHit + // 7-0 Offset=48 + // ---------------------- + int bytes_used = 16; + + // Check that all hits have been inserted + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // Get the min size to fit for the hits in pl_used. Moving the hits in pl_used + // into a posting list with this min size should make it ALMOST_FULL, which we + // can see should have size = 24. + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int expected_min_size = 24; + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(expected_min_size)); + + // Also check that this min size to fit posting list actually does fit all the + // hits and can only hit one more hit in the ALMOST_FULL state. + ICING_ASSERT_OK_AND_ASSIGN(PostingListUsed min_size_to_fit_pl, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, min_size_to_fit)); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + } + + // Adding another hit to the min-size-to-fit posting list should succeed + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&min_size_to_fit_pl, hit)); + // Adding any other hits should fail with RESOURCE_EXHAUSTED error. + EXPECT_THAT( + serializer.PrependHit(&min_size_to_fit_pl, + CreateEmbeddingHit(hit, /*desired_byte_length=*/1)), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + + // Check that all hits have been inserted and the min-fit posting list is now + // FULL. + EXPECT_THAT(serializer.GetBytesUsed(&min_size_to_fit_pl), + Eq(min_size_to_fit)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&min_size_to_fit_pl), + IsOkAndHolds(ElementsAreArray(hits_pushed))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + GetMinPostingListToFitForAlmostFullAndFullPLReturnsSameSize) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 24; + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + // Create and add some hits to make pl_used ALMOST_FULL + std::vector<EmbeddingHit> hits_in = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits_in) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + } + // ---------------------- + // 23-22 delta(EmbeddingHit #0) + // 21-20 delta(EmbeddingHit #1) + // 19-18 delta(EmbeddingHit #2) + // 17-16 delta(EmbeddingHit #3) + // 15-8 EmbeddingHit #4 + // 7-0 kSpecialHit + // ---------------------- + int bytes_used = 16; + + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(bytes_used)); + std::deque<EmbeddingHit> hits_pushed(hits_in.rbegin(), hits_in.rend()); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should return the same size as pl_used. + uint32_t min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); + + // Add another hit to make the posting list FULL + EmbeddingHit hit = + CreateEmbeddingHit(hits_in.back(), /*desired_byte_length=*/1); + ICING_ASSERT_OK(serializer.PrependHit(&pl_used, hit)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(pl_size)); + hits_pushed.emplace_front(hit); + EXPECT_THAT(serializer.GetHits(&pl_used), + IsOkAndHolds(ElementsAreArray(hits_pushed))); + + // GetMinPostingListSizeToFit should still be the same size as pl_used. + min_size_to_fit = serializer.GetMinPostingListSizeToFit(&pl_used); + EXPECT_THAT(min_size_to_fit, Eq(pl_size)); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveFrom) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + ICING_ASSERT_OK(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromNullArgumentReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveFromInvalidPostingListReturnsInvalidArgument) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used1 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used1.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, + MoveToInvalidPostingListReturnsFailedPrecondition) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + // Write invalid hits to the beginning of pl_used2 to make it invalid. + EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue); + EmbeddingHit *first_hit = + reinterpret_cast<EmbeddingHit *>(pl_used2.posting_list_buffer()); + *first_hit = invalid_hit; + ++first_hit; + *first_hit = invalid_hit; + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); +} + +TEST(PostingListEmbeddingHitSerializerTest, MoveToPostingListTooSmall) { + PostingListEmbeddingHitSerializer serializer; + + int pl_size = 3 * serializer.GetMinPostingListSize(); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion(&serializer, pl_size)); + std::vector<EmbeddingHit> hits1 = + CreateEmbeddingHits(/*num_hits=*/5, /*desired_byte_length=*/1); + for (const EmbeddingHit &hit : hits1) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, serializer.GetMinPostingListSize())); + std::vector<EmbeddingHit> hits2 = + CreateEmbeddingHits(/*num_hits=*/1, /*desired_byte_length=*/2); + for (const EmbeddingHit &hit : hits2) { + ICING_ASSERT_OK(serializer.PrependHit(&pl_used2, hit)); + } + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(serializer.GetHits(&pl_used1), + IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); + EXPECT_THAT(serializer.GetHits(&pl_used2), + IsOkAndHolds(ElementsAreArray(hits2.rbegin(), hits2.rend()))); +} + +} // namespace + +} // namespace lib +} // namespace icing |