diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-09-23 10:02:16 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-09-23 10:02:16 +0000 |
commit | c98a298235bbd6e4ad9fda46062b861df4b629fe (patch) | |
tree | f6c3e8206a8372e464c6d2576ca90821f2df300b | |
parent | 4bd49c828850130e6c09a393af57216ad6333570 (diff) | |
parent | fb3ae39e2f6ece0a75b3670b7256587d18bc81ff (diff) | |
download | icing-android13-mainline-scheduling-release.tar.gz |
Snap for 9098257 from fb3ae39e2f6ece0a75b3670b7256587d18bc81ff to mainline-scheduling-releaseaml_sch_331113000aml_sch_331111000android13-mainline-scheduling-release
Change-Id: Ia6d6645e60c0326ea72596c80c392c56a5bb2f82
130 files changed, 14789 insertions, 2623 deletions
diff --git a/AndroidManifest.xml b/AndroidManifest.xml deleted file mode 100644 index 7377c53..0000000 --- a/AndroidManifest.xml +++ /dev/null @@ -1,2 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<manifest package="com.google.android.icing" /> diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c8e439..48a63d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,7 @@ project(icing) add_definitions("-DICING_REVERSE_JNI_SEGMENTATION=1") set(VERSION_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/icing/jni.lds") +set(CMAKE_CXX_STANDARD 17) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections -Wl,--version-script=${VERSION_SCRIPT}") diff --git a/build.gradle b/build.gradle index 5b5f3a6..2ac1d39 100644 --- a/build.gradle +++ b/build.gradle @@ -42,13 +42,13 @@ android { sourceSets { main { java.srcDir 'java/src/' - manifest.srcFile 'AndroidManifest.xml' proto.srcDir 'proto/' } // TODO(b/161205849): Re-enable this test once icing nativeLib is no longer being built // inside appsearch:appsearch. //androidTest.java.srcDir 'java/tests/instrumentation/' } + namespace "com.google.android.icing" } dependencies { diff --git a/icing/file/destructible-directory.h b/icing/file/destructible-directory.h new file mode 100644 index 0000000..9a8bd4b --- /dev/null +++ b/icing/file/destructible-directory.h @@ -0,0 +1,74 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_FILE_DESTRUCTIBLE_DIRECTORY_H_ +#define ICING_FILE_DESTRUCTIBLE_DIRECTORY_H_ + +#include "icing/file/filesystem.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { + +// A convenient RAII class which will recursively create the directory at the +// specified file path and delete it upon destruction. +class DestructibleDirectory { + public: + explicit DestructibleDirectory(const Filesystem* filesystem, std::string dir) + : filesystem_(filesystem), dir_(std::move(dir)) { + is_valid_ = filesystem_->CreateDirectoryRecursively(dir_.c_str()); + } + + DestructibleDirectory(const DestructibleDirectory&) = delete; + DestructibleDirectory& operator=(const DestructibleDirectory&) = delete; + + DestructibleDirectory(DestructibleDirectory&& rhs) + : filesystem_(nullptr), is_valid_(false) { + Swap(rhs); + } + + DestructibleDirectory& operator=(DestructibleDirectory&& rhs) { + Swap(rhs); + return *this; + } + + ~DestructibleDirectory() { + if (filesystem_ != nullptr && + !filesystem_->DeleteDirectoryRecursively(dir_.c_str())) { + // Swallow deletion failures as there's nothing actionable to do about + // them. + ICING_LOG(WARNING) << "Unable to delete temporary directory: " << dir_; + } + } + + const std::string& dir() const { return dir_; } + + bool is_valid() const { return is_valid_; } + + private: + void Swap(DestructibleDirectory& other) { + std::swap(filesystem_, other.filesystem_); + std::swap(dir_, other.dir_); + std::swap(is_valid_, other.is_valid_); + } + + const Filesystem* filesystem_; + std::string dir_; + bool is_valid_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_FILE_DESTRUCTIBLE_DIRECTORY_H_ diff --git a/icing/file/destructible-directory_test.cc b/icing/file/destructible-directory_test.cc new file mode 100644 index 0000000..c62db3b --- /dev/null +++ b/icing/file/destructible-directory_test.cc @@ -0,0 +1,118 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/destructible-directory.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; + +TEST(DestructibleFileTest, DeletesDirectoryProperly) { + Filesystem filesystem; + std::string dir_path = GetTestTempDir() + "/dir1"; + std::string file_path = dir_path + "/file1"; + + { + // 1. Create a file in the directory. + ASSERT_TRUE(filesystem.CreateDirectoryRecursively(dir_path.c_str())); + ScopedFd sfd(filesystem.OpenForWrite(file_path.c_str())); + ASSERT_TRUE(sfd.is_valid()); + int i = 127; + ASSERT_TRUE(filesystem.Write(sfd.get(), &i, sizeof(i))); + } + + { + // 2. Open the directory with a DestructibleDirectory + DestructibleDirectory destructible(&filesystem, dir_path); + EXPECT_TRUE(destructible.is_valid()); + EXPECT_THAT(destructible.dir(), Eq(dir_path)); + } + + // 3. Ensure that the file and directory don't exist. + EXPECT_FALSE(filesystem.FileExists(file_path.c_str())); + EXPECT_FALSE(filesystem.DirectoryExists(dir_path.c_str())); +} + +TEST(DestructibleFileTest, MoveAssignDeletesFileProperly) { + Filesystem filesystem; + std::string filepath1 = GetTestTempDir() + "/dir1"; + std::string filepath2 = GetTestTempDir() + "/dir2"; + + // 1. Create dir1 + DestructibleDirectory destructible1(&filesystem, filepath1); + ASSERT_TRUE(destructible1.is_valid()); + ASSERT_TRUE(filesystem.DirectoryExists(filepath1.c_str())); + + { + // 2. Create dir2 + DestructibleDirectory destructible2(&filesystem, filepath2); + ASSERT_TRUE(destructible2.is_valid()); + + // Move assign destructible2 into destructible1 + destructible1 = std::move(destructible2); + } + + // 3. dir1 shouldn't exist because it was destroyed when destructible1 was + // move assigned to. + EXPECT_FALSE(filesystem.DirectoryExists(filepath1.c_str())); + + // 4. dir2 should still exist because it moved into destructible1 from + // destructible2. + EXPECT_TRUE(filesystem.DirectoryExists(filepath2.c_str())); +} + +TEST(DestructibleFileTest, MoveConstructionDeletesFileProperly) { + Filesystem filesystem; + std::string filepath1 = GetTestTempDir() + "/dir1"; + + // 1. Create destructible1, it'll be reconstructed soon anyways. + std::unique_ptr<DestructibleDirectory> destructible1; + { + // 2. Create file1 + DestructibleDirectory destructible2(&filesystem, filepath1); + ASSERT_TRUE(destructible2.is_valid()); + + // Move construct destructible1 from destructible2 + destructible1 = + std::make_unique<DestructibleDirectory>(std::move(destructible2)); + } + + // 3. dir1 should still exist because it moved into destructible1 from + // destructible2. + EXPECT_TRUE(destructible1->is_valid()); + EXPECT_TRUE(filesystem.DirectoryExists(filepath1.c_str())); + + { + // 4. Move construct destructible3 from destructible1 + DestructibleDirectory destructible3(std::move(*destructible1)); + EXPECT_TRUE(destructible3.is_valid()); + } + + // 5. dir1 shouldn't exist because it was destroyed when destructible3 was + // destroyed. + EXPECT_FALSE(filesystem.DirectoryExists(filepath1.c_str())); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/file/file-backed-bitmap.cc b/icing/file/file-backed-bitmap.cc index eec7668..a8231e3 100644 --- a/icing/file/file-backed-bitmap.cc +++ b/icing/file/file-backed-bitmap.cc @@ -269,8 +269,7 @@ libtextclassifier3::Status FileBackedBitmap::GrowTo(int new_num_bits) { return status; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Grew file %s to new size %zd", file_path_.c_str(), new_file_size); + ICING_VLOG(1) << "Grew file " << file_path_ << " to new size " << new_file_size; mutable_header()->state = Header::ChecksumState::kStale; return libtextclassifier3::Status::OK; } diff --git a/icing/file/file-backed-proto-log.h b/icing/file/file-backed-proto-log.h index 686b4fb..ad7fae9 100644 --- a/icing/file/file-backed-proto-log.h +++ b/icing/file/file-backed-proto-log.h @@ -455,8 +455,8 @@ FileBackedProtoLog<ProtoT>::InitializeExistingFile(const Filesystem* filesystem, absl_ports::StrCat("Error truncating file: ", file_path)); } - ICING_LOG(INFO) << "Truncated '" << file_path << "' to size " - << last_known_good; + ICING_LOG(WARNING) << "Truncated '" << file_path << "' to size " + << last_known_good; } CreateResult create_result = { diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index 7e42e32..bcfbbdd 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -58,8 +58,12 @@ #include <sys/mman.h> +#include <algorithm> #include <cinttypes> #include <cstdint> +#include <cstring> +#include <functional> +#include <limits> #include <memory> #include <string> #include <utility> @@ -83,6 +87,9 @@ namespace lib { template <typename T> class FileBackedVector { public: + class MutableArrayView; + class MutableView; + // Header stored at the beginning of the file before the rest of the vector // elements. Stores metadata on the vector. struct Header { @@ -133,15 +140,24 @@ class FileBackedVector { kHeaderChecksumOffset, ""); - Crc32 crc; - std::string_view header_str( - reinterpret_cast<const char*>(this), - offsetof(FileBackedVector::Header, header_checksum)); - crc.Append(header_str); - return crc.Get(); + return Crc32(std::string_view( + reinterpret_cast<const char*>(this), + offsetof(FileBackedVector::Header, header_checksum))) + .Get(); } }; + // Absolute max file size for FileBackedVector. Note that Android has a + // (2^31-1)-byte single file size limit, so kMaxFileSize is 2^31-1. + static constexpr int32_t kMaxFileSize = + std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB; + + // Size of element type T. The value is same as sizeof(T), while we should + // avoid using sizeof(T) in our codebase to prevent unexpected unsigned + // integer casting. + static constexpr int32_t kElementTypeSize = static_cast<int32_t>(sizeof(T)); + static_assert(sizeof(T) <= (1 << 10)); + // Creates a new FileBackedVector to read/write content to. // // filesystem: Object to make system level calls @@ -149,15 +165,20 @@ class FileBackedVector { // within a directory that already exists. // mmap_strategy : Strategy/optimizations to access the content in the vector, // see MemoryMappedFile::Strategy for more details + // max_file_size: Maximum file size for FileBackedVector, default + // kMaxFileSize. See max_file_size_ and kMaxFileSize for more + // details. // // Return: // FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored // checksum. // INTERNAL_ERROR on I/O errors. + // INVALID_ARGUMENT_ERROR if max_file_size is incorrect. // UNIMPLEMENTED_ERROR if created with strategy READ_WRITE_MANUAL_SYNC. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> Create(const Filesystem& filesystem, const std::string& file_path, - MemoryMappedFile::Strategy mmap_strategy); + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size = kMaxFileSize); // Deletes the FileBackedVector // @@ -184,13 +205,13 @@ class FileBackedVector { // referencing the now-invalidated region. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or > num_elements() + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() libtextclassifier3::StatusOr<T> GetCopy(int32_t idx) const; - // Gets a pointer to the element at idx. + // Gets an immutable pointer to the element at idx. // - // WARNING: Subsequent calls to Set may invalidate the pointer returned by - // Get. + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the pointer + // returned by Get. // // This is useful if you do not think the FileBackedVector will grow before // you need to reference this value, and you want to avoid a copy. When the @@ -198,27 +219,102 @@ class FileBackedVector { // which will invalidate this pointer to the previously mapped region. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or > num_elements() + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() libtextclassifier3::StatusOr<const T*> Get(int32_t idx) const; + // Gets a MutableView to the element at idx. + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference returned by MutableView::Get(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to mutate the underlying + // data directly. When the FileBackedVector grows, the underlying mmap will be + // unmapped and remapped, which will invalidate this MutableView to the + // previously mapped region. + // + // Returns: + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() + libtextclassifier3::StatusOr<MutableView> GetMutable(int32_t idx); + + // Gets a MutableArrayView to the elements at range [idx, idx + len). + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference/pointer returned by MutableArrayView::operator[]/data(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to mutate the underlying + // data directly. When the FileBackedVector grows, the underlying mmap will be + // unmapped and remapped, which will invalidate this MutableArrayView to the + // previously mapped region. + // + // Returns: + // OUT_OF_RANGE_ERROR if idx < 0 or idx + len > num_elements() + libtextclassifier3::StatusOr<MutableArrayView> GetMutable(int32_t idx, + int32_t len); + // Writes the value at idx. // // May grow the underlying file and mmapped region as needed to fit the new - // value. If it does grow, then any pointers to previous values returned - // from Get() may be invalidated. + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or file cannot be grown idx size + // OUT_OF_RANGE_ERROR if idx < 0 or idx > kMaxIndex or file cannot be grown + // idx size libtextclassifier3::Status Set(int32_t idx, const T& value); + // Appends the value to the end of the vector. + // + // May grow the underlying file and mmapped region as needed to fit the new + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. + // + // Returns: + // OUT_OF_RANGE_ERROR if file cannot be grown (i.e. reach max_file_size_) + libtextclassifier3::Status Append(const T& value) { + return Set(header_->num_elements, value); + } + + // Allocates spaces with given length in the end of the vector and returns a + // MutableArrayView to the space. + // + // May grow the underlying file and mmapped region as needed to fit the new + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference/pointer returned by MutableArrayView::operator[]/data(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to allocate adjacent spaces + // for multiple elements and mutate the underlying data directly. When the + // FileBackedVector grows, the underlying mmap will be unmapped and remapped, + // which will invalidate this MutableArrayView to the previously mapped + // region. + // + // Returns: + // OUT_OF_RANGE_ERROR if len <= 0 or file cannot be grown (i.e. reach + // max_file_size_) + libtextclassifier3::StatusOr<MutableArrayView> Allocate(int32_t len); + // Resizes to first len elements. The crc is cleared on truncation and will be // updated on destruction, or once the client calls ComputeChecksum() or // PersistToDisk(). // // Returns: - // OUT_OF_RANGE_ERROR if len < 0 or >= num_elements() + // OUT_OF_RANGE_ERROR if len < 0 or len >= num_elements() libtextclassifier3::Status TruncateTo(int32_t new_num_elements); + // Mark idx as changed iff idx < changes_end_, so later ComputeChecksum() can + // update checksum by the cached changes without going over [0, changes_end_). + // + // If the buffer size exceeds kPartialCrcLimitDiv, then clear all change + // buffers and set changes_end_ as 0, indicating that the checksum should be + // recomputed from idx 0 (starting from the beginning). Otherwise cache the + // change. + void SetDirty(int32_t idx); + // Flushes content to underlying file. // // Returns: @@ -248,10 +344,6 @@ class FileBackedVector { return reinterpret_cast<const T*>(mmapped_file_->region()); } - T* mutable_array() const { - return reinterpret_cast<T*>(mmapped_file_->mutable_region()); - } - int32_t num_elements() const { return header_->num_elements; } // Updates checksum of the vector contents and returns it. @@ -260,6 +352,66 @@ class FileBackedVector { // INTERNAL_ERROR if the vector's internal state is inconsistent libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); + public: + class MutableArrayView { + public: + const T& operator[](int32_t idx) const { return data_[idx]; } + T& operator[](int32_t idx) { + SetDirty(idx); + return data_[idx]; + } + + const T* data() const { return data_; } + + int32_t size() const { return len_; } + + // Set the mutable array slice (starting at idx) by the given element array. + // It handles SetDirty properly for the file-backed-vector when modifying + // elements. + // + // REQUIRES: arr is valid && arr_len >= 0 && idx + arr_len <= size(), + // otherwise the behavior is undefined. + void SetArray(int32_t idx, const T* arr, int32_t arr_len) { + for (int32_t i = 0; i < arr_len; ++i) { + SetDirty(idx + i); + data_[idx + i] = arr[i]; + } + } + + private: + MutableArrayView(FileBackedVector<T>* vector, T* data, int32_t len) + : vector_(vector), + data_(data), + original_idx_(data - vector->array()), + len_(len) {} + + void SetDirty(int32_t idx) { vector_->SetDirty(original_idx_ + idx); } + + // Does not own. For SetDirty only. + FileBackedVector<T>* vector_; + + // data_ points at vector_->mutable_array()[original_idx_] + T* data_; + int32_t original_idx_; + int32_t len_; + + friend class FileBackedVector; + }; + + class MutableView { + public: + const T& Get() const { return mutable_array_view_[0]; } + T& Get() { return mutable_array_view_[0]; } + + private: + MutableView(FileBackedVector<T>* vector, T* data) + : mutable_array_view_(vector, data, 1) {} + + MutableArrayView mutable_array_view_; + + friend class FileBackedVector; + }; + private: // We track partial updates to the array for crc updating. This // requires extra memory to keep track of original buffers but @@ -271,24 +423,33 @@ class FileBackedVector { // Grow file by at least this many elements if array is growable. static constexpr int64_t kGrowElements = 1u << 14; // 16K - // Max number of elements that can be held by the vector. - static constexpr int64_t kMaxNumElements = 1u << 20; // 1M + // Absolute max # of elements allowed. Since we are using int32_t to store + // num_elements, max value is 2^31-1. Still the actual max # of elements are + // determined by max_file_size, kElementTypeSize, and Header::kHeaderSize. + static constexpr int32_t kMaxNumElements = + std::numeric_limits<int32_t>::max(); + + // Absolute max index allowed. + static constexpr int32_t kMaxIndex = kMaxNumElements - 1; // Can only be created through the factory ::Create function FileBackedVector(const Filesystem& filesystem, const std::string& file_path, std::unique_ptr<Header> header, - std::unique_ptr<MemoryMappedFile> mmapped_file); + std::unique_ptr<MemoryMappedFile> mmapped_file, + int32_t max_file_size); // Initialize a new FileBackedVector, and create the file. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> InitializeNewFile(const Filesystem& filesystem, const std::string& file_path, - ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy); + ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size); // Initialize a FileBackedVector from an existing file. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> InitializeExistingFile(const Filesystem& filesystem, const std::string& file_path, ScopedFd fd, - MemoryMappedFile::Strategy mmap_strategy); + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size); // Grows the underlying file to hold at least num_elements // @@ -296,6 +457,10 @@ class FileBackedVector { // OUT_OF_RANGE_ERROR if we can't grow to the specified size libtextclassifier3::Status GrowIfNecessary(int32_t num_elements); + T* mutable_array() const { + return reinterpret_cast<T*>(mmapped_file_->mutable_region()); + } + // Cached constructor params. const Filesystem* const filesystem_; const std::string file_path_; @@ -314,25 +479,42 @@ class FileBackedVector { // update. Will be cleared if the size grows too big. std::string saved_original_buffer_; - // Keep track of all pages we touched so we can write them back to - // disk. - std::vector<bool> dirty_pages_; + // Max file size for FileBackedVector, default kMaxFileSize. Note that this + // value won't be written into the header, so maximum file size will always be + // specified in runtime and the caller should make sure its value is correct + // and reasonable. Note that file size includes size of header + elements. + // + // The range should be in + // [Header::kHeaderSize + kElementTypeSize, kMaxFileSize], and + // (max_file_size_ - Header::kHeaderSize) / kElementTypeSize is max # of + // elements that can be stored. + int32_t max_file_size_; }; template <typename T> +constexpr int32_t FileBackedVector<T>::kMaxFileSize; + +template <typename T> +constexpr int32_t FileBackedVector<T>::kElementTypeSize; + +template <typename T> constexpr int32_t FileBackedVector<T>::kPartialCrcLimitDiv; template <typename T> constexpr int64_t FileBackedVector<T>::kGrowElements; template <typename T> -constexpr int64_t FileBackedVector<T>::kMaxNumElements; +constexpr int32_t FileBackedVector<T>::kMaxNumElements; + +template <typename T> +constexpr int32_t FileBackedVector<T>::kMaxIndex; template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> FileBackedVector<T>::Create(const Filesystem& filesystem, const std::string& file_path, - MemoryMappedFile::Strategy mmap_strategy) { + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { if (mmap_strategy == MemoryMappedFile::Strategy::READ_WRITE_MANUAL_SYNC) { // FileBackedVector's behavior of growing the file underneath the mmap is // inherently broken with MAP_PRIVATE. Growing the vector requires extending @@ -345,6 +527,14 @@ FileBackedVector<T>::Create(const Filesystem& filesystem, "mmap strategy."); } + if (max_file_size < Header::kHeaderSize + kElementTypeSize || + max_file_size > kMaxFileSize) { + // FileBackedVector should be able to store at least 1 element, so + // max_file_size should be at least Header::kHeaderSize + kElementTypeSize. + return absl_ports::InvalidArgumentError( + "Invalid max file size for FileBackedVector"); + } + ScopedFd fd(filesystem.OpenForWrite(file_path.c_str())); if (!fd.is_valid()) { return absl_ports::InternalError( @@ -357,31 +547,38 @@ FileBackedVector<T>::Create(const Filesystem& filesystem, absl_ports::StrCat("Bad file size for file ", file_path)); } + if (max_file_size < file_size) { + return absl_ports::InvalidArgumentError( + "Max file size should not be smaller than the existing file size"); + } + const bool new_file = file_size == 0; if (new_file) { return InitializeNewFile(filesystem, file_path, std::move(fd), - mmap_strategy); + mmap_strategy, max_file_size); } return InitializeExistingFile(filesystem, file_path, std::move(fd), - mmap_strategy); + mmap_strategy, max_file_size); } template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> -FileBackedVector<T>::InitializeNewFile( - const Filesystem& filesystem, const std::string& file_path, ScopedFd fd, - MemoryMappedFile::Strategy mmap_strategy) { +FileBackedVector<T>::InitializeNewFile(const Filesystem& filesystem, + const std::string& file_path, + ScopedFd fd, + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { // Create header. auto header = std::make_unique<Header>(); header->magic = FileBackedVector<T>::Header::kMagic; - header->element_size = sizeof(T); + header->element_size = kElementTypeSize; header->header_checksum = header->CalculateHeaderChecksum(); // We use Write() here, instead of writing through the mmapped region // created below, so we can gracefully handle errors that occur when the // disk is full. See b/77309668 for details. if (!filesystem.PWrite(fd.get(), /*offset=*/0, header.get(), - sizeof(Header))) { + Header::kHeaderSize)) { return absl_ports::InternalError("Failed to write header"); } @@ -393,23 +590,30 @@ FileBackedVector<T>::InitializeNewFile( auto mmapped_file = std::make_unique<MemoryMappedFile>(filesystem, file_path, mmap_strategy); - return std::unique_ptr<FileBackedVector<T>>(new FileBackedVector<T>( - filesystem, file_path, std::move(header), std::move(mmapped_file))); + return std::unique_ptr<FileBackedVector<T>>( + new FileBackedVector<T>(filesystem, file_path, std::move(header), + std::move(mmapped_file), max_file_size)); } template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> FileBackedVector<T>::InitializeExistingFile( const Filesystem& filesystem, const std::string& file_path, - const ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy) { + const ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { int64_t file_size = filesystem.GetFileSize(file_path.c_str()); - if (file_size < sizeof(FileBackedVector<T>::Header)) { + if (file_size == Filesystem::kBadFileSize) { + return absl_ports::InternalError( + absl_ports::StrCat("Bad file size for file ", file_path)); + } + + if (file_size < Header::kHeaderSize) { return absl_ports::InternalError( absl_ports::StrCat("File header too short for ", file_path)); } auto header = std::make_unique<Header>(); - if (!filesystem.PRead(fd.get(), header.get(), sizeof(Header), + if (!filesystem.PRead(fd.get(), header.get(), Header::kHeaderSize, /*offset=*/0)) { return absl_ports::InternalError( absl_ports::StrCat("Failed to read header of ", file_path)); @@ -429,13 +633,15 @@ FileBackedVector<T>::InitializeExistingFile( absl_ports::StrCat("Invalid header crc for ", file_path)); } - if (header->element_size != sizeof(T)) { + if (header->element_size != kElementTypeSize) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( - "Inconsistent element size, expected %zd, actual %d", sizeof(T), + "Inconsistent element size, expected %d, actual %d", kElementTypeSize, header->element_size)); } - int64_t min_file_size = header->num_elements * sizeof(T) + sizeof(Header); + int64_t min_file_size = + static_cast<int64_t>(header->num_elements) * kElementTypeSize + + Header::kHeaderSize; if (min_file_size > file_size) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Inconsistent file size, expected %" PRId64 ", actual %" PRId64, @@ -446,23 +652,22 @@ FileBackedVector<T>::InitializeExistingFile( // access elements from the mmapped region auto mmapped_file = std::make_unique<MemoryMappedFile>(filesystem, file_path, mmap_strategy); - ICING_RETURN_IF_ERROR( - mmapped_file->Remap(sizeof(Header), file_size - sizeof(Header))); + ICING_RETURN_IF_ERROR(mmapped_file->Remap(Header::kHeaderSize, + file_size - Header::kHeaderSize)); // Check vector contents - Crc32 vector_checksum; - std::string_view vector_contents( - reinterpret_cast<const char*>(mmapped_file->region()), - header->num_elements * sizeof(T)); - vector_checksum.Append(vector_contents); + Crc32 vector_checksum( + std::string_view(reinterpret_cast<const char*>(mmapped_file->region()), + header->num_elements * kElementTypeSize)); if (vector_checksum.Get() != header->vector_checksum) { return absl_ports::FailedPreconditionError( absl_ports::StrCat("Invalid vector contents for ", file_path)); } - return std::unique_ptr<FileBackedVector<T>>(new FileBackedVector<T>( - filesystem, file_path, std::move(header), std::move(mmapped_file))); + return std::unique_ptr<FileBackedVector<T>>( + new FileBackedVector<T>(filesystem, file_path, std::move(header), + std::move(mmapped_file), max_file_size)); } template <typename T> @@ -479,12 +684,13 @@ template <typename T> FileBackedVector<T>::FileBackedVector( const Filesystem& filesystem, const std::string& file_path, std::unique_ptr<Header> header, - std::unique_ptr<MemoryMappedFile> mmapped_file) + std::unique_ptr<MemoryMappedFile> mmapped_file, int32_t max_file_size) : filesystem_(&filesystem), file_path_(file_path), header_(std::move(header)), mmapped_file_(std::move(mmapped_file)), - changes_end_(header_->num_elements) {} + changes_end_(header_->num_elements), + max_file_size_(max_file_size) {} template <typename T> FileBackedVector<T>::~FileBackedVector() { @@ -523,6 +729,40 @@ libtextclassifier3::StatusOr<const T*> FileBackedVector<T>::Get( } template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableView> +FileBackedVector<T>::GetMutable(int32_t idx) { + if (idx < 0) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); + } + + if (idx >= header_->num_elements) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index, %d, was greater than vector size, %d", idx, + header_->num_elements)); + } + + return MutableView(this, &mutable_array()[idx]); +} + +template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableArrayView> +FileBackedVector<T>::GetMutable(int32_t idx, int32_t len) { + if (idx < 0) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); + } + + if (idx > header_->num_elements - len) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index with len, %d %d, was greater than vector size, %d", idx, len, + header_->num_elements)); + } + + return MutableArrayView(this, &mutable_array()[idx], len); +} + +template <typename T> libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, const T& value) { if (idx < 0) { @@ -530,6 +770,11 @@ libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); } + if (idx > kMaxIndex) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index, %d, was greater than max index allowed, %d", idx, kMaxIndex)); + } + ICING_RETURN_IF_ERROR(GrowIfNecessary(idx + 1)); if (idx + 1 > header_->num_elements) { @@ -541,36 +786,39 @@ libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, return libtextclassifier3::Status::OK; } - // Cache original value to update crcs. - if (idx < changes_end_) { - // If we exceed kPartialCrcLimitDiv, clear changes_end_ to - // revert to full CRC. - if ((saved_original_buffer_.size() + sizeof(T)) * - FileBackedVector<T>::kPartialCrcLimitDiv > - changes_end_ * sizeof(T)) { - ICING_VLOG(2) << "FileBackedVector change tracking limit exceeded"; - changes_.clear(); - saved_original_buffer_.clear(); - changes_end_ = 0; - header_->vector_checksum = 0; - } else { - int32_t start_byte = idx * sizeof(T); - - changes_.push_back(idx); - saved_original_buffer_.append( - reinterpret_cast<char*>(const_cast<T*>(array())) + start_byte, - sizeof(T)); - } - } + SetDirty(idx); mutable_array()[idx] = value; return libtextclassifier3::Status::OK; } template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableArrayView> +FileBackedVector<T>::Allocate(int32_t len) { + if (len <= 0) { + return absl_ports::OutOfRangeError("Invalid allocate length"); + } + + if (len > kMaxNumElements - header_->num_elements) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Cannot allocate %d elements", len)); + } + + // Although header_->num_elements + len doesn't exceed kMaxNumElements, the + // actual max # of elements are determined by max_file_size, kElementTypeSize, + // and kHeaderSize. Thus, it is still possible to fail to grow the file. + ICING_RETURN_IF_ERROR(GrowIfNecessary(header_->num_elements + len)); + + int32_t start_idx = header_->num_elements; + header_->num_elements += len; + + return MutableArrayView(this, &mutable_array()[start_idx], len); +} + +template <typename T> libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( int32_t num_elements) { - if (sizeof(T) == 0) { + if (kElementTypeSize == 0) { // Growing is a no-op return libtextclassifier3::Status::OK; } @@ -579,10 +827,12 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( return libtextclassifier3::Status::OK; } - if (num_elements > FileBackedVector<T>::kMaxNumElements) { + if (num_elements > + (max_file_size_ - Header::kHeaderSize) / kElementTypeSize) { return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( - "%d exceeds maximum number of elements allowed, %lld", num_elements, - static_cast<long long>(FileBackedVector<T>::kMaxNumElements))); + "%d elements total size exceed maximum bytes of elements allowed, " + "%d bytes", + num_elements, max_file_size_ - Header::kHeaderSize)); } int64_t current_file_size = filesystem_->GetFileSize(file_path_.c_str()); @@ -590,7 +840,8 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( return absl_ports::InternalError("Unable to retrieve file size."); } - int64_t least_file_size_needed = sizeof(Header) + num_elements * sizeof(T); + int32_t least_file_size_needed = + Header::kHeaderSize + num_elements * kElementTypeSize; // Won't overflow if (least_file_size_needed <= current_file_size) { // Our underlying file can hold the target num_elements cause we've grown // before @@ -598,9 +849,13 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( } // Otherwise, we need to grow. Grow to kGrowElements boundary. - least_file_size_needed = math_util::RoundUpTo( - least_file_size_needed, - int64_t{FileBackedVector<T>::kGrowElements * sizeof(T)}); + // Note that we need to use int64_t here, since int32_t might overflow after + // round up. + int64_t round_up_file_size_needed = math_util::RoundUpTo( + int64_t{least_file_size_needed}, + int64_t{FileBackedVector<T>::kGrowElements} * kElementTypeSize); + least_file_size_needed = + std::min(round_up_file_size_needed, int64_t{max_file_size_}); // We use PWrite here rather than Grow because Grow doesn't actually allocate // an underlying disk block. This can lead to problems with mmap because mmap @@ -609,20 +864,22 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( // these blocks, which will ensure that any failure to grow will surface here. int64_t page_size = getpagesize(); auto buf = std::make_unique<uint8_t[]>(page_size); - int64_t size_to_write = page_size - (current_file_size % page_size); + int64_t size_to_write = std::min(page_size - (current_file_size % page_size), + max_file_size_ - current_file_size); ScopedFd sfd(filesystem_->OpenForWrite(file_path_.c_str())); - while (current_file_size < least_file_size_needed) { + while (size_to_write > 0 && current_file_size < least_file_size_needed) { if (!filesystem_->PWrite(sfd.get(), current_file_size, buf.get(), size_to_write)) { return absl_ports::InternalError( absl_ports::StrCat("Couldn't grow file ", file_path_)); } current_file_size += size_to_write; - size_to_write = page_size - (current_file_size % page_size); + size_to_write = std::min(page_size - (current_file_size % page_size), + max_file_size_ - current_file_size); } ICING_RETURN_IF_ERROR(mmapped_file_->Remap( - sizeof(Header), least_file_size_needed - sizeof(Header))); + Header::kHeaderSize, least_file_size_needed - Header::kHeaderSize)); return libtextclassifier3::Status::OK; } @@ -653,6 +910,31 @@ libtextclassifier3::Status FileBackedVector<T>::TruncateTo( } template <typename T> +void FileBackedVector<T>::SetDirty(int32_t idx) { + // Cache original value to update crcs. + if (idx >= 0 && idx < changes_end_) { + // If we exceed kPartialCrcLimitDiv, clear changes_end_ to + // revert to full CRC. + if ((saved_original_buffer_.size() + kElementTypeSize) * + FileBackedVector<T>::kPartialCrcLimitDiv > + changes_end_ * kElementTypeSize) { + ICING_VLOG(2) << "FileBackedVector change tracking limit exceeded"; + changes_.clear(); + saved_original_buffer_.clear(); + changes_end_ = 0; + header_->vector_checksum = 0; + } else { + int32_t start_byte = idx * kElementTypeSize; + + changes_.push_back(idx); + saved_original_buffer_.append( + reinterpret_cast<char*>(const_cast<T*>(array())) + start_byte, + kElementTypeSize); + } + } +} + +template <typename T> libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { // First apply the modified area. Keep a bitmap of already updated // regions so we don't double-update. @@ -663,8 +945,7 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { int num_truncated = 0; int num_overlapped = 0; int num_duplicate = 0; - for (size_t i = 0; i < changes_.size(); i++) { - const int32_t change_offset = changes_[i]; + for (const int32_t change_offset : changes_) { if (change_offset > changes_end_) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to update crc, change offset %d, changes_end_ %d", @@ -678,9 +959,10 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { } // Turn change buffer into change^original. - const char* buffer_end = &saved_original_buffer_[cur_offset + sizeof(T)]; - const char* cur_array = - reinterpret_cast<const char*>(array()) + change_offset * sizeof(T); + const char* buffer_end = + &saved_original_buffer_[cur_offset + kElementTypeSize]; + const char* cur_array = reinterpret_cast<const char*>(array()) + + change_offset * kElementTypeSize; // Now xor in. SSE acceleration please? for (char* cur = &saved_original_buffer_[cur_offset]; cur < buffer_end; cur++, cur_array++) { @@ -692,9 +974,9 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { bool overlap = false; uint32_t cur_element = change_offset; for (char* cur = &saved_original_buffer_[cur_offset]; cur < buffer_end; - cur_element++, cur += sizeof(T)) { + cur_element++, cur += kElementTypeSize) { if (updated[cur_element]) { - memset(cur, 0, sizeof(T)); + memset(cur, 0, kElementTypeSize); overlap = true; } else { updated[cur_element] = true; @@ -705,10 +987,11 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { // Apply update to crc. if (new_update) { // Explicitly create the string_view with length - std::string_view xored_str(buffer_end - sizeof(T), sizeof(T)); + std::string_view xored_str(buffer_end - kElementTypeSize, + kElementTypeSize); if (!cur_crc - .UpdateWithXor(xored_str, changes_end_ * sizeof(T), - change_offset * sizeof(T)) + .UpdateWithXor(xored_str, changes_end_ * kElementTypeSize, + change_offset * kElementTypeSize) .ok()) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to update crc, change offset %d, change " @@ -722,7 +1005,7 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { } else { num_duplicate++; } - cur_offset += sizeof(T); + cur_offset += kElementTypeSize; } if (!changes_.empty()) { @@ -735,8 +1018,9 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { if (changes_end_ < header_->num_elements) { // Explicitly create the string_view with length std::string_view update_str( - reinterpret_cast<const char*>(array()) + changes_end_ * sizeof(T), - (header_->num_elements - changes_end_) * sizeof(T)); + reinterpret_cast<const char*>(array()) + + changes_end_ * kElementTypeSize, + (header_->num_elements - changes_end_) * kElementTypeSize); cur_crc.Append(update_str); ICING_VLOG(2) << IcingStringUtil::StringPrintf( "Array update tail crc offset %d -> %d", changes_end_, @@ -761,7 +1045,7 @@ libtextclassifier3::Status FileBackedVector<T>::PersistToDisk() { header_->header_checksum = header_->CalculateHeaderChecksum(); if (!filesystem_->PWrite(file_path_.c_str(), /*offset=*/0, header_.get(), - sizeof(Header))) { + Header::kHeaderSize)) { return absl_ports::InternalError("Failed to sync header"); } @@ -795,7 +1079,11 @@ libtextclassifier3::StatusOr<int64_t> FileBackedVector<T>::GetElementsFileSize() return absl_ports::InternalError( "Failed to get file size of elements in the file-backed vector"); } - return total_file_size - sizeof(Header); + if (total_file_size < Header::kHeaderSize) { + return absl_ports::InternalError( + "File size should not be smaller than header size"); + } + return total_file_size - Header::kHeaderSize; } } // namespace lib diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc index ed94fa5..60ed887 100644 --- a/icing/file/file-backed-vector_test.cc +++ b/icing/file/file-backed-vector_test.cc @@ -19,25 +19,31 @@ #include <algorithm> #include <cerrno> #include <cstdint> +#include <limits> #include <memory> +#include <string> #include <string_view> #include <vector> -#include "knowledge/cerebra/sense/text_classifier/lib3/utils/base/status.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" -#include "third_party/icing/file/filesystem.h" -#include "third_party/icing/file/memory-mapped-file.h" -#include "third_party/icing/file/mock-filesystem.h" -#include "third_party/icing/testing/common-matchers.h" -#include "third_party/icing/testing/tmp-directory.h" -#include "third_party/icing/util/crc32.h" -#include "third_party/icing/util/logging.h" - +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/mock-filesystem.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/crc32.h" +#include "icing/util/logging.h" + +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsTrue; +using ::testing::Lt; +using ::testing::Not; using ::testing::Pointee; using ::testing::Return; +using ::testing::SizeIs; namespace icing { namespace lib { @@ -60,20 +66,30 @@ class FileBackedVectorTest : public testing::Test { // Helper method to loop over some data and insert into the vector at some idx template <typename T> - void Insert(FileBackedVector<T>* vector, int32_t idx, std::string data) { - for (int i = 0; i < data.length(); ++i) { + void Insert(FileBackedVector<T>* vector, int32_t idx, + const std::vector<T>& data) { + for (int i = 0; i < data.size(); ++i) { ICING_ASSERT_OK(vector->Set(idx + i, data.at(i))); } } + void Insert(FileBackedVector<char>* vector, int32_t idx, std::string data) { + Insert(vector, idx, std::vector<char>(data.begin(), data.end())); + } + // Helper method to retrieve data from the beginning of the vector template <typename T> - std::string_view Get(FileBackedVector<T>* vector, int32_t expected_len) { + std::vector<T> Get(FileBackedVector<T>* vector, int32_t idx, + int32_t expected_len) { + return std::vector<T>(vector->array() + idx, + vector->array() + idx + expected_len); + } + + std::string_view Get(FileBackedVector<char>* vector, int32_t expected_len) { return Get(vector, 0, expected_len); } - template <typename T> - std::string_view Get(FileBackedVector<T>* vector, int32_t idx, + std::string_view Get(FileBackedVector<char>* vector, int32_t idx, int32_t expected_len) { return std::string_view(vector->array() + idx, expected_len); } @@ -103,6 +119,79 @@ TEST_F(FileBackedVectorTest, Create) { } } +TEST_F(FileBackedVectorTest, CreateWithInvalidStrategy) { + // Create a vector with unimplemented strategy + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_MANUAL_SYNC), + StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED)); +} + +TEST_F(FileBackedVectorTest, CreateWithCustomMaxFileSize) { + int32_t header_size = FileBackedVector<char>::Header::kHeaderSize; + + // Create a vector with invalid max_file_size + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/-1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + { + // Create a vector with max_file_size that allows only 1 element. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) * 1)); + ICING_ASSERT_OK(vector->Set(0, 'a')); + } + + { + // We can create it again with larger max_file_size, as long as it is not + // greater than kMaxFileSize. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) * 2)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + ICING_ASSERT_OK(vector->Set(1, 'b')); + } + + // We cannot create it again with max_file_size < current_file_size, even if + // it is a valid value. + int64_t current_file_size = filesystem_.GetFileSize(file_path_.c_str()); + ASSERT_THAT(current_file_size, Eq(header_size + sizeof(char) * 2)); + ASSERT_THAT(current_file_size - 1, Not(Lt(header_size + sizeof(char)))); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/current_file_size - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + { + // We can create it again with max_file_size == current_file_size. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/current_file_size)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + EXPECT_THAT(vector->Get(1), IsOkAndHolds(Pointee(Eq('b')))); + } +} + TEST_F(FileBackedVectorTest, SimpleShared) { // Create a vector and add some data. ICING_ASSERT_OK_AND_ASSIGN( @@ -195,6 +284,373 @@ TEST_F(FileBackedVectorTest, Get) { StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); } +TEST_F(FileBackedVectorTest, MutableView) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'a')); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2620640643U))); + + ICING_ASSERT_OK_AND_ASSIGN(FileBackedVector<char>::MutableView mutable_elt, + vector->GetMutable(3)); + + mutable_elt.Get() = 'b'; + EXPECT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('b')))); + + mutable_elt.Get() = 'c'; + EXPECT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('c')))); +} + +TEST_F(FileBackedVectorTest, MutableViewShouldSetDirty) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'a')); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2620640643U))); + + std::string_view reconstructed_view = + std::string_view(vector->array(), vector->num_elements()); + + ICING_ASSERT_OK_AND_ASSIGN(FileBackedVector<char>::MutableView mutable_elt, + vector->GetMutable(3)); + + // Mutate the element via MutateView + // If non-const Get() is called, MutateView should set the element index dirty + // so that ComputeChecksum() can pick up the change and compute the checksum + // correctly. Validate by mapping another array on top. + mutable_elt.Get() = 'b'; + ASSERT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('b')))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + Crc32 full_crc1; + full_crc1.Append(reconstructed_view); + EXPECT_THAT(crc1, Eq(full_crc1)); + + // Mutate and test again. + mutable_elt.Get() = 'c'; + ASSERT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('c')))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + Crc32 full_crc2; + full_crc2.Append(reconstructed_view); + EXPECT_THAT(crc2, Eq(full_crc2)); +} + +TEST_F(FileBackedVectorTest, MutableArrayView) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, /*len=*/3)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + + mutable_arr[0] = 2; + mutable_arr[1] = 3; + mutable_arr[2] = 4; + + EXPECT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(2)))); + EXPECT_THAT(mutable_arr.data()[0], Eq(2)); + + EXPECT_THAT(vector->Get(kArrayViewOffset + 1), IsOkAndHolds(Pointee(Eq(3)))); + EXPECT_THAT(mutable_arr.data()[1], Eq(3)); + + EXPECT_THAT(vector->Get(kArrayViewOffset + 2), IsOkAndHolds(Pointee(Eq(4)))); + EXPECT_THAT(mutable_arr.data()[2], Eq(4)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArray) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + std::vector<int> change1{2, 3, 4}; + mutable_arr.SetArray(/*idx=*/0, change1.data(), change1.size()); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 4, 1, 1)); + + std::vector<int> change2{5, 6}; + mutable_arr.SetArray(/*idx=*/2, change2.data(), change2.size()); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 5, 6, 1)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArrayWithZeroLength) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + // Zero arr_len should work and change nothing + std::vector<int> change{2, 3}; + mutable_arr.SetArray(/*idx=*/0, change.data(), /*arr_len=*/0); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(1, 1, 1, 1, 1)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewIndexOperatorShouldSetDirty) { + // Create an array with some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + std::string_view reconstructed_view( + reinterpret_cast<const char*>(vector->array()), + vector->num_elements() * sizeof(int)); + + constexpr int kArrayViewOffset = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, /*len=*/3)); + + // Use operator[] to mutate elements + // If non-const operator[] is called, MutateView should set the element index + // dirty so that ComputeChecksum() can pick up the change and compute the + // checksum correctly. Validate by mapping another array on top. + mutable_arr[0] = 2; + ASSERT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(2)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + EXPECT_THAT(crc1, Eq(Crc32(reconstructed_view))); + + mutable_arr[1] = 3; + ASSERT_THAT(vector->Get(kArrayViewOffset + 1), IsOkAndHolds(Pointee(Eq(3)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + EXPECT_THAT(crc2, Eq(Crc32(reconstructed_view))); + + mutable_arr[2] = 4; + ASSERT_THAT(vector->Get(kArrayViewOffset + 2), IsOkAndHolds(Pointee(Eq(4)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc3, vector->ComputeChecksum()); + EXPECT_THAT(crc3, Eq(Crc32(reconstructed_view))); + + // Change the same position. It should set dirty again. + mutable_arr[0] = 5; + ASSERT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(5)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc4, vector->ComputeChecksum()); + EXPECT_THAT(crc4, Eq(Crc32(reconstructed_view))); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArrayShouldSetDirty) { + // Create an array with some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + std::string_view reconstructed_view( + reinterpret_cast<const char*>(vector->array()), + vector->num_elements() * sizeof(int)); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + std::vector<int> change{2, 3, 4}; + mutable_arr.SetArray(/*idx=*/0, change.data(), change.size()); + ASSERT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 4, 1, 1)); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc, vector->ComputeChecksum()); + EXPECT_THAT(crc, Eq(Crc32(reconstructed_view))); +} + +TEST_F(FileBackedVectorTest, Append) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(1)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + + ICING_EXPECT_OK(vector->Append('b')); + EXPECT_THAT(vector->num_elements(), Eq(2)); + EXPECT_THAT(vector->Get(1), IsOkAndHolds(Pointee(Eq('b')))); +} + +TEST_F(FileBackedVectorTest, AppendAfterSet) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK(vector->Set(9, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(10)); + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(11)); + EXPECT_THAT(vector->Get(10), IsOkAndHolds(Pointee(Eq('a')))); +} + +TEST_F(FileBackedVectorTest, AppendAfterTruncate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(1000)); + + ICING_ASSERT_OK(vector->TruncateTo(5)); + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(6)); + EXPECT_THAT(vector->Get(5), IsOkAndHolds(Pointee(Eq('a')))); +} + +TEST_F(FileBackedVectorTest, AppendShouldFailIfExceedingMaxFileSize) { + int32_t max_file_size = (1 << 10) - 1; + int32_t max_num_elements = + (max_file_size - FileBackedVector<char>::Header::kHeaderSize) / + sizeof(char); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); + ICING_ASSERT_OK(vector->Set(max_num_elements - 1, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(max_num_elements)); + + EXPECT_THAT(vector->Append('a'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + +TEST_F(FileBackedVectorTest, Allocate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(3)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateAfterSet) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK(vector->Set(9, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(10)); + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(13)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/10, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateAfterTruncate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(1000)); + + ICING_ASSERT_OK(vector->TruncateTo(5)); + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(8)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/5, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateInvalidLengthShouldFail) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + EXPECT_THAT(vector->Allocate(-1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->num_elements(), Eq(0)); + + EXPECT_THAT(vector->Allocate(0), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->num_elements(), Eq(0)); +} + +TEST_F(FileBackedVectorTest, AllocateShouldFailIfExceedingMaxFileSize) { + int32_t max_file_size = (1 << 10) - 1; + int32_t max_num_elements = + (max_file_size - FileBackedVector<char>::Header::kHeaderSize) / + sizeof(char); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); + ICING_ASSERT_OK(vector->Set(max_num_elements - 3, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(max_num_elements - 2)); + + EXPECT_THAT(vector->Allocate(3), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Allocate(2), IsOk()); +} + TEST_F(FileBackedVectorTest, IncrementalCrc_NonOverlappingChanges) { int num_elements = 1000; int incremental_size = 3; @@ -272,29 +728,58 @@ TEST_F(FileBackedVectorTest, IncrementalCrc_OverlappingChanges) { } } +TEST_F(FileBackedVectorTest, SetIntMaxShouldReturnOutOfRangeError) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int32_t>> vector, + FileBackedVector<int32_t>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); + + // It is an edge case. Since Set() calls GrowIfNecessary(idx + 1), we have to + // make sure that when idx is INT32_MAX, Set() should handle it correctly. + EXPECT_THAT(vector->Set(std::numeric_limits<int32_t>::max(), 1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + TEST_F(FileBackedVectorTest, Grow) { - // This is the same value as FileBackedVector::kMaxNumElts - constexpr int32_t kMaxNumElts = 1U << 20; + int32_t max_file_size = (1 << 20) - 1; + int32_t header_size = FileBackedVector<int32_t>::Header::kHeaderSize; + int32_t element_type_size = static_cast<int32_t>(sizeof(int32_t)); + + // Max file size includes size of the header and elements, so max # of + // elements will be (max_file_size - header_size) / element_type_size. + // + // Also ensure that (max_file_size - header_size) is not a multiple of + // element_type_size, in order to test if the desired # of elements is + // computed by (math) floor instead of ceil. + ASSERT_THAT((max_file_size - header_size) % element_type_size, Not(Eq(0))); + int32_t max_num_elements = (max_file_size - header_size) / element_type_size; ASSERT_TRUE(filesystem_.Truncate(fd_, 0)); - // Create an array and add some data. + // Create a vector and add some data. ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<FileBackedVector<char>> vector, - FileBackedVector<char>::Create( + std::unique_ptr<FileBackedVector<int32_t>> vector, + FileBackedVector<int32_t>::Create( filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); - EXPECT_THAT(vector->Set(kMaxNumElts + 11, 'a'), + // max_num_elements is the allowed max # of elements, so the valid index + // should be 0 to max_num_elements-1. + EXPECT_THAT(vector->Set(max_num_elements, 1), StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); - EXPECT_THAT(vector->Set(-1, 'a'), + EXPECT_THAT(vector->Set(-1, 1), StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Set(max_num_elements - 1, 1), IsOk()); - uint32_t start = kMaxNumElts - 13; - Insert(vector.get(), start, "abcde"); + int32_t start = max_num_elements - 5; + std::vector<int32_t> data{1, 2, 3, 4, 5}; + Insert(vector.get(), start, data); // Crc works? - const Crc32 good_crc(1134899064U); + const Crc32 good_crc(650981917U); EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(good_crc)); // PersistToDisk does nothing bad, and ensures the content is still there @@ -306,12 +791,12 @@ TEST_F(FileBackedVectorTest, Grow) { vector.reset(); ICING_ASSERT_OK_AND_ASSIGN( - vector, FileBackedVector<char>::Create( - filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + vector, + FileBackedVector<int32_t>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); - std::string expected = "abcde"; - EXPECT_EQ(expected, Get(vector.get(), start, expected.length())); + EXPECT_THAT(Get(vector.get(), start, data.size()), Eq(data)); } TEST_F(FileBackedVectorTest, GrowsInChunks) { @@ -334,20 +819,20 @@ TEST_F(FileBackedVectorTest, GrowsInChunks) { // Once we add something though, we'll grow to be kGrowElements big. From this // point on, file size and disk usage should be the same because Growing will // explicitly allocate the number of blocks needed to accomodate the file. - Insert(vector.get(), 0, "a"); - int file_size = kGrowElements * sizeof(int); + Insert(vector.get(), 0, {1}); + int file_size = 1 * kGrowElements * sizeof(int); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); // Should still be the same size, don't need to grow underlying file - Insert(vector.get(), 1, "b"); + Insert(vector.get(), 1, {2}); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); // Now we grow by a kGrowElements chunk, so the underlying file is 2 // kGrowElements big - file_size *= 2; - Insert(vector.get(), 2, std::string(kGrowElements, 'c')); + file_size = 2 * kGrowElements * sizeof(int); + Insert(vector.get(), 2, std::vector<int>(kGrowElements, 3)); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); @@ -476,6 +961,48 @@ TEST_F(FileBackedVectorTest, TruncateAndReReadFile) { } } +TEST_F(FileBackedVectorTest, SetDirty) { + // 1. Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), 0, "abcd"); + + std::string_view reconstructed_view = + std::string_view(vector->array(), vector->num_elements()); + + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + Crc32 full_crc_before_overwrite; + full_crc_before_overwrite.Append(reconstructed_view); + EXPECT_THAT(crc1, Eq(full_crc_before_overwrite)); + + // 2. Manually overwrite the values of the first two elements. + std::string corrupted_content = "ef"; + ASSERT_THAT( + filesystem_.PWrite(fd_, /*offset=*/sizeof(FileBackedVector<char>::Header), + corrupted_content.c_str(), corrupted_content.length()), + IsTrue()); + ASSERT_THAT(Get(vector.get(), 0, 4), Eq("efcd")); + Crc32 full_crc_after_overwrite; + full_crc_after_overwrite.Append(reconstructed_view); + ASSERT_THAT(full_crc_before_overwrite, Not(Eq(full_crc_after_overwrite))); + + // 3. Without calling SetDirty(), the checksum will be recomputed incorrectly. + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + EXPECT_THAT(crc2, Not(Eq(full_crc_after_overwrite))); + + // 4. Call SetDirty() + vector->SetDirty(0); + vector->SetDirty(1); + + // 5. The checksum should be computed correctly after calling SetDirty() with + // correct index. + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc3, vector->ComputeChecksum()); + EXPECT_THAT(crc3, Eq(full_crc_after_overwrite)); +} + TEST_F(FileBackedVectorTest, InitFileTooSmallForHeaderFails) { { // 1. Create a vector with a few elements. @@ -662,7 +1189,7 @@ TEST_F(FileBackedVectorTest, RemapFailureStillValidInstance) { // 2. The next Set call should cause a resize and a remap. Make that remap // fail. int num_calls = 0; - auto open_lambda = [this, &num_calls](const char* file_name){ + auto open_lambda = [this, &num_calls](const char* file_name) { if (++num_calls == 2) { return -1; } diff --git a/icing/file/filesystem.cc b/icing/file/filesystem.cc index 82b8d98..10b77db 100644 --- a/icing/file/filesystem.cc +++ b/icing/file/filesystem.cc @@ -63,18 +63,16 @@ void LogOpenFileDescriptors() { constexpr int kMaxFileDescriptorsToStat = 4096; struct rlimit rlim = {0, 0}; if (getrlimit(RLIMIT_NOFILE, &rlim) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "getrlimit() failed (errno=%d)", errno); + ICING_LOG(ERROR) << "getrlimit() failed (errno=" << errno << ")"; return; } int fd_lim = rlim.rlim_cur; if (fd_lim > kMaxFileDescriptorsToStat) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Maximum number of file descriptors (%d) too large.", fd_lim); + ICING_LOG(ERROR) << "Maximum number of file descriptors (" << fd_lim + << ") too large."; fd_lim = kMaxFileDescriptorsToStat; } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Listing up to %d file descriptors.", fd_lim); + ICING_LOG(ERROR) << "Listing up to " << fd_lim << " file descriptors."; // Verify that /proc/self/fd is a directory. If not, procfs is not mounted or // inaccessible for some other reason. In that case, there's no point trying @@ -96,15 +94,12 @@ void LogOpenFileDescriptors() { if (len >= 0) { // Zero-terminate the buffer, because readlink() won't. target[len < target_size ? len : target_size - 1] = '\0'; - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> \"%s\"", fd, - target); + ICING_LOG(ERROR) << "fd " << fd << " -> \"" << target << "\""; } else if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> ? (errno=%d)", - fd, errno); + ICING_LOG(ERROR) << "fd " << fd << " -> ? (errno=" << errno << ")"; } } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "File descriptor list complete."); + ICING_LOG(ERROR) << "File descriptor list complete."; } // Logs an error formatted as: desc1 + file_name + desc2 + strerror(errnum). @@ -113,8 +108,7 @@ void LogOpenFileDescriptors() { // file descriptors (see LogOpenFileDescriptors() above). void LogOpenError(const char* desc1, const char* file_name, const char* desc2, int errnum) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "%s%s%s%s", desc1, file_name, desc2, strerror(errnum)); + ICING_LOG(ERROR) << desc1 << file_name << desc2 << strerror(errnum); if (errnum == EMFILE) { LogOpenFileDescriptors(); } @@ -155,8 +149,7 @@ bool ListDirectoryInternal(const char* dir_name, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Error closing %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << dir_name << " " << strerror(errno); } return true; } @@ -179,11 +172,10 @@ void ScopedFd::reset(int fd) { const int64_t Filesystem::kBadFileSize; bool Filesystem::DeleteFile(const char* file_name) const { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Deleting file %s", file_name); + ICING_VLOG(1) << "Deleting file " << file_name; int ret = unlink(file_name); if (ret != 0 && errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting file %s failed: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting file " << file_name << " failed: " << strerror(errno); return false; } return true; @@ -192,8 +184,7 @@ bool Filesystem::DeleteFile(const char* file_name) const { bool Filesystem::DeleteDirectory(const char* dir_name) const { int ret = rmdir(dir_name); if (ret != 0 && errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting directory " << dir_name << " failed: " << strerror(errno); return false; } return true; @@ -206,8 +197,7 @@ bool Filesystem::DeleteDirectoryRecursively(const char* dir_name) const { if (errno == ENOENT) { return true; // If directory didn't exist, this was successful. } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Stat " << dir_name << " failed: " << strerror(errno); return false; } vector<std::string> entries; @@ -220,8 +210,7 @@ bool Filesystem::DeleteDirectoryRecursively(const char* dir_name) const { ++i) { std::string filename = std::string(dir_name) + '/' + *i; if (stat(filename.c_str(), &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", filename.c_str(), strerror(errno)); + ICING_LOG(ERROR) << "Stat " << filename << " failed: " << strerror(errno); success = false; } else if (S_ISDIR(st.st_mode)) { success = DeleteDirectoryRecursively(filename.c_str()) && success; @@ -244,8 +233,7 @@ bool Filesystem::FileExists(const char* file_name) const { exists = S_ISREG(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << file_name << ": " << strerror(errno); } exists = false; } @@ -259,8 +247,7 @@ bool Filesystem::DirectoryExists(const char* dir_name) const { exists = S_ISDIR(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat directory %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat directory " << dir_name << ": " << strerror(errno); } exists = false; } @@ -316,8 +303,7 @@ bool Filesystem::GetMatchingFiles(const char* glob, int basename_idx = GetBasenameIndex(glob); if (basename_idx == 0) { // We need a directory. - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Expected directory, no matching files for: %s", glob); + ICING_VLOG(1) << "Expected directory, no matching files for: " << glob; return true; } const char* basename_glob = glob + basename_idx; @@ -372,8 +358,7 @@ int Filesystem::OpenForRead(const char* file_name) const { int64_t Filesystem::GetFileSize(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_size; @@ -383,11 +368,9 @@ int64_t Filesystem::GetFileSize(const char* filename) const { struct stat st; if (stat(filename, &st) < 0) { if (errno == ENOENT) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_VLOG(1) << "Unable to stat file " << filename << ": " << strerror(errno); } else { - ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_LOG(WARNING) << "Unable to stat file " << filename << ": " << strerror(errno); } return kBadFileSize; } @@ -396,8 +379,7 @@ int64_t Filesystem::GetFileSize(const char* filename) const { bool Filesystem::Truncate(int fd, int64_t new_size) const { if (ftruncate(fd, new_size) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to truncate file: %s", strerror(errno)); + ICING_LOG(ERROR) << "Unable to truncate file: " << strerror(errno); return false; } lseek(fd, new_size, SEEK_SET); @@ -416,8 +398,7 @@ bool Filesystem::Truncate(const char* filename, int64_t new_size) const { bool Filesystem::Grow(int fd, int64_t new_size) const { if (ftruncate(fd, new_size) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to grow file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to grow file: " << strerror(errno); return false; } @@ -442,8 +423,7 @@ bool Filesystem::Write(int fd, const void* data, size_t data_size) const { size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = write(fd, data, chunk_size); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t*>(data) + wrote; @@ -521,8 +501,7 @@ bool Filesystem::CopyDirectory(const char* src_dir, const char* dst_dir, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Error closing %s: %s", - src_dir, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << src_dir << ": " << strerror(errno); } return true; } @@ -535,8 +514,7 @@ bool Filesystem::PWrite(int fd, off_t offset, const void* data, size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = pwrite(fd, data, chunk_size, offset); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t*>(data) + wrote; @@ -561,8 +539,7 @@ bool Filesystem::PWrite(const char* filename, off_t offset, const void* data, bool Filesystem::Read(int fd, void* buf, size_t buf_size) const { ssize_t read_status = read(fd, buf, buf_size); if (read_status < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad read: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad read: " << strerror(errno); return false; } return true; @@ -582,8 +559,7 @@ bool Filesystem::Read(const char* filename, void* buf, size_t buf_size) const { bool Filesystem::PRead(int fd, void* buf, size_t buf_size, off_t offset) const { ssize_t read_status = pread(fd, buf, buf_size, offset); if (read_status < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad read: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad read: " << strerror(errno); return false; } return true; @@ -609,8 +585,7 @@ bool Filesystem::DataSync(int fd) const { #endif if (result < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to sync data: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to sync data: " << strerror(errno); return false; } return true; @@ -618,9 +593,7 @@ bool Filesystem::DataSync(int fd) const { bool Filesystem::RenameFile(const char* old_name, const char* new_name) const { if (rename(old_name, new_name) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to rename file %s to %s: %s", old_name, new_name, - strerror(errno)); + ICING_LOG(ERROR) << "Unable to rename file " << old_name << " to " << new_name << ": " << strerror(errno); return false; } return true; @@ -658,8 +631,7 @@ bool Filesystem::CreateDirectory(const char* dir_name) const { if (mkdir(dir_name, S_IRUSR | S_IWUSR | S_IXUSR) == 0) { success = true; } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Creating directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Creating directory " << dir_name << " failed: " << strerror(errno); } } return success; @@ -679,8 +651,7 @@ bool Filesystem::CreateDirectoryRecursively(const char* dir_name) const { int64_t Filesystem::GetDiskUsage(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -689,8 +660,7 @@ int64_t Filesystem::GetDiskUsage(int fd) const { int64_t Filesystem::GetFileDiskUsage(const char* path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -699,8 +669,7 @@ int64_t Filesystem::GetFileDiskUsage(const char* path) const { int64_t Filesystem::GetDiskUsage(const char* path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } int64_t result = st.st_blocks * kStatBlockSize; diff --git a/icing/file/memory-mapped-file.cc b/icing/file/memory-mapped-file.cc index 9ff3adb..fc13a79 100644 --- a/icing/file/memory-mapped-file.cc +++ b/icing/file/memory-mapped-file.cc @@ -73,8 +73,6 @@ libtextclassifier3::Status MemoryMappedFile::Remap(size_t file_offset, if (mmap_size == 0) { // First unmap any previously mmapped region. Unmap(); - - // Nothing more to do. return libtextclassifier3::Status::OK; } @@ -122,6 +120,7 @@ libtextclassifier3::Status MemoryMappedFile::Remap(size_t file_offset, mmap_flags, fd.get(), aligned_offset); if (mmap_result == MAP_FAILED) { + mmap_result = nullptr; return absl_ports::InternalError(absl_ports::StrCat( "Failed to mmap region due to error: ", strerror(errno))); } diff --git a/icing/file/persistent-hash-map.cc b/icing/file/persistent-hash-map.cc new file mode 100644 index 0000000..d20285a --- /dev/null +++ b/icing/file/persistent-hash-map.cc @@ -0,0 +1,534 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/persistent-hash-map.h" + +#include <cstring> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/util/crc32.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +// Helper function to check if there is no termination character '\0' in the +// key. +libtextclassifier3::Status ValidateKey(std::string_view key) { + if (key.find('\0') != std::string_view::npos) { // NOLINT + return absl_ports::InvalidArgumentError( + "Key cannot contain termination character '\\0'"); + } + return libtextclassifier3::Status::OK; +} + +// Helper function to convert the key to bucket index by hash. +// +// Returns: +// int32_t: A valid bucket index with range [0, num_buckets - 1]. +// INTERNAL_ERROR if num_buckets == 0 +libtextclassifier3::StatusOr<int32_t> HashKeyToBucketIndex( + std::string_view key, int32_t num_buckets) { + if (num_buckets == 0) { + return absl_ports::InternalError("Should not have empty bucket"); + } + return static_cast<int32_t>(std::hash<std::string_view>()(key) % num_buckets); +} + +// Helper function to PWrite crcs and info to metadata_file_path. Note that +// metadata_file_path will be the normal or temporary (for branching use when +// rehashing) metadata file path. +libtextclassifier3::Status WriteMetadata(const Filesystem& filesystem, + const char* metadata_file_path, + const PersistentHashMap::Crcs* crcs, + const PersistentHashMap::Info* info) { + ScopedFd sfd(filesystem.OpenForWrite(metadata_file_path)); + if (!sfd.is_valid()) { + return absl_ports::InternalError("Failed to create metadata file"); + } + + // Write crcs and info. File layout: <Crcs><Info> + if (!filesystem.PWrite(sfd.get(), PersistentHashMap::Crcs::kFileOffset, crcs, + sizeof(PersistentHashMap::Crcs))) { + return absl_ports::InternalError("Failed to write crcs into metadata file"); + } + // Note that PWrite won't change the file offset, so we need to specify + // the correct offset when writing Info. + if (!filesystem.PWrite(sfd.get(), PersistentHashMap::Info::kFileOffset, info, + sizeof(PersistentHashMap::Info))) { + return absl_ports::InternalError("Failed to write info into metadata file"); + } + + return libtextclassifier3::Status::OK; +} + +// Helper function to update checksums from info and storages to a Crcs +// instance. Note that storages will be the normal instances used by +// PersistentHashMap, or the temporary instances (for branching use when +// rehashing). +libtextclassifier3::Status UpdateChecksums( + PersistentHashMap::Crcs* crcs, PersistentHashMap::Info* info, + FileBackedVector<PersistentHashMap::Bucket>* bucket_storage, + FileBackedVector<PersistentHashMap::Entry>* entry_storage, + FileBackedVector<char>* kv_storage) { + // Compute crcs + ICING_ASSIGN_OR_RETURN(Crc32 bucket_storage_crc, + bucket_storage->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 entry_storage_crc, + entry_storage->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 kv_storage_crc, kv_storage->ComputeChecksum()); + + crcs->component_crcs.info_crc = info->ComputeChecksum().Get(); + crcs->component_crcs.bucket_storage_crc = bucket_storage_crc.Get(); + crcs->component_crcs.entry_storage_crc = entry_storage_crc.Get(); + crcs->component_crcs.kv_storage_crc = kv_storage_crc.Get(); + crcs->all_crc = crcs->component_crcs.ComputeChecksum().Get(); + + return libtextclassifier3::Status::OK; +} + +// Helper function to validate checksums. +libtextclassifier3::Status ValidateChecksums( + const PersistentHashMap::Crcs* crcs, const PersistentHashMap::Info* info, + FileBackedVector<PersistentHashMap::Bucket>* bucket_storage, + FileBackedVector<PersistentHashMap::Entry>* entry_storage, + FileBackedVector<char>* kv_storage) { + if (crcs->all_crc != crcs->component_crcs.ComputeChecksum().Get()) { + return absl_ports::FailedPreconditionError( + "Invalid all crc for PersistentHashMap"); + } + + if (crcs->component_crcs.info_crc != info->ComputeChecksum().Get()) { + return absl_ports::FailedPreconditionError( + "Invalid info crc for PersistentHashMap"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 bucket_storage_crc, + bucket_storage->ComputeChecksum()); + if (crcs->component_crcs.bucket_storage_crc != bucket_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap bucket storage"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 entry_storage_crc, + entry_storage->ComputeChecksum()); + if (crcs->component_crcs.entry_storage_crc != entry_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap entry storage"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 kv_storage_crc, kv_storage->ComputeChecksum()); + if (crcs->component_crcs.kv_storage_crc != kv_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap key value storage"); + } + + return libtextclassifier3::Status::OK; +} + +// Since metadata/bucket/entry storages should be branched when rehashing, we +// have to store them together under the same sub directory +// ("<base_dir>/<sub_dir>"). On the other hand, key-value storage won't be +// branched and it will be stored under <base_dir>. +// +// The following 4 methods are helper functions to get the correct path of +// metadata/bucket/entry/key-value storages, according to the given base +// directory and sub directory. +std::string GetMetadataFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".m"); +} + +std::string GetBucketStorageFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".b"); +} + +std::string GetEntryStorageFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".e"); +} + +std::string GetKeyValueStorageFilePath(std::string_view base_dir) { + return absl_ports::StrCat(base_dir, "/", PersistentHashMap::kFilePrefix, + ".k"); +} + +} // namespace + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::Create(const Filesystem& filesystem, + std::string_view base_dir, int32_t value_type_size, + int32_t max_load_factor_percent) { + if (!filesystem.FileExists( + GetMetadataFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists( + GetBucketStorageFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists( + GetEntryStorageFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists(GetKeyValueStorageFilePath(base_dir).c_str())) { + // TODO: erase all files if missing any. + return InitializeNewFiles(filesystem, base_dir, value_type_size, + max_load_factor_percent); + } + return InitializeExistingFiles(filesystem, base_dir, value_type_size, + max_load_factor_percent); +} + +PersistentHashMap::~PersistentHashMap() { + if (!PersistToDisk().ok()) { + ICING_LOG(WARNING) + << "Failed to persist hash map to disk while destructing " << base_dir_; + } +} + +libtextclassifier3::Status PersistentHashMap::Put(std::string_view key, + const void* value) { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + // If not found, then insert new key value pair. + return Insert(bucket_idx, key, value); + } + + // Otherwise, overwrite the value. + ICING_ASSIGN_OR_RETURN(const Entry* entry, + entry_storage_->Get(target_entry_idx)); + + int32_t kv_len = key.length() + 1 + info()->value_type_size; + int32_t value_offset = key.length() + 1; + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<char>::MutableArrayView mutable_kv_arr, + kv_storage_->GetMutable(entry->key_value_index(), kv_len)); + // It is the same key and value_size is fixed, so we can directly overwrite + // serialized value. + mutable_kv_arr.SetArray(value_offset, reinterpret_cast<const char*>(value), + info()->value_type_size); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PersistentHashMap::GetOrPut(std::string_view key, + void* next_value) { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + // If not found, then insert new key value pair. + return Insert(bucket_idx, key, next_value); + } + + // Otherwise, copy the hash map value into next_value. + return CopyEntryValue(target_entry_idx, next_value); +} + +libtextclassifier3::Status PersistentHashMap::Get(std::string_view key, + void* value) const { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + return absl_ports::NotFoundError( + absl_ports::StrCat("Key not found in PersistentHashMap ", base_dir_)); + } + + return CopyEntryValue(target_entry_idx, value); +} + +libtextclassifier3::Status PersistentHashMap::PersistToDisk() { + ICING_RETURN_IF_ERROR(bucket_storage_->PersistToDisk()); + ICING_RETURN_IF_ERROR(entry_storage_->PersistToDisk()); + ICING_RETURN_IF_ERROR(kv_storage_->PersistToDisk()); + + ICING_RETURN_IF_ERROR(UpdateChecksums(crcs(), info(), bucket_storage_.get(), + entry_storage_.get(), + kv_storage_.get())); + // Changes should have been applied to the underlying file when using + // MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, but call msync() as an + // extra safety step to ensure they are written out. + ICING_RETURN_IF_ERROR(metadata_mmapped_file_->PersistToDisk()); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::StatusOr<int64_t> PersistentHashMap::GetDiskUsage() const { + ICING_ASSIGN_OR_RETURN(int64_t bucket_storage_disk_usage, + bucket_storage_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(int64_t entry_storage_disk_usage, + entry_storage_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(int64_t kv_storage_disk_usage, + kv_storage_->GetDiskUsage()); + + int64_t total = bucket_storage_disk_usage + entry_storage_disk_usage + + kv_storage_disk_usage; + Filesystem::IncrementByOrSetInvalid( + filesystem_->GetDiskUsage( + GetMetadataFilePath(base_dir_, kSubDirectory).c_str()), + &total); + + if (total < 0 || total == Filesystem::kBadFileSize) { + return absl_ports::InternalError( + "Failed to get disk usage of PersistentHashMap"); + } + return total; +} + +libtextclassifier3::StatusOr<int64_t> PersistentHashMap::GetElementsSize() + const { + ICING_ASSIGN_OR_RETURN(int64_t bucket_storage_elements_size, + bucket_storage_->GetElementsFileSize()); + ICING_ASSIGN_OR_RETURN(int64_t entry_storage_elements_size, + entry_storage_->GetElementsFileSize()); + ICING_ASSIGN_OR_RETURN(int64_t kv_storage_elements_size, + kv_storage_->GetElementsFileSize()); + return bucket_storage_elements_size + entry_storage_elements_size + + kv_storage_elements_size; +} + +libtextclassifier3::StatusOr<Crc32> PersistentHashMap::ComputeChecksum() { + Crcs* crcs_ptr = crcs(); + ICING_RETURN_IF_ERROR(UpdateChecksums(crcs_ptr, info(), bucket_storage_.get(), + entry_storage_.get(), + kv_storage_.get())); + return Crc32(crcs_ptr->all_crc); +} + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, + std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent) { + // Create directory. + const std::string dir_path = absl_ports::StrCat(base_dir, "/", kSubDirectory); + if (!filesystem.CreateDirectoryRecursively(dir_path.c_str())) { + return absl_ports::InternalError( + absl_ports::StrCat("Failed to create directory: ", dir_path)); + } + + // Initialize 3 storages + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + FileBackedVector<Bucket>::Create( + filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + FileBackedVector<Entry>::Create( + filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, + FileBackedVector<char>::Create( + filesystem, GetKeyValueStorageFilePath(base_dir), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + // Initialize one bucket. + ICING_RETURN_IF_ERROR(bucket_storage->Append(Bucket())); + ICING_RETURN_IF_ERROR(bucket_storage->PersistToDisk()); + + // Create and initialize new info + Info new_info; + new_info.version = kVersion; + new_info.value_type_size = value_type_size; + new_info.max_load_factor_percent = max_load_factor_percent; + new_info.num_deleted_entries = 0; + new_info.num_deleted_key_value_bytes = 0; + + // Compute checksums + Crcs new_crcs; + ICING_RETURN_IF_ERROR(UpdateChecksums(&new_crcs, &new_info, + bucket_storage.get(), + entry_storage.get(), kv_storage.get())); + + const std::string metadata_file_path = + GetMetadataFilePath(base_dir, kSubDirectory); + // Write new metadata file + ICING_RETURN_IF_ERROR(WriteMetadata(filesystem, metadata_file_path.c_str(), + &new_crcs, &new_info)); + + // Mmap the content of the crcs and info. + auto metadata_mmapped_file = std::make_unique<MemoryMappedFile>( + filesystem, metadata_file_path, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->Remap( + /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); + + return std::unique_ptr<PersistentHashMap>(new PersistentHashMap( + filesystem, base_dir, std::move(metadata_mmapped_file), + std::move(bucket_storage), std::move(entry_storage), + std::move(kv_storage))); +} + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, + std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent) { + // Mmap the content of the crcs and info. + auto metadata_mmapped_file = std::make_unique<MemoryMappedFile>( + filesystem, GetMetadataFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->Remap( + /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); + + // Initialize 3 storages + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + FileBackedVector<Bucket>::Create( + filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + FileBackedVector<Entry>::Create( + filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, + FileBackedVector<char>::Create( + filesystem, GetKeyValueStorageFilePath(base_dir), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + Crcs* crcs_ptr = reinterpret_cast<Crcs*>( + metadata_mmapped_file->mutable_region() + Crcs::kFileOffset); + Info* info_ptr = reinterpret_cast<Info*>( + metadata_mmapped_file->mutable_region() + Info::kFileOffset); + + // Value type size should be consistent. + if (value_type_size != info_ptr->value_type_size) { + return absl_ports::FailedPreconditionError("Incorrect value type size"); + } + + // Validate checksums of info and 3 storages. + ICING_RETURN_IF_ERROR( + ValidateChecksums(crcs_ptr, info_ptr, bucket_storage.get(), + entry_storage.get(), kv_storage.get())); + + // Allow max_load_factor_percent_ change. + if (max_load_factor_percent != info_ptr->max_load_factor_percent) { + ICING_VLOG(2) << "Changing max_load_factor_percent from " << info_ptr->max_load_factor_percent << " to " << max_load_factor_percent; + + info_ptr->max_load_factor_percent = max_load_factor_percent; + crcs_ptr->component_crcs.info_crc = info_ptr->ComputeChecksum().Get(); + crcs_ptr->all_crc = crcs_ptr->component_crcs.ComputeChecksum().Get(); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->PersistToDisk()); + // TODO(b/193919210): rehash if needed + } + + return std::unique_ptr<PersistentHashMap>(new PersistentHashMap( + filesystem, base_dir, std::move(metadata_mmapped_file), + std::move(bucket_storage), std::move(entry_storage), + std::move(kv_storage))); +} + +libtextclassifier3::StatusOr<int32_t> PersistentHashMap::FindEntryIndexByKey( + int32_t bucket_idx, std::string_view key) const { + // Iterate all entries in the bucket, compare with key, and return the entry + // index if exists. + ICING_ASSIGN_OR_RETURN(const Bucket* bucket, + bucket_storage_->Get(bucket_idx)); + int32_t curr_entry_idx = bucket->head_entry_index(); + while (curr_entry_idx != Entry::kInvalidIndex) { + ICING_ASSIGN_OR_RETURN(const Entry* entry, + entry_storage_->Get(curr_entry_idx)); + if (entry->key_value_index() == kInvalidKVIndex) { + ICING_LOG(ERROR) << "Got an invalid key value index in the persistent " + "hash map bucket. This shouldn't happen"; + return absl_ports::InternalError("Unexpected invalid key value index"); + } + ICING_ASSIGN_OR_RETURN(const char* kv_arr, + kv_storage_->Get(entry->key_value_index())); + if (key.compare(kv_arr) == 0) { + return curr_entry_idx; + } + + curr_entry_idx = entry->next_entry_index(); + } + + return curr_entry_idx; +} + +libtextclassifier3::Status PersistentHashMap::CopyEntryValue( + int32_t entry_idx, void* value) const { + ICING_ASSIGN_OR_RETURN(const Entry* entry, entry_storage_->Get(entry_idx)); + + ICING_ASSIGN_OR_RETURN(const char* kv_arr, + kv_storage_->Get(entry->key_value_index())); + int32_t value_offset = strlen(kv_arr) + 1; + memcpy(value, kv_arr + value_offset, info()->value_type_size); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PersistentHashMap::Insert(int32_t bucket_idx, + std::string_view key, + const void* value) { + // If size() + 1 exceeds Entry::kMaxNumEntries, then return error. + if (size() > Entry::kMaxNumEntries - 1) { + return absl_ports::ResourceExhaustedError("Cannot insert new entry"); + } + + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<Bucket>::MutableView mutable_bucket, + bucket_storage_->GetMutable(bucket_idx)); + + // Append new key value. + int32_t new_kv_idx = kv_storage_->num_elements(); + int32_t kv_len = key.size() + 1 + info()->value_type_size; + int32_t value_offset = key.size() + 1; + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<char>::MutableArrayView mutable_new_kv_arr, + kv_storage_->Allocate(kv_len)); + mutable_new_kv_arr.SetArray(/*idx=*/0, key.data(), key.size()); + mutable_new_kv_arr.SetArray(/*idx=*/key.size(), "\0", 1); + mutable_new_kv_arr.SetArray(/*idx=*/value_offset, + reinterpret_cast<const char*>(value), + info()->value_type_size); + + // Append new entry. + int32_t new_entry_idx = entry_storage_->num_elements(); + ICING_RETURN_IF_ERROR(entry_storage_->Append( + Entry(new_kv_idx, mutable_bucket.Get().head_entry_index()))); + mutable_bucket.Get().set_head_entry_index(new_entry_idx); + + // TODO: rehash if needed + + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/file/persistent-hash-map.h b/icing/file/persistent-hash-map.h new file mode 100644 index 0000000..24a47ea --- /dev/null +++ b/icing/file/persistent-hash-map.h @@ -0,0 +1,383 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_FILE_PERSISTENT_HASH_MAP_H_ +#define ICING_FILE_PERSISTENT_HASH_MAP_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +// Low level persistent hash map. +// It supports variant length serialized key + fixed length serialized value. +// Key and value can be any type, but callers should serialize key/value by +// themselves and pass raw bytes into the hash map, and the serialized key +// should not contain termination character '\0'. +class PersistentHashMap { + public: + // Crcs and Info will be written into the metadata file. + // File layout: <Crcs><Info> + // Crcs + struct Crcs { + static constexpr int32_t kFileOffset = 0; + + struct ComponentCrcs { + uint32_t info_crc; + uint32_t bucket_storage_crc; + uint32_t entry_storage_crc; + uint32_t kv_storage_crc; + + bool operator==(const ComponentCrcs& other) const { + return info_crc == other.info_crc && + bucket_storage_crc == other.bucket_storage_crc && + entry_storage_crc == other.entry_storage_crc && + kv_storage_crc == other.kv_storage_crc; + } + + Crc32 ComputeChecksum() const { + return Crc32(std::string_view(reinterpret_cast<const char*>(this), + sizeof(ComponentCrcs))); + } + } __attribute__((packed)); + + bool operator==(const Crcs& other) const { + return all_crc == other.all_crc && component_crcs == other.component_crcs; + } + + uint32_t all_crc; + ComponentCrcs component_crcs; + } __attribute__((packed)); + static_assert(sizeof(Crcs) == 20, ""); + + // Info + struct Info { + static constexpr int32_t kFileOffset = static_cast<int32_t>(sizeof(Crcs)); + + int32_t version; + int32_t value_type_size; + int32_t max_load_factor_percent; + int32_t num_deleted_entries; + int32_t num_deleted_key_value_bytes; + + Crc32 ComputeChecksum() const { + return Crc32( + std::string_view(reinterpret_cast<const char*>(this), sizeof(Info))); + } + } __attribute__((packed)); + static_assert(sizeof(Info) == 20, ""); + + // Bucket + class Bucket { + public: + // Absolute max # of buckets allowed. Since max file size on Android is + // 2^31-1, we can at most have ~2^29 buckets. To make it power of 2, round + // it down to 2^28. Also since we're using FileBackedVector to store + // buckets, add some static_asserts to ensure numbers here are compatible + // with FileBackedVector. + static constexpr int32_t kMaxNumBuckets = 1 << 28; + + explicit Bucket(int32_t head_entry_index = Entry::kInvalidIndex) + : head_entry_index_(head_entry_index) {} + + // For FileBackedVector + bool operator==(const Bucket& other) const { + return head_entry_index_ == other.head_entry_index_; + } + + int32_t head_entry_index() const { return head_entry_index_; } + void set_head_entry_index(int32_t head_entry_index) { + head_entry_index_ = head_entry_index; + } + + private: + int32_t head_entry_index_; + } __attribute__((packed)); + static_assert(sizeof(Bucket) == 4, ""); + static_assert(sizeof(Bucket) == FileBackedVector<Bucket>::kElementTypeSize, + "Bucket type size is inconsistent with FileBackedVector " + "element type size"); + static_assert(Bucket::kMaxNumBuckets <= + (FileBackedVector<Bucket>::kMaxFileSize - + FileBackedVector<Bucket>::Header::kHeaderSize) / + FileBackedVector<Bucket>::kElementTypeSize, + "Max # of buckets cannot fit into FileBackedVector"); + + // Entry + class Entry { + public: + // Absolute max # of entries allowed. Since max file size on Android is + // 2^31-1, we can at most have ~2^28 entries. To make it power of 2, round + // it down to 2^27. Also since we're using FileBackedVector to store + // entries, add some static_asserts to ensure numbers here are compatible + // with FileBackedVector. + // + // Still the actual max # of entries are determined by key-value storage, + // since length of the key varies and affects # of actual key-value pairs + // that can be stored. + static constexpr int32_t kMaxNumEntries = 1 << 27; + static constexpr int32_t kMaxIndex = kMaxNumEntries - 1; + static constexpr int32_t kInvalidIndex = -1; + + explicit Entry(int32_t key_value_index, int32_t next_entry_index) + : key_value_index_(key_value_index), + next_entry_index_(next_entry_index) {} + + bool operator==(const Entry& other) const { + return key_value_index_ == other.key_value_index_ && + next_entry_index_ == other.next_entry_index_; + } + + int32_t key_value_index() const { return key_value_index_; } + void set_key_value_index(int32_t key_value_index) { + key_value_index_ = key_value_index; + } + + int32_t next_entry_index() const { return next_entry_index_; } + void set_next_entry_index(int32_t next_entry_index) { + next_entry_index_ = next_entry_index; + } + + private: + int32_t key_value_index_; + int32_t next_entry_index_; + } __attribute__((packed)); + static_assert(sizeof(Entry) == 8, ""); + static_assert(sizeof(Entry) == FileBackedVector<Entry>::kElementTypeSize, + "Entry type size is inconsistent with FileBackedVector " + "element type size"); + static_assert(Entry::kMaxNumEntries <= + (FileBackedVector<Entry>::kMaxFileSize - + FileBackedVector<Entry>::Header::kHeaderSize) / + FileBackedVector<Entry>::kElementTypeSize, + "Max # of entries cannot fit into FileBackedVector"); + + // Key-value serialized type + static constexpr int32_t kMaxKVTotalByteSize = + (FileBackedVector<char>::kMaxFileSize - + FileBackedVector<char>::Header::kHeaderSize) / + FileBackedVector<char>::kElementTypeSize; + static constexpr int32_t kMaxKVIndex = kMaxKVTotalByteSize - 1; + static constexpr int32_t kInvalidKVIndex = -1; + static_assert(sizeof(char) == FileBackedVector<char>::kElementTypeSize, + "Char type size is inconsistent with FileBackedVector element " + "type size"); + + static constexpr int32_t kVersion = 1; + static constexpr int32_t kDefaultMaxLoadFactorPercent = 75; + + static constexpr std::string_view kFilePrefix = "persistent_hash_map"; + // Only metadata, bucket, entry files are stored under this sub-directory, for + // rehashing branching use. + static constexpr std::string_view kSubDirectory = "dynamic"; + + // Creates a new PersistentHashMap to read/write/delete key value pairs. + // + // filesystem: Object to make system level calls + // base_dir: Specifies the directory for all persistent hash map related + // sub-directory and files to be stored. If base_dir doesn't exist, + // then PersistentHashMap will automatically create it. If files + // exist, then it will initialize the hash map from existing files. + // value_type_size: (fixed) size of the serialized value type for hash map. + // max_load_factor_percent: percentage of the max loading for the hash map. + // load_factor_percent = 100 * num_keys / num_buckets + // If load_factor_percent exceeds + // max_load_factor_percent, then rehash will be + // invoked (and # of buckets will be doubled). + // Note that load_factor_percent exceeding 100 is + // considered valid. + // + // Returns: + // FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored + // checksum. + // INTERNAL_ERROR on I/O errors. + // Any FileBackedVector errors. + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + Create(const Filesystem& filesystem, std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent = kDefaultMaxLoadFactorPercent); + + ~PersistentHashMap(); + + // Update a key value pair. If key does not exist, then insert (key, value) + // into the storage. Otherwise overwrite the value into the storage. + // + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // RESOURCE_EXHAUSTED_ERROR if # of entries reach kMaxNumEntries + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status Put(std::string_view key, const void* value); + + // If key does not exist, then insert (key, next_value) into the storage. + // Otherwise, copy the hash map value into next_value. + // + // REQUIRES: the buffer pointed to by next_value must be of value_size() + // + // Returns: + // OK on success + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status GetOrPut(std::string_view key, void* next_value); + + // Get the value by key from the storage. If key exists, then copy the hash + // map value into into value buffer. Otherwise, return NOT_FOUND_ERROR. + // + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // NOT_FOUND_ERROR if the key doesn't exist + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status Get(std::string_view key, void* value) const; + + // Flushes content to underlying files. + // + // Returns: + // OK on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistToDisk(); + + // Calculates and returns the disk usage (metadata + 3 storages total file + // size) in bytes. + // + // Returns: + // Disk usage on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const; + + // Returns the total file size of the all the elements held in the persistent + // hash map. File size is in bytes. This excludes the size of any internal + // metadata, i.e. crcs/info of persistent hash map, file backed vector's + // header. + // + // Returns: + // File size on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::StatusOr<int64_t> GetElementsSize() const; + + // Updates all checksums of the persistent hash map components and returns + // all_crc. + // + // Returns: + // Crc of all components (all_crc) on success + // INTERNAL_ERROR if any data inconsistency + libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); + + int32_t size() const { + return entry_storage_->num_elements() - info()->num_deleted_entries; + } + + bool empty() const { return size() == 0; } + + private: + explicit PersistentHashMap( + const Filesystem& filesystem, std::string_view base_dir, + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file, + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + std::unique_ptr<FileBackedVector<char>> kv_storage) + : filesystem_(&filesystem), + base_dir_(base_dir), + metadata_mmapped_file_(std::move(metadata_mmapped_file)), + bucket_storage_(std::move(bucket_storage)), + entry_storage_(std::move(entry_storage)), + kv_storage_(std::move(kv_storage)) {} + + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + InitializeNewFiles(const Filesystem& filesystem, std::string_view base_dir, + int32_t value_type_size, int32_t max_load_factor_percent); + + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + InitializeExistingFiles(const Filesystem& filesystem, + std::string_view base_dir, int32_t value_type_size, + int32_t max_load_factor_percent); + + // Find the index of the key entry from a bucket (specified by bucket index). + // The caller should specify the desired bucket index. + // + // Returns: + // int32_t: on success, the index of the entry, or Entry::kInvalidIndex if + // not found + // INTERNAL_ERROR if any content inconsistency + // Any FileBackedVector errors + libtextclassifier3::StatusOr<int32_t> FindEntryIndexByKey( + int32_t bucket_idx, std::string_view key) const; + + // Copy the hash map value of the entry into value buffer. + // + // REQUIRES: entry_idx should be valid. + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // Any FileBackedVector errors + libtextclassifier3::Status CopyEntryValue(int32_t entry_idx, + void* value) const; + + // Insert a new key value pair into a bucket (specified by the bucket index). + // The caller should specify the desired bucket index and make sure that the + // key is not present in the hash map before calling. + // + // Returns: + // OK on success + // Any FileBackedVector errors + libtextclassifier3::Status Insert(int32_t bucket_idx, std::string_view key, + const void* value); + + Crcs* crcs() { + return reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() + + Crcs::kFileOffset); + } + + Info* info() { + return reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() + + Info::kFileOffset); + } + + const Info* info() const { + return reinterpret_cast<const Info*>(metadata_mmapped_file_->region() + + Info::kFileOffset); + } + + const Filesystem* filesystem_; + std::string base_dir_; + + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; + + // Storages + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage_; + std::unique_ptr<FileBackedVector<Entry>> entry_storage_; + std::unique_ptr<FileBackedVector<char>> kv_storage_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_FILE_PERSISTENT_HASH_MAP_H_ diff --git a/icing/file/persistent-hash-map_test.cc b/icing/file/persistent-hash-map_test.cc new file mode 100644 index 0000000..fb15175 --- /dev/null +++ b/icing/file/persistent-hash-map_test.cc @@ -0,0 +1,662 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/persistent-hash-map.h" + +#include <cstring> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +namespace { + +static constexpr int32_t kCorruptedValueOffset = 3; + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Pointee; +using ::testing::SizeIs; + +using Bucket = PersistentHashMap::Bucket; +using Crcs = PersistentHashMap::Crcs; +using Entry = PersistentHashMap::Entry; +using Info = PersistentHashMap::Info; + +class PersistentHashMapTest : public ::testing::Test { + protected: + void SetUp() override { + base_dir_ = GetTestTempDir() + "/persistent_hash_map_test"; + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + std::vector<char> Serialize(int val) { + std::vector<char> ret(sizeof(val)); + memcpy(ret.data(), &val, sizeof(val)); + return ret; + } + + libtextclassifier3::StatusOr<int> GetValueByKey( + PersistentHashMap* persistent_hash_map, std::string_view key) { + int val; + ICING_RETURN_IF_ERROR(persistent_hash_map->Get(key, &val)); + return val; + } + + Filesystem filesystem_; + std::string base_dir_; +}; + +TEST_F(PersistentHashMapTest, InvalidBaseDir) { + EXPECT_THAT(PersistentHashMap::Create(filesystem_, "/dev/null", + /*value_type_size=*/sizeof(int)), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); +} + +TEST_F(PersistentHashMapTest, InitializeNewFiles) { + { + ASSERT_FALSE(filesystem_.DirectoryExists(base_dir_.c_str())); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + // Metadata file should be initialized correctly for both info and crcs + // sections. + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + // Check info section + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + EXPECT_THAT(info.version, Eq(PersistentHashMap::kVersion)); + EXPECT_THAT(info.value_type_size, Eq(sizeof(int))); + EXPECT_THAT(info.max_load_factor_percent, + Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent)); + EXPECT_THAT(info.num_deleted_entries, Eq(0)); + EXPECT_THAT(info.num_deleted_key_value_bytes, Eq(0)); + + // Check crcs section + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + // # of elements in bucket_storage should be 1, so it should have non-zero + // crc value. + EXPECT_THAT(crcs.component_crcs.bucket_storage_crc, Not(Eq(0))); + // Other empty file backed vectors should have 0 crc value. + EXPECT_THAT(crcs.component_crcs.entry_storage_crc, Eq(0)); + EXPECT_THAT(crcs.component_crcs.kv_storage_crc, Eq(0)); + EXPECT_THAT(crcs.component_crcs.info_crc, + Eq(Crc32(std::string_view(reinterpret_cast<const char*>(&info), + sizeof(Info))) + .Get())); + EXPECT_THAT(crcs.all_crc, + Eq(Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get())); +} + +TEST_F(PersistentHashMapTest, + TestInitializationFailsWithoutPersistToDiskOrDestruction) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + // Put some key value pairs. + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + // Without calling PersistToDisk, checksums will not be recomputed or synced + // to disk, so initializing another instance on the same files should fail. + EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map1, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + // Put some key value pairs. + ICING_ASSERT_OK(persistent_hash_map1->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map1->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map1, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map1.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map1.get(), "b"), IsOkAndHolds(2)); + + // After calling PersistToDisk, all checksums should be recomputed and synced + // correctly to disk, so initializing another instance on the same files + // should succeed, and we should be able to get the same contents. + ICING_EXPECT_OK(persistent_hash_map1->PersistToDisk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map2, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + EXPECT_THAT(persistent_hash_map2, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "b"), IsOkAndHolds(2)); +} + +TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) { + { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + } + + { + // The previous instance went out of scope and was destructed. Although we + // didn't call PersistToDisk explicitly, the destructor should invoke it and + // thus initializing another instance on the same files should succeed, and + // we should be able to get the same contents. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithDifferentValueTypeSizeShouldFail) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + { + // Attempt to create the persistent hash map with different value type size. + // This should fail. + ASSERT_THAT(sizeof(char), Not(Eq(sizeof(int)))); + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(char)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Incorrect value type size")); + } +} + +TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt all_crc + crcs.all_crc += kCorruptedValueOffset; + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + metadata_sfd.reset(); + + { + // Attempt to create the persistent hash map with metadata containing + // corrupted all_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Invalid all crc for PersistentHashMap")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithCorruptedInfoShouldFail) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + + // Modify info, but don't update the checksum. This would be similar to + // corruption of info. + info.num_deleted_entries += kCorruptedValueOffset; + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Info::kFileOffset, &info, + sizeof(Info))); + { + // Attempt to create the persistent hash map with info that doesn't match + // its checksum and confirm that it fails. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Invalid info crc for PersistentHashMap")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithWrongBucketStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt bucket_storage_crc + crcs.component_crcs.bucket_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted bucket_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap bucket storage")); + } +} + +TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt entry_storage_crc + crcs.component_crcs.entry_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted entry_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap entry storage")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithWrongKeyValueStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt kv_storage_crc + crcs.component_crcs.kv_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted kv_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap key value storage")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesAllowDifferentMaxLoadFactorPercent) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + int32_t new_max_load_factor_percent = 100; + { + ASSERT_THAT(new_max_load_factor_percent, + Not(Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent))); + // Attempt to create the persistent hash map with different max load factor + // percent. This should succeed and metadata should be modified correctly. + // Also verify all entries should remain unchanged. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + new_max_load_factor_percent)); + + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + EXPECT_THAT(info.max_load_factor_percent, Eq(new_max_load_factor_percent)); + + // Also should update crcs correctly. We test it by creating instance again + // and make sure it won't get corrupted crcs/info errors. + { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + new_max_load_factor_percent)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } +} + +TEST_F(PersistentHashMapTest, PutAndGet) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ICING_EXPECT_OK( + persistent_hash_map->Put("default-youtube.com", Serialize(50).data())); + + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(100)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + IsOkAndHolds(50)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "key-not-exist"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); +} + +TEST_F(PersistentHashMapTest, PutShouldOverwriteValueIfKeyExists) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(100)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(200).data())); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(200)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(300).data())); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(300)); +} + +TEST_F(PersistentHashMapTest, GetOrPutShouldPutIfKeyDoesNotExist) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + int val = 1; + EXPECT_THAT(persistent_hash_map->GetOrPut("default-google.com", &val), + IsOk()); + EXPECT_THAT(val, Eq(1)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); +} + +TEST_F(PersistentHashMapTest, GetOrPutShouldGetIfKeyExists) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ASSERT_THAT( + persistent_hash_map->Put("default-google.com", Serialize(1).data()), + IsOk()); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); + + int val = 2; + EXPECT_THAT(persistent_hash_map->GetOrPut("default-google.com", &val), + IsOk()); + EXPECT_THAT(val, Eq(1)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); +} + +TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + const char invalid_key[] = "a\0bc"; + std::string_view invalid_key_view(invalid_key, 4); + + int val = 1; + EXPECT_THAT(persistent_hash_map->Put(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(persistent_hash_map->GetOrPut(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(persistent_hash_map->Get(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/file/portable-file-backed-proto-log_benchmark.cc b/icing/file/portable-file-backed-proto-log_benchmark.cc index 80a8011..d7ea4bb 100644 --- a/icing/file/portable-file-backed-proto-log_benchmark.cc +++ b/icing/file/portable-file-backed-proto-log_benchmark.cc @@ -33,7 +33,7 @@ // icing/file:portable-file-backed-proto-log_benchmark // // $ blaze-bin/icing/file/portable-file-backed-proto-log_benchmark -// --benchmarks=all +// --benchmark_filter=all // // // To build and run on an Android device (must be connected and rooted): @@ -48,7 +48,7 @@ // /data/local/tmp/ // // $ adb shell /data/local/tmp/portable-file-backed-proto-log-benchmark -// --benchmarks=all +// --benchmark_filter=all namespace icing { namespace lib { diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 952ba21..4be4ac3 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -49,14 +49,16 @@ #include "icing/proto/status.pb.h" #include "icing/query/query-processor.h" #include "icing/query/suggestion-processor.h" +#include "icing/result/page-result.h" #include "icing/result/projection-tree.h" #include "icing/result/projector.h" -#include "icing/result/result-retriever.h" +#include "icing/result/result-retriever-v2.h" #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" #include "icing/schema/section.h" -#include "icing/scoring/ranker.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" #include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/scoring/scoring-processor.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -112,6 +114,11 @@ libtextclassifier3::Status ValidateResultSpec( return absl_ports::InvalidArgumentError( "ResultSpecProto.num_per_page cannot be negative."); } + if (result_spec.num_total_bytes_per_page_threshold() <= 0) { + return absl_ports::InvalidArgumentError( + "ResultSpecProto.num_total_bytes_per_page_threshold cannot be " + "non-positive."); + } std::unordered_set<std::string> unique_namespaces; for (const ResultSpecProto::ResultGrouping& result_grouping : result_spec.result_groupings()) { @@ -263,9 +270,9 @@ void TransformStatus(const libtextclassifier3::Status& internal_status, case libtextclassifier3::StatusCode::UNAUTHENTICATED: // Other internal status codes aren't supported externally yet. If it // should be supported, add another switch-case above. - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Internal status code %d not supported in the external API", - internal_status.error_code()); + ICING_LOG(ERROR) << "Internal status code " + << internal_status.error_code() + << " not supported in the external API"; code = StatusProto::UNKNOWN; break; } @@ -295,6 +302,17 @@ libtextclassifier3::Status RetrieveAndAddDocumentInfo( return libtextclassifier3::Status::OK; } +bool ShouldRebuildIndex(const OptimizeStatsProto& optimize_stats) { + int num_invalid_documents = optimize_stats.num_deleted_documents() + + optimize_stats.num_expired_documents(); + // Rebuilding the index could be faster than optimizing the index if we have + // removed most of the documents. + // Based on benchmarks, 85%~95% seems to be a good threshold for most cases. + // TODO(b/238236206): Try using the number of remaining hits in this + // condition, and allow clients to configure the threshold. + return num_invalid_documents >= optimize_stats.num_original_documents() * 0.9; +} + } // namespace IcingSearchEngine::IcingSearchEngine(const IcingSearchEngineOptions& options, @@ -529,7 +547,8 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( } result_state_manager_ = std::make_unique<ResultStateManager>( - performance_configuration_.max_num_total_hits, *document_store_); + performance_configuration_.max_num_total_hits, *document_store_, + clock_.get()); return status; } @@ -633,18 +652,18 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( StatusProto* result_status = result_proto.mutable_status(); absl_ports::unique_lock l(&mutex_); - std::unique_ptr<Timer> timer = clock_->GetNewTimer(); + ScopedTimer timer(clock_->GetNewTimer(), [&result_proto](int64_t t) { + result_proto.set_latency_ms(t); + }); if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } auto lost_previous_schema_or = LostPreviousSchema(); if (!lost_previous_schema_or.ok()) { TransformStatus(lost_previous_schema_or.status(), result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } bool lost_previous_schema = lost_previous_schema_or.ValueOrDie(); @@ -662,7 +681,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( std::move(new_schema), ignore_errors_and_delete_documents); if (!set_schema_result_or.ok()) { TransformStatus(set_schema_result_or.status(), result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } SchemaStore::SetSchemaResult set_schema_result = @@ -705,7 +723,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( status = document_store_->UpdateSchemaStore(schema_store_.get()); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } else if (!set_schema_result.old_schema_type_ids_changed.empty() || @@ -715,7 +732,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( set_schema_result); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } @@ -725,7 +741,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( status = index_->Reset(); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } @@ -736,7 +751,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( if (!restore_result.status.ok() && !absl_ports::IsDataLoss(restore_result.status)) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } @@ -747,7 +761,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( result_status->set_message("Schema is incompatible."); } - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } @@ -803,12 +816,13 @@ PutResultProto IcingSearchEngine::Put(const DocumentProto& document) { PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { ICING_VLOG(1) << "Writing document to document store"; - std::unique_ptr<Timer> put_timer = clock_->GetNewTimer(); - PutResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); PutDocumentStatsProto* put_document_stats = result_proto.mutable_put_document_stats(); + ScopedTimer put_timer(clock_->GetNewTimer(), [put_document_stats](int64_t t) { + put_document_stats->set_latency_ms(t); + }); // Lock must be acquired before validation because the DocumentStore uses // the schema file to validate, and the schema could be changed in @@ -817,7 +831,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } @@ -825,7 +838,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { schema_store_.get(), language_segmenter_.get(), std::move(document)); if (!tokenized_document_or.ok()) { TransformStatus(tokenized_document_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } TokenizedDocument tokenized_document( @@ -836,7 +848,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { tokenized_document.num_tokens(), put_document_stats); if (!document_id_or.ok()) { TransformStatus(document_id_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } DocumentId document_id = document_id_or.ValueOrDie(); @@ -845,7 +856,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<IndexProcessor> index_processor = @@ -866,7 +876,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { } TransformStatus(status, result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } @@ -1080,7 +1089,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( delete_stats->set_num_schema_types_filtered( search_spec.schema_type_filters_size()); - std::unique_ptr<Timer> delete_timer = clock_->GetNewTimer(); + ScopedTimer delete_timer(clock_->GetNewTimer(), [delete_stats](int64_t t) { + delete_stats->set_latency_ms(t); + }); libtextclassifier3::Status status = ValidateSearchSpec(search_spec, performance_configuration_); if (!status.ok()) { @@ -1095,6 +1106,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( document_store_.get(), schema_store_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); + delete_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<QueryProcessor> query_processor = @@ -1103,6 +1116,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( auto query_results_or = query_processor->ParseSearch(search_spec); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); + delete_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } QueryProcessor::QueryResults query_results = @@ -1130,6 +1145,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( query_results.root_iterator->doc_hit_info().document_id()); if (!status.ok()) { TransformStatus(status, result_status); + delete_stats->set_document_removal_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } } @@ -1137,6 +1154,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( query_results.root_iterator->doc_hit_info().document_id()); if (!status.ok()) { TransformStatus(status, result_status); + delete_stats->set_document_removal_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } } @@ -1155,7 +1174,6 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( result_proto.mutable_status()->set_message( "No documents matched the query to delete by!"); } - delete_stats->set_latency_ms(delete_timer->GetElapsedMilliseconds()); delete_stats->set_num_documents_deleted(num_deleted); return result_proto; } @@ -1198,11 +1216,10 @@ OptimizeResultProto IcingSearchEngine::Optimize() { return result_proto; } - std::unique_ptr<Timer> optimize_timer = clock_->GetNewTimer(); OptimizeStatsProto* optimize_stats = result_proto.mutable_optimize_stats(); - int64_t before_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); - optimize_stats->set_storage_size_before( - Filesystem::SanitizeFileSize(before_size)); + ScopedTimer optimize_timer( + clock_->GetNewTimer(), + [optimize_stats](int64_t t) { optimize_stats->set_latency_ms(t); }); // Flushes data to disk before doing optimization auto status = InternalPersistToDisk(PersistType::FULL); @@ -1211,52 +1228,86 @@ OptimizeResultProto IcingSearchEngine::Optimize() { return result_proto; } + int64_t before_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); + optimize_stats->set_storage_size_before( + Filesystem::SanitizeFileSize(before_size)); + // TODO(b/143646633): figure out if we need to optimize index and doc store // at the same time. std::unique_ptr<Timer> optimize_doc_store_timer = clock_->GetNewTimer(); - libtextclassifier3::Status optimization_status = - OptimizeDocumentStore(optimize_stats); + libtextclassifier3::StatusOr<std::vector<DocumentId>> + document_id_old_to_new_or = OptimizeDocumentStore(optimize_stats); optimize_stats->set_document_store_optimize_latency_ms( optimize_doc_store_timer->GetElapsedMilliseconds()); - if (!optimization_status.ok() && - !absl_ports::IsDataLoss(optimization_status)) { + if (!document_id_old_to_new_or.ok() && + !absl_ports::IsDataLoss(document_id_old_to_new_or.status())) { // The status now is either ABORTED_ERROR or INTERNAL_ERROR. // If ABORTED_ERROR, Icing should still be working. // If INTERNAL_ERROR, we're having IO errors or other errors that we can't // recover from. - TransformStatus(optimization_status, result_status); + TransformStatus(document_id_old_to_new_or.status(), result_status); return result_proto; } // The status is either OK or DATA_LOSS. The optimized document store is // guaranteed to work, so we update index according to the new document store. std::unique_ptr<Timer> optimize_index_timer = clock_->GetNewTimer(); - libtextclassifier3::Status index_reset_status = index_->Reset(); - if (!index_reset_status.ok()) { - status = absl_ports::Annotate( - absl_ports::InternalError("Failed to reset index after optimization."), - index_reset_status.error_message()); - TransformStatus(status, result_status); - return result_proto; + bool should_rebuild_index = + !document_id_old_to_new_or.ok() || ShouldRebuildIndex(*optimize_stats); + if (!should_rebuild_index) { + optimize_stats->set_index_restoration_mode( + OptimizeStatsProto::INDEX_TRANSLATION); + libtextclassifier3::Status index_optimize_status = + index_->Optimize(document_id_old_to_new_or.ValueOrDie(), + document_store_->last_added_document_id()); + if (!index_optimize_status.ok()) { + ICING_LOG(WARNING) << "Failed to optimize index. Error: " + << index_optimize_status.error_message(); + should_rebuild_index = true; + } } + // If we received a DATA_LOSS error from OptimizeDocumentStore, we have a + // valid document store, but it might be the old one or the new one. So throw + // out the index and rebuild from scratch. + // Likewise, if Index::Optimize failed, then attempt to recover the index by + // rebuilding from scratch. + // If ShouldRebuildIndex() returns true, we will also rebuild the index for + // better performance. + if (should_rebuild_index) { + optimize_stats->set_index_restoration_mode( + OptimizeStatsProto::FULL_INDEX_REBUILD); + ICING_LOG(WARNING) << "Resetting the entire index!"; + libtextclassifier3::Status index_reset_status = index_->Reset(); + if (!index_reset_status.ok()) { + status = absl_ports::Annotate( + absl_ports::InternalError("Failed to reset index."), + index_reset_status.error_message()); + TransformStatus(status, result_status); + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); + return result_proto; + } - IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded(); - optimize_stats->set_index_restoration_latency_ms( - optimize_index_timer->GetElapsedMilliseconds()); - // DATA_LOSS means that we have successfully re-added content to the index. - // Some indexed content was lost, but otherwise the index is in a valid state - // and can be queried. - if (!index_restoration_status.status.ok() && - !absl_ports::IsDataLoss(index_restoration_status.status)) { - status = absl_ports::Annotate( - absl_ports::InternalError( - "Failed to reindex documents after optimization."), - index_restoration_status.status.error_message()); + IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded(); + // DATA_LOSS means that we have successfully re-added content to the index. + // Some indexed content was lost, but otherwise the index is in a valid + // state and can be queried. + if (!index_restoration_status.status.ok() && + !absl_ports::IsDataLoss(index_restoration_status.status)) { + status = absl_ports::Annotate( + absl_ports::InternalError( + "Failed to reindex documents after optimization."), + index_restoration_status.status.error_message()); - TransformStatus(status, result_status); - return result_proto; + TransformStatus(status, result_status); + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); + return result_proto; + } } + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); // Read the optimize status to get the time that we last ran. std::string optimize_status_filename = @@ -1278,12 +1329,18 @@ OptimizeResultProto IcingSearchEngine::Optimize() { optimize_status->set_last_successful_optimize_run_time_ms(current_time); optimize_status_file.Write(std::move(optimize_status)); + // Flushes data to disk after doing optimization + status = InternalPersistToDisk(PersistType::FULL); + if (!status.ok()) { + TransformStatus(status, result_status); + return result_proto; + } + int64_t after_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); optimize_stats->set_storage_size_after( Filesystem::SanitizeFileSize(after_size)); - optimize_stats->set_latency_ms(optimize_timer->GetElapsedMilliseconds()); - TransformStatus(optimization_status, result_status); + TransformStatus(document_id_old_to_new_or.status(), result_status); return result_proto; } @@ -1374,6 +1431,46 @@ StorageInfoResultProto IcingSearchEngine::GetStorageInfo() { return result; } +DebugInfoResultProto IcingSearchEngine::GetDebugInfo( + DebugInfoVerbosity::Code verbosity) { + DebugInfoResultProto debug_info; + StatusProto* result_status = debug_info.mutable_status(); + absl_ports::shared_lock l(&mutex_); + if (!initialized_) { + debug_info.mutable_status()->set_code(StatusProto::FAILED_PRECONDITION); + debug_info.mutable_status()->set_message( + "IcingSearchEngine has not been initialized!"); + return debug_info; + } + + // Index + *debug_info.mutable_debug_info()->mutable_index_info() = + index_->GetDebugInfo(verbosity); + + // Document Store + libtextclassifier3::StatusOr<DocumentDebugInfoProto> document_debug_info = + document_store_->GetDebugInfo(verbosity); + if (!document_debug_info.ok()) { + TransformStatus(document_debug_info.status(), result_status); + return debug_info; + } + *debug_info.mutable_debug_info()->mutable_document_info() = + std::move(document_debug_info).ValueOrDie(); + + // Schema Store + libtextclassifier3::StatusOr<SchemaDebugInfoProto> schema_debug_info = + schema_store_->GetDebugInfo(); + if (!schema_debug_info.ok()) { + TransformStatus(schema_debug_info.status(), result_status); + return debug_info; + } + *debug_info.mutable_debug_info()->mutable_schema_info() = + std::move(schema_debug_info).ValueOrDie(); + + result_status->set_code(StatusProto::OK); + return debug_info; +} + libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk( PersistType::Code persist_type) { if (persist_type == PersistType::LITE) { @@ -1401,7 +1498,9 @@ SearchResultProto IcingSearchEngine::Search( QueryStatsProto* query_stats = result_proto.mutable_query_stats(); query_stats->set_query_length(search_spec.query().length()); - std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); + ScopedTimer overall_timer(clock_->GetNewTimer(), [query_stats](int64_t t) { + query_stats->set_latency_ms(t); + }); libtextclassifier3::Status status = ValidateResultSpec(result_spec); if (!status.ok()) { @@ -1429,6 +1528,8 @@ SearchResultProto IcingSearchEngine::Search( document_store_.get(), schema_store_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); + query_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<QueryProcessor> query_processor = @@ -1437,6 +1538,8 @@ SearchResultProto IcingSearchEngine::Search( auto query_results_or = query_processor->ParseSearch(search_spec); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); + query_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } QueryProcessor::QueryResults query_results = @@ -1457,6 +1560,8 @@ SearchResultProto IcingSearchEngine::Search( scoring_spec, document_store_.get(), schema_store_.get()); if (!scoring_processor_or.ok()) { TransformStatus(scoring_processor_or.status(), result_status); + query_stats->set_scoring_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<ScoringProcessor> scoring_processor = @@ -1476,62 +1581,62 @@ SearchResultProto IcingSearchEngine::Search( } component_timer = clock_->GetNewTimer(); - // Ranks and paginates results - libtextclassifier3::StatusOr<PageResultState> page_result_state_or = - result_state_manager_->RankAndPaginate(ResultState( - std::move(result_document_hits), std::move(query_results.query_terms), - search_spec, scoring_spec, result_spec, *document_store_)); - if (!page_result_state_or.ok()) { - TransformStatus(page_result_state_or.status(), result_status); - return result_proto; - } - PageResultState page_result_state = - std::move(page_result_state_or).ValueOrDie(); + // Ranks results + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(result_document_hits), + /*is_descending=*/scoring_spec.order_by() == + ScoringSpecProto::Order::DESC); query_stats->set_ranking_latency_ms( component_timer->GetElapsedMilliseconds()); component_timer = clock_->GetNewTimer(); - // Retrieves the document protos and snippets if requested + // RanksAndPaginates and retrieves the document protos and snippets if + // requested auto result_retriever_or = - ResultRetriever::Create(document_store_.get(), schema_store_.get(), - language_segmenter_.get(), normalizer_.get()); + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get()); if (!result_retriever_or.ok()) { - result_state_manager_->InvalidateResultState( - page_result_state.next_page_token); TransformStatus(result_retriever_or.status(), result_status); + query_stats->set_document_retrieval_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } - std::unique_ptr<ResultRetriever> result_retriever = + std::unique_ptr<ResultRetrieverV2> result_retriever = std::move(result_retriever_or).ValueOrDie(); - libtextclassifier3::StatusOr<std::vector<SearchResultProto::ResultProto>> - results_or = result_retriever->RetrieveResults(page_result_state); - if (!results_or.ok()) { - result_state_manager_->InvalidateResultState( - page_result_state.next_page_token); - TransformStatus(results_or.status(), result_status); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + page_result_info_or = result_state_manager_->CacheAndRetrieveFirstPage( + std::move(ranker), std::move(query_results.query_terms), search_spec, + scoring_spec, result_spec, *document_store_, *result_retriever); + if (!page_result_info_or.ok()) { + TransformStatus(page_result_info_or.status(), result_status); + query_stats->set_document_retrieval_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } - std::vector<SearchResultProto::ResultProto> results = - std::move(results_or).ValueOrDie(); + std::pair<uint64_t, PageResult> page_result_info = + std::move(page_result_info_or).ValueOrDie(); // Assembles the final search result proto - result_proto.mutable_results()->Reserve(results.size()); - for (SearchResultProto::ResultProto& result : results) { + result_proto.mutable_results()->Reserve( + page_result_info.second.results.size()); + for (SearchResultProto::ResultProto& result : + page_result_info.second.results) { result_proto.mutable_results()->Add(std::move(result)); } + result_status->set_code(StatusProto::OK); - if (page_result_state.next_page_token != kInvalidNextPageToken) { - result_proto.set_next_page_token(page_result_state.next_page_token); + if (page_result_info.first != kInvalidNextPageToken) { + result_proto.set_next_page_token(page_result_info.first); } + query_stats->set_document_retrieval_latency_ms( component_timer->GetElapsedMilliseconds()); - query_stats->set_latency_ms(overall_timer->GetElapsedMilliseconds()); query_stats->set_num_results_returned_current_page( result_proto.results_size()); query_stats->set_num_results_with_snippets( - std::min(result_proto.results_size(), - result_spec.snippet_spec().num_to_snippet())); + page_result_info.second.num_results_with_snippets); return result_proto; } @@ -1552,53 +1657,46 @@ SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { query_stats->set_is_first_page(false); std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); - libtextclassifier3::StatusOr<PageResultState> page_result_state_or = - result_state_manager_->GetNextPage(next_page_token); - - if (!page_result_state_or.ok()) { - if (absl_ports::IsNotFound(page_result_state_or.status())) { - // NOT_FOUND means an empty result. - result_status->set_code(StatusProto::OK); - } else { - // Real error, pass up. - TransformStatus(page_result_state_or.status(), result_status); - } - return result_proto; - } - - PageResultState page_result_state = - std::move(page_result_state_or).ValueOrDie(); - query_stats->set_requested_page_size(page_result_state.requested_page_size); - - // Retrieves the document protos. auto result_retriever_or = - ResultRetriever::Create(document_store_.get(), schema_store_.get(), - language_segmenter_.get(), normalizer_.get()); + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get()); if (!result_retriever_or.ok()) { TransformStatus(result_retriever_or.status(), result_status); return result_proto; } - std::unique_ptr<ResultRetriever> result_retriever = + std::unique_ptr<ResultRetrieverV2> result_retriever = std::move(result_retriever_or).ValueOrDie(); - libtextclassifier3::StatusOr<std::vector<SearchResultProto::ResultProto>> - results_or = result_retriever->RetrieveResults(page_result_state); - if (!results_or.ok()) { - TransformStatus(results_or.status(), result_status); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + page_result_info_or = result_state_manager_->GetNextPage( + next_page_token, *result_retriever); + if (!page_result_info_or.ok()) { + if (absl_ports::IsNotFound(page_result_info_or.status())) { + // NOT_FOUND means an empty result. + result_status->set_code(StatusProto::OK); + } else { + // Real error, pass up. + TransformStatus(page_result_info_or.status(), result_status); + } return result_proto; } - std::vector<SearchResultProto::ResultProto> results = - std::move(results_or).ValueOrDie(); + + std::pair<uint64_t, PageResult> page_result_info = + std::move(page_result_info_or).ValueOrDie(); + query_stats->set_requested_page_size( + page_result_info.second.requested_page_size); // Assembles the final search result proto - result_proto.mutable_results()->Reserve(results.size()); - for (SearchResultProto::ResultProto& result : results) { + result_proto.mutable_results()->Reserve( + page_result_info.second.results.size()); + for (SearchResultProto::ResultProto& result : + page_result_info.second.results) { result_proto.mutable_results()->Add(std::move(result)); } result_status->set_code(StatusProto::OK); - if (page_result_state.next_page_token != kInvalidNextPageToken) { - result_proto.set_next_page_token(page_result_state.next_page_token); + if (page_result_info.first != kInvalidNextPageToken) { + result_proto.set_next_page_token(page_result_info.first); } // The only thing that we're doing is document retrieval. So document @@ -1609,12 +1707,8 @@ SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { query_stats->set_latency_ms(overall_timer->GetElapsedMilliseconds()); query_stats->set_num_results_returned_current_page( result_proto.results_size()); - int num_left_to_snippet = - std::max(page_result_state.snippet_context.snippet_spec.num_to_snippet() - - page_result_state.num_previously_returned, - 0); query_stats->set_num_results_with_snippets( - std::min(result_proto.results_size(), num_left_to_snippet)); + page_result_info.second.num_results_with_snippets); return result_proto; } @@ -1627,8 +1721,8 @@ void IcingSearchEngine::InvalidateNextPageToken(uint64_t next_page_token) { result_state_manager_->InvalidateResultState(next_page_token); } -libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( - OptimizeStatsProto* optimize_stats) { +libtextclassifier3::StatusOr<std::vector<DocumentId>> +IcingSearchEngine::OptimizeDocumentStore(OptimizeStatsProto* optimize_stats) { // Gets the current directory path and an empty tmp directory path for // document store optimization. const std::string current_document_dir = @@ -1644,15 +1738,16 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( } // Copies valid document data to tmp directory - auto optimize_status = document_store_->OptimizeInto( - temporary_document_dir, language_segmenter_.get(), optimize_stats); + libtextclassifier3::StatusOr<std::vector<DocumentId>> + document_id_old_to_new_or = document_store_->OptimizeInto( + temporary_document_dir, language_segmenter_.get(), optimize_stats); // Handles error if any - if (!optimize_status.ok()) { + if (!document_id_old_to_new_or.ok()) { filesystem_->DeleteDirectoryRecursively(temporary_document_dir.c_str()); return absl_ports::Annotate( absl_ports::AbortedError("Failed to optimize document store"), - optimize_status.error_message()); + document_id_old_to_new_or.status().error_message()); } // result_state_manager_ depends on document_store_. So we need to reset it at @@ -1695,7 +1790,8 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( } document_store_ = std::move(create_result_or.ValueOrDie().document_store); result_state_manager_ = std::make_unique<ResultStateManager>( - performance_configuration_.max_num_total_hits, *document_store_); + performance_configuration_.max_num_total_hits, *document_store_, + clock_.get()); // Potential data loss // TODO(b/147373249): Find a way to detect true data loss error @@ -1717,7 +1813,8 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( } document_store_ = std::move(create_result_or.ValueOrDie().document_store); result_state_manager_ = std::make_unique<ResultStateManager>( - performance_configuration_.max_num_total_hits, *document_store_); + performance_configuration_.max_num_total_hits, *document_store_, + clock_.get()); // Deletes tmp directory if (!filesystem_->DeleteDirectoryRecursively( @@ -1725,7 +1822,7 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( ICING_LOG(ERROR) << "Document store has been optimized, but it failed to " "delete temporary file directory"; } - return libtextclassifier3::Status::OK; + return document_id_old_to_new_or; } IcingSearchEngine::IndexRestorationResult diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index ff9c7fb..2eda803 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -20,13 +20,13 @@ #include <string> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/mutex.h" #include "icing/absl_ports/thread_annotations.h" #include "icing/file/filesystem.h" #include "icing/index/index.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/performance-configuration.h" #include "icing/proto/document.pb.h" @@ -403,6 +403,10 @@ class IcingSearchEngine { // that field will be set to -1. StorageInfoResultProto GetStorageInfo() ICING_LOCKS_EXCLUDED(mutex_); + // Get debug information for Icing. + DebugInfoResultProto GetDebugInfo(DebugInfoVerbosity::Code verbosity) + ICING_LOCKS_EXCLUDED(mutex_); + // Clears all data from Icing and re-initializes. Clients DO NOT need to call // Initialize again. // @@ -578,14 +582,16 @@ class IcingSearchEngine { // would need call Initialize() to reinitialize everything into a valid state. // // Returns: - // OK on success + // On success, a vector that maps from old document id to new document id. A + // value of kInvalidDocumentId indicates that the old document id has been + // deleted. // ABORTED_ERROR if any error happens before the actual optimization, the // original document store should be still available // DATA_LOSS_ERROR on errors that could potentially cause data loss, // document store is still available // INTERNAL_ERROR on any IO errors or other errors that we can't recover // from - libtextclassifier3::Status OptimizeDocumentStore( + libtextclassifier3::StatusOr<std::vector<DocumentId>> OptimizeDocumentStore( OptimizeStatsProto* optimize_stats) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); diff --git a/icing/icing-search-engine_benchmark.cc b/icing/icing-search-engine_benchmark.cc index 5e610d5..6db66f6 100644 --- a/icing/icing-search-engine_benchmark.cc +++ b/icing/icing-search-engine_benchmark.cc @@ -51,7 +51,7 @@ // //icing:icing-search-engine_benchmark // // $ blaze-bin/icing/icing-search-engine_benchmark -// --benchmarks=all --benchmark_memory_usage +// --benchmark_filter=all --benchmark_memory_usage // // Run on an Android device: // $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" @@ -61,7 +61,8 @@ // $ adb push blaze-bin/icing/icing-search-engine_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/icing-search-engine_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/icing-search-engine_benchmark +// --benchmark_filter=all namespace icing { namespace lib { @@ -222,24 +223,19 @@ void BM_IndexLatency(benchmark::State& state) { std::unique_ptr<IcingSearchEngine> icing = std::make_unique<IcingSearchEngine>(options); - ASSERT_THAT(icing->Initialize().status(), ProtoIsOk()); - ASSERT_THAT(icing->SetSchema(schema).status(), ProtoIsOk()); - int num_docs = state.range(0); std::vector<std::string> language = CreateLanguages(kLanguageSize, &random); const std::vector<DocumentProto> random_docs = GenerateRandomDocuments(&type_selector, num_docs, language); - Timer timer; - for (const DocumentProto& doc : random_docs) { - ASSERT_THAT(icing->Put(doc).status(), ProtoIsOk()); + for (auto _ : state) { + state.PauseTiming(); + ASSERT_THAT(icing->Reset().status(), ProtoIsOk()); + ASSERT_THAT(icing->SetSchema(schema).status(), ProtoIsOk()); + state.ResumeTiming(); + for (const DocumentProto& doc : random_docs) { + ASSERT_THAT(icing->Put(doc).status(), ProtoIsOk()); + } } - int64_t time_taken_ns = timer.GetElapsedNanoseconds(); - int64_t time_per_doc_ns = time_taken_ns / num_docs; - std::cout << "Number of indexed documents:\t" << num_docs - << "\t\tNumber of indexed sections:\t" << state.range(1) - << "\t\tTime taken (ms):\t" << time_taken_ns / 1000000 - << "\t\tTime taken per doc (us):\t" << time_per_doc_ns / 1000 - << std::endl; } BENCHMARK(BM_IndexLatency) // Arguments: num_indexed_documents, num_sections diff --git a/icing/icing-search-engine_flush_benchmark.cc b/icing/icing-search-engine_flush_benchmark.cc index de8f550..04e83fe 100644 --- a/icing/icing-search-engine_flush_benchmark.cc +++ b/icing/icing-search-engine_flush_benchmark.cc @@ -48,7 +48,7 @@ // //icing:icing-search-engine_flush_benchmark // // $ blaze-bin/icing/icing-search-engine_flush_benchmark -// --benchmarks=all --benchmark_memory_usage +// --benchmark_filter=all --benchmark_memory_usage // // Run on an Android device: // $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" @@ -59,7 +59,7 @@ // /data/local/tmp/ // // $ adb shell /data/local/tmp/icing-search-engine_flush_benchmark -// --benchmarks=all +// --benchmark_filter=all namespace icing { namespace lib { diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index 13e77b8..f862e45 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -20,13 +20,13 @@ #include <string> #include <utility> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/file/mock-filesystem.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/portable/endian.h" #include "icing/portable/equals-proto.h" @@ -2274,7 +2274,12 @@ TEST_F(IcingSearchEngineTest, SearchReturnsScoresCreationTimestamp) { } TEST_F(IcingSearchEngineTest, SearchReturnsOneResult) { - IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + auto fake_clock = std::make_unique<FakeClock>(); + fake_clock->SetTimerElapsedMilliseconds(1000); + TestIcingSearchEngine icing(GetDefaultIcingOptions(), + std::make_unique<Filesystem>(), + std::make_unique<IcingFilesystem>(), + std::move(fake_clock), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); @@ -2299,6 +2304,15 @@ TEST_F(IcingSearchEngineTest, SearchReturnsOneResult) { SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + + EXPECT_THAT(search_result_proto.query_stats().latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().parse_query_latency_ms(), + Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().scoring_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), + Eq(1000)); + // The token is a random number so we don't verify it. expected_search_result_proto.set_next_page_token( search_result_proto.next_page_token()); @@ -2347,6 +2361,30 @@ TEST_F(IcingSearchEngineTest, SearchNegativeResultLimitReturnsInvalidArgument) { expected_search_result_proto)); } +TEST_F(IcingSearchEngineTest, + SearchNonPositivePageTotalBytesLimitReturnsInvalidArgument) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query(""); + + ResultSpecProto result_spec; + result_spec.set_num_total_bytes_per_page_threshold(-1); + + SearchResultProto actual_results1 = + icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + EXPECT_THAT(actual_results1.status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); + + result_spec.set_num_total_bytes_per_page_threshold(0); + SearchResultProto actual_results2 = + icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + EXPECT_THAT(actual_results2.status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + TEST_F(IcingSearchEngineTest, SearchWithPersistenceReturnsValidResults) { IcingSearchEngineOptions icing_options = GetDefaultIcingOptions(); @@ -2403,7 +2441,12 @@ TEST_F(IcingSearchEngineTest, SearchWithPersistenceReturnsValidResults) { } TEST_F(IcingSearchEngineTest, SearchShouldReturnEmpty) { - IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + auto fake_clock = std::make_unique<FakeClock>(); + fake_clock->SetTimerElapsedMilliseconds(1000); + TestIcingSearchEngine icing(GetDefaultIcingOptions(), + std::make_unique<Filesystem>(), + std::make_unique<IcingFilesystem>(), + std::move(fake_clock), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); @@ -2418,6 +2461,15 @@ TEST_F(IcingSearchEngineTest, SearchShouldReturnEmpty) { SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + + EXPECT_THAT(search_result_proto.query_stats().latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().parse_query_latency_ms(), + Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().scoring_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(0)); + EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), + Eq(0)); EXPECT_THAT(search_result_proto, EqualsSearchResultIgnoreStatsAndScores( expected_search_result_proto)); @@ -2894,10 +2946,11 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { DocumentProto document1 = CreateMessageDocument("namespace", "uri1"); DocumentProto document2 = CreateMessageDocument("namespace", "uri2"); DocumentProto document3 = CreateMessageDocument("namespace", "uri3"); + DocumentProto document4 = CreateMessageDocument("namespace", "uri4"); + DocumentProto document5 = CreateMessageDocument("namespace", "uri5"); GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); - *expected_get_result_proto.mutable_document() = document1; { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); @@ -2905,27 +2958,97 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Delete("namespace", "uri2").status(), ProtoIsOk()); ASSERT_THAT(icing.Optimize().status(), ProtoIsOk()); // Validates that Get() and Put() are good right after Optimize() + *expected_get_result_proto.mutable_document() = document1; EXPECT_THAT( icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); - EXPECT_THAT(icing.Put(document2).status(), ProtoIsOk()); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + *expected_get_result_proto.mutable_document() = document3; + EXPECT_THAT( + icing.Get("namespace", "uri3", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + EXPECT_THAT(icing.Put(document4).status(), ProtoIsOk()); } // Destroys IcingSearchEngine to make sure nothing is cached. IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + *expected_get_result_proto.mutable_document() = document1; EXPECT_THAT( icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); - - *expected_get_result_proto.mutable_document() = document2; EXPECT_THAT( - icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + *expected_get_result_proto.mutable_document() = document3; + EXPECT_THAT( + icing.Get("namespace", "uri3", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); + *expected_get_result_proto.mutable_document() = document4; + EXPECT_THAT( + icing.Get("namespace", "uri4", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + + EXPECT_THAT(icing.Put(document5).status(), ProtoIsOk()); +} + +TEST_F(IcingSearchEngineTest, + GetAndPutShouldWorkAfterOptimizationWithEmptyDocuments) { + DocumentProto empty_document1 = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto empty_document2 = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto empty_document3 = + DocumentBuilder() + .SetKey("namespace", "uri3") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); - EXPECT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(empty_document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(empty_document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Delete("namespace", "uri2").status(), ProtoIsOk()); + ASSERT_THAT(icing.Optimize().status(), ProtoIsOk()); + + // Validates that Get() and Put() are good right after Optimize() + *expected_get_result_proto.mutable_document() = empty_document1; + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + EXPECT_THAT(icing.Put(empty_document3).status(), ProtoIsOk()); } TEST_F(IcingSearchEngineTest, DeleteShouldWorkAfterOptimization) { @@ -3003,7 +3126,6 @@ TEST_F(IcingSearchEngineTest, OptimizationFailureUninitializesIcing) { HasSubstr("document_dir"))) .WillByDefault(swap_lambda); TestIcingSearchEngine icing(options, std::move(mock_filesystem), - std::move(mock_filesystem), std::make_unique<IcingFilesystem>(), std::make_unique<FakeClock>(), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); @@ -3822,8 +3944,11 @@ TEST_F(IcingSearchEngineTest, ProtoIsOk()); // Optimize() fails due to filesystem error - EXPECT_THAT(icing.Optimize().status(), - ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + OptimizeResultProto result = icing.Optimize(); + EXPECT_THAT(result.status(), ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + // Should rebuild the index for data loss. + EXPECT_THAT(result.optimize_stats().index_restoration_mode(), + Eq(OptimizeStatsProto::FULL_INDEX_REBUILD)); // Document is not found because original file directory is missing GetResultProto expected_get_result_proto; @@ -3896,8 +4021,11 @@ TEST_F(IcingSearchEngineTest, OptimizationShouldRecoverIfDataFilesAreMissing) { ProtoIsOk()); // Optimize() fails due to filesystem error - EXPECT_THAT(icing.Optimize().status(), - ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + OptimizeResultProto result = icing.Optimize(); + EXPECT_THAT(result.status(), ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + // Should rebuild the index for data loss. + EXPECT_THAT(result.optimize_stats().index_restoration_mode(), + Eq(OptimizeStatsProto::FULL_INDEX_REBUILD)); // Document is not found because original files are missing GetResultProto expected_get_result_proto; @@ -7868,6 +7996,7 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { expected.set_num_original_documents(3); expected.set_num_deleted_documents(1); expected.set_num_expired_documents(1); + expected.set_index_restoration_mode(OptimizeStatsProto::INDEX_TRANSLATION); // Run Optimize OptimizeResultProto result = icing->Optimize(); @@ -7900,6 +8029,7 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { expected.set_num_deleted_documents(0); expected.set_num_expired_documents(0); expected.set_time_since_last_optimize_ms(10000); + expected.set_index_restoration_mode(OptimizeStatsProto::INDEX_TRANSLATION); // Run Optimize result = icing->Optimize(); @@ -7908,6 +8038,29 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { result.mutable_optimize_stats()->clear_storage_size_before(); result.mutable_optimize_stats()->clear_storage_size_after(); EXPECT_THAT(result.optimize_stats(), EqualsProto(expected)); + + // Delete the last document. + ASSERT_THAT(icing->Delete(document3.namespace_(), document3.uri()).status(), + ProtoIsOk()); + + expected = OptimizeStatsProto(); + expected.set_latency_ms(5); + expected.set_document_store_optimize_latency_ms(5); + expected.set_index_restoration_latency_ms(5); + expected.set_num_original_documents(1); + expected.set_num_deleted_documents(1); + expected.set_num_expired_documents(0); + expected.set_time_since_last_optimize_ms(0); + // Should rebuild the index since all documents are removed. + expected.set_index_restoration_mode(OptimizeStatsProto::FULL_INDEX_REBUILD); + + // Run Optimize + result = icing->Optimize(); + EXPECT_THAT(result.optimize_stats().storage_size_before(), + Ge(result.optimize_stats().storage_size_after())); + result.mutable_optimize_stats()->clear_storage_size_before(); + result.mutable_optimize_stats()->clear_storage_size_after(); + EXPECT_THAT(result.optimize_stats(), EqualsProto(expected)); } TEST_F(IcingSearchEngineTest, StorageInfoTest) { @@ -8680,6 +8833,81 @@ TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_NonPositiveNumToReturn) { ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); } +TEST_F(IcingSearchEngineTest, GetDebugInfoVerbosityBasicSucceeds) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + // Create a document. + DocumentProto document = CreateMessageDocument("namespace", "email"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + DebugInfoResultProto result = icing.GetDebugInfo(DebugInfoVerbosity::BASIC); + EXPECT_THAT(result.status(), ProtoIsOk()); + + // Some sanity checks + DebugInfoProto debug_info = result.debug_info(); + EXPECT_THAT( + debug_info.document_info().document_storage_info().num_alive_documents(), + Eq(1)); + EXPECT_THAT(debug_info.document_info().corpus_info(), + IsEmpty()); // because verbosity=BASIC + EXPECT_THAT(debug_info.schema_info().crc(), Gt(0)); +} + +TEST_F(IcingSearchEngineTest, + GetDebugInfoVerbosityDetailedSucceedsWithCorpusInfo) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + // Create 4 documents. + DocumentProto document1 = CreateMessageDocument("namespace1", "email/1"); + DocumentProto document2 = CreateMessageDocument("namespace1", "email/2"); + DocumentProto document3 = CreateMessageDocument("namespace2", "email/3"); + DocumentProto document4 = CreateMessageDocument("namespace2", "email/4"); + ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document4).status(), ProtoIsOk()); + + DebugInfoResultProto result = + icing.GetDebugInfo(DebugInfoVerbosity::DETAILED); + EXPECT_THAT(result.status(), ProtoIsOk()); + + // Some sanity checks + DebugInfoProto debug_info = result.debug_info(); + EXPECT_THAT( + debug_info.document_info().document_storage_info().num_alive_documents(), + Eq(4)); + EXPECT_THAT(debug_info.document_info().corpus_info(), SizeIs(2)); + EXPECT_THAT(debug_info.schema_info().crc(), Gt(0)); +} + +TEST_F(IcingSearchEngineTest, GetDebugInfoUninitialized) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + DebugInfoResultProto result = + icing.GetDebugInfo(DebugInfoVerbosity::DETAILED); + EXPECT_THAT(result.status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); +} + +TEST_F(IcingSearchEngineTest, GetDebugInfoNoSchemaNoDocumentsSucceeds) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + DebugInfoResultProto result = + icing.GetDebugInfo(DebugInfoVerbosity::DETAILED); + ASSERT_THAT(result.status(), ProtoIsOk()); +} + +TEST_F(IcingSearchEngineTest, GetDebugInfoWithSchemaNoDocumentsSucceeds) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + DebugInfoResultProto result = + icing.GetDebugInfo(DebugInfoVerbosity::DETAILED); + ASSERT_THAT(result.status(), ProtoIsOk()); +} + #ifndef ICING_JNI_TEST // We skip this test case when we're running in a jni_test since the data files // will be stored in the android-instrumented storage location, rather than the diff --git a/icing/index/hit/hit.cc b/icing/index/hit/hit.cc index 887e6e4..ce1c366 100644 --- a/icing/index/hit/hit.cc +++ b/icing/index/hit/hit.cc @@ -97,6 +97,11 @@ bool Hit::is_in_prefix_section() const { return bit_util::BitfieldGet(value(), kInPrefixSection, 1); } +Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) { + return Hit(old_hit.section_id(), new_document_id, old_hit.term_frequency(), + old_hit.is_in_prefix_section(), old_hit.is_prefix_hit()); +} + bool Hit::EqualsDocumentIdAndSectionId::operator()(const Hit& hit1, const Hit& hit2) const { return (hit1.value() >> kNumFlags) == (hit2.value() >> kNumFlags); diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index ee1f64b..f8cbd78 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -77,6 +77,9 @@ class Hit { bool is_prefix_hit() const; bool is_in_prefix_section() const; + // Creates a new hit based on old_hit but with new_document_id set. + static Hit TranslateHit(Hit old_hit, DocumentId new_document_id); + bool operator<(const Hit& h2) const { return value() < h2.value(); } bool operator==(const Hit& h2) const { return value() == h2.value(); } diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc index 207c033..edc7881 100644 --- a/icing/index/index-processor.cc +++ b/icing/index/index-processor.cc @@ -67,6 +67,11 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( uint32_t num_tokens = 0; libtextclassifier3::Status status; for (const TokenizedSection& section : tokenized_document.sections()) { + if (section.metadata.tokenizer == + StringIndexingConfig::TokenizerType::NONE) { + ICING_LOG(WARNING) + << "Unexpected TokenizerType::NONE found when indexing document."; + } // TODO(b/152934343): pass real namespace ids in Index::Editor editor = index_->Edit(document_id, section.metadata.id, @@ -82,8 +87,6 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( status = editor.BufferTerm(token.data()); break; case StringIndexingConfig::TokenizerType::NONE: - ICING_LOG(WARNING) - << "Unexpected TokenizerType::NONE found when indexing document."; [[fallthrough]]; case StringIndexingConfig::TokenizerType::PLAIN: std::string normalized_term = normalizer_.NormalizeTerm(token); diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc index 1aad7d0..68c592c 100644 --- a/icing/index/index-processor_benchmark.cc +++ b/icing/index/index-processor_benchmark.cc @@ -39,7 +39,7 @@ // //icing/index:index-processor_benchmark // // $ blaze-bin/icing/index/index-processor_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // Make target //icing/tokenization:language-segmenter depend on @@ -55,7 +55,7 @@ // $ adb push blaze-bin/icing/index/index-processor_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/index-processor_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/index-processor_benchmark --benchmark_filter=all // --adb // Flag to tell the benchmark that it'll be run on an Android device via adb, diff --git a/icing/index/index.cc b/icing/index/index.cc index 02ba699..6004ed3 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -264,6 +264,16 @@ IndexStorageInfoProto Index::GetStorageInfo() const { return main_index_->GetStorageInfo(std::move(storage_info)); } +libtextclassifier3::Status Index::Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id) { + if (main_index_->last_added_document_id() != kInvalidDocumentId) { + ICING_RETURN_IF_ERROR(main_index_->Optimize(document_id_old_to_new)); + } + return lite_index_->Optimize(document_id_old_to_new, term_id_codec_.get(), + new_last_added_document_id); +} + libtextclassifier3::Status Index::Editor::BufferTerm(const char* term) { // Step 1: See if this term is already in the lexicon uint32_t tvi; diff --git a/icing/index/index.h b/icing/index/index.h index 5c53349..55f2358 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -140,11 +140,11 @@ class Index { } // Returns debug information for the index in out. - // verbosity <= 0, simplest debug information - just the lexicons and lite - // index. - // verbosity > 0, more detailed debug information including raw postings - // lists. - IndexDebugInfoProto GetDebugInfo(int verbosity) const { + // verbosity = BASIC, simplest debug information - just the lexicons and lite + // index. + // verbosity = DETAILED, more detailed debug information including raw + // postings lists. + IndexDebugInfoProto GetDebugInfo(DebugInfoVerbosity::Code verbosity) const { IndexDebugInfoProto debug_info; *debug_info.mutable_index_storage_info() = GetStorageInfo(); *debug_info.mutable_lite_index_info() = @@ -263,6 +263,18 @@ class Index { return lite_index_->Reset(); } + // Reduces internal file sizes by reclaiming space of deleted documents. + // new_last_added_document_id will be used to update the last added document + // id in the lite index. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id); + private: Index(const Options& options, std::unique_ptr<TermIdCodec> term_id_codec, std::unique_ptr<LiteIndex> lite_index, diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 8355c01..23945de 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -14,6 +14,7 @@ #include "icing/index/index.h" +#include <algorithm> #include <cstdint> #include <limits> #include <memory> @@ -41,12 +42,14 @@ #include "icing/testing/random-string.h" #include "icing/testing/tmp-directory.h" #include "icing/util/crc32.h" +#include "icing/util/logging.h" namespace icing { namespace lib { namespace { +using ::testing::ContainerEq; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Ge; @@ -58,6 +61,8 @@ using ::testing::NiceMock; using ::testing::Not; using ::testing::Return; using ::testing::SizeIs; +using ::testing::StrEq; +using ::testing::StrNe; using ::testing::Test; using ::testing::UnorderedElementsAre; @@ -76,10 +81,27 @@ class IndexTest : public Test { icing_filesystem_.DeleteDirectoryRecursively(index_dir_.c_str()); } - std::unique_ptr<Index> index_; - std::string index_dir_; - IcingFilesystem icing_filesystem_; + std::vector<DocHitInfo> GetHits( + std::unique_ptr<DocHitInfoIterator> iterator) { + std::vector<DocHitInfo> infos; + while (iterator->Advance().ok()) { + infos.push_back(iterator->doc_hit_info()); + } + return infos; + } + + libtextclassifier3::StatusOr<std::vector<DocHitInfo>> GetHits( + std::string term, TermMatchType::Code match_type) { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<DocHitInfoIterator> itr, + index_->GetIterator(term, kSectionIdMaskAll, match_type)); + return GetHits(std::move(itr)); + } + Filesystem filesystem_; + IcingFilesystem icing_filesystem_; + std::string index_dir_; + std::unique_ptr<Index> index_; }; constexpr DocumentId kDocumentId0 = 0; @@ -94,14 +116,6 @@ constexpr DocumentId kDocumentId8 = 8; constexpr SectionId kSectionId2 = 2; constexpr SectionId kSectionId3 = 3; -std::vector<DocHitInfo> GetHits(std::unique_ptr<DocHitInfoIterator> iterator) { - std::vector<DocHitInfo> infos; - while (iterator->Advance().ok()) { - infos.push_back(iterator->doc_hit_info()); - } - return infos; -} - MATCHER_P2(EqualsDocHitInfo, document_id, sections, "") { const DocHitInfo& actual = arg; SectionIdMask section_mask = kSectionIdMaskNone; @@ -246,6 +260,72 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterMerge) { kDocumentId0, std::vector<SectionId>{kSectionId2}))); } +TEST_F(IndexTest, SingleHitSingleTermIndexAfterOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId2, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Mapping to a different docid will translate the hit + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Mapping to kInvalidDocumentId will remove the hit. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId0); +} + +TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Merge()); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId2, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Mapping to a different docid will translate the hit + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Mapping to kInvalidDocumentId will remove the hit. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), 0); +} + TEST_F(IndexTest, SingleHitMultiTermIndex) { Index::Editor edit = index_->Edit( kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); @@ -278,6 +358,118 @@ TEST_F(IndexTest, SingleHitMultiTermIndexAfterMerge) { kDocumentId0, std::vector<SectionId>{kSectionId2}))); } +TEST_F(IndexTest, MultiHitMultiTermIndexAfterOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("bar"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId2, kSectionId3, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId2, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Delete document id 1, and document id 2 is translated to 1. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Delete all the rest documents. + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId}, + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + +TEST_F(IndexTest, MultiHitMultiTermIndexAfterMergeAndOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("bar"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId2, kSectionId3, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Merge()); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId2, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Delete document id 1, and document id 2 is translated to 1. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Delete all the rest documents. + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId}, + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + TEST_F(IndexTest, NoHitMultiTermIndex) { Index::Editor edit = index_->Edit( kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); @@ -804,6 +996,118 @@ TEST_F(IndexTest, FullIndexMerge) { EXPECT_THAT(last_itr->doc_hit_info().document_id(), Eq(document_id + 1)); } +TEST_F(IndexTest, OptimizeShouldWorkForEmptyIndex) { + // Optimize an empty index should succeed, but have no effects. + ICING_ASSERT_OK( + index_->Optimize(std::vector<DocumentId>(), + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<DocHitInfoIterator> itr, + index_->GetIterator("", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); + EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); + + ICING_ASSERT_OK_AND_ASSIGN( + itr, index_->GetIterator("", kSectionIdMaskAll, TermMatchType::PREFIX)); + EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); +} + +TEST_F(IndexTest, IndexOptimize) { + std::string prefix = "prefix"; + std::default_random_engine random; + std::vector<std::string> query_terms; + // Add 1024 hits to main index, and 1024 hits to lite index. + for (int i = 0; i < 2048; ++i) { + if (i == 1024) { + ICING_ASSERT_OK(index_->Merge()); + } + // Generate a unique term for document i. + query_terms.push_back(prefix + RandomString("abcdefg", 5, &random) + + std::to_string(i)); + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + SectionId section_id = i % 5; + if (section_id == 2) { + // Make section 2 an exact section. + term_match_type = TermMatchType::EXACT_ONLY; + } + Index::Editor edit = index_->Edit(/*document_id=*/i, section_id, + term_match_type, /*namespace_id=*/0); + ICING_ASSERT_OK(edit.BufferTerm(query_terms.at(i).c_str())); + ICING_ASSERT_OK(edit.IndexAllBufferedTerms()); + index_->set_last_added_document_id(i); + } + + // Delete one document for every three documents. + DocumentId document_id = 0; + DocumentId new_last_added_document_id = kInvalidDocumentId; + std::vector<DocumentId> document_id_old_to_new; + for (int i = 0; i < 2048; ++i) { + if (i % 3 == 0) { + document_id_old_to_new.push_back(kInvalidDocumentId); + } else { + new_last_added_document_id = document_id++; + document_id_old_to_new.push_back(new_last_added_document_id); + } + } + + std::vector<DocHitInfo> exp_prefix_hits; + for (int i = 0; i < 2048; ++i) { + if (document_id_old_to_new[i] == kInvalidDocumentId) { + continue; + } + if (i % 5 == 2) { + // Section 2 is an exact section, so we should not see any hits in + // prefix search. + continue; + } + exp_prefix_hits.push_back(DocHitInfo(document_id_old_to_new[i])); + exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 5, + /*hit_term_frequency=*/1); + } + std::reverse(exp_prefix_hits.begin(), exp_prefix_hits.end()); + + // Check that optimize is correct + ICING_ASSERT_OK( + index_->Optimize(document_id_old_to_new, new_last_added_document_id)); + EXPECT_EQ(index_->last_added_document_id(), new_last_added_document_id); + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<DocHitInfo> hits, + GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + // Check exact search. + for (int i = 0; i < 2048; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(query_terms[i], TermMatchType::EXACT_ONLY)); + if (document_id_old_to_new[i] == kInvalidDocumentId) { + EXPECT_THAT(hits, IsEmpty()); + } else { + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + document_id_old_to_new[i], + std::vector<SectionId>{(SectionId)(i % 5)}))); + } + } + + // Check that optimize does not block merge. + ICING_ASSERT_OK(index_->Merge()); + EXPECT_EQ(index_->last_added_document_id(), new_last_added_document_id); + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(hits, GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + // Check exact search. + for (int i = 0; i < 2048; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(query_terms[i], TermMatchType::EXACT_ONLY)); + if (document_id_old_to_new[i] == kInvalidDocumentId) { + EXPECT_THAT(hits, IsEmpty()); + } else { + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + document_id_old_to_new[i], + std::vector<SectionId>{(SectionId)(i % 5)}))); + } + } +} + TEST_F(IndexTest, IndexCreateIOFailure) { // Create the index with mock filesystem. By default, Mock will return false, // so the first attempted file operation will fail. @@ -1410,17 +1714,19 @@ TEST_F(IndexTest, GetDebugInfo) { ASSERT_THAT(edit.BufferTerm("foo"), IsOk()); EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); - IndexDebugInfoProto out0 = index_->GetDebugInfo(/*verbosity=*/0); - EXPECT_FALSE(out0.main_index_info().has_flash_index_storage_info()); - EXPECT_THAT(out0.main_index_info().last_added_document_id(), - Eq(kDocumentId1)); - EXPECT_THAT(out0.lite_index_info().curr_size(), Eq(2)); - EXPECT_THAT(out0.lite_index_info().last_added_document_id(), - Eq(kDocumentId2)); + IndexDebugInfoProto out0 = index_->GetDebugInfo(DebugInfoVerbosity::BASIC); + ICING_LOG(DBG) << "main_index_info:\n" << out0.main_index_info(); + ICING_LOG(DBG) << "lite_index_info:\n" << out0.lite_index_info(); + EXPECT_THAT(out0.main_index_info(), Not(IsEmpty())); + EXPECT_THAT(out0.lite_index_info(), Not(IsEmpty())); - IndexDebugInfoProto out1 = index_->GetDebugInfo(/*verbosity=*/1); - EXPECT_THAT(out1.main_index_info().flash_index_storage_info(), - Not(IsEmpty())); + IndexDebugInfoProto out1 = index_->GetDebugInfo(DebugInfoVerbosity::DETAILED); + ICING_LOG(DBG) << "main_index_info:\n" << out1.main_index_info(); + ICING_LOG(DBG) << "lite_index_info:\n" << out1.lite_index_info(); + EXPECT_THAT(out1.main_index_info(), + SizeIs(Gt(out0.main_index_info().size()))); + EXPECT_THAT(out1.lite_index_info(), + SizeIs(Gt(out0.lite_index_info().size()))); // Add one more doc to the lite index. Debug strings should change. edit = index_->Edit(kDocumentId3, kSectionId2, TermMatchType::EXACT_ONLY, @@ -1429,26 +1735,25 @@ TEST_F(IndexTest, GetDebugInfo) { ASSERT_THAT(edit.BufferTerm("far"), IsOk()); EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); - IndexDebugInfoProto out2 = index_->GetDebugInfo(/*verbosity=*/0); - EXPECT_THAT(out2.lite_index_info().curr_size(), Eq(3)); - EXPECT_THAT(out2.lite_index_info().last_added_document_id(), - Eq(kDocumentId3)); + IndexDebugInfoProto out2 = index_->GetDebugInfo(DebugInfoVerbosity::BASIC); + ICING_LOG(DBG) << "main_index_info:\n" << out2.main_index_info(); + ICING_LOG(DBG) << "lite_index_info:\n" << out2.lite_index_info(); + EXPECT_THAT(out2.main_index_info(), Not(IsEmpty())); + EXPECT_THAT(out2.lite_index_info(), Not(IsEmpty())); + EXPECT_THAT(out2.main_index_info(), StrEq(out0.main_index_info())); + EXPECT_THAT(out2.lite_index_info(), StrNe(out0.lite_index_info())); - // Merge into the man index. Debuug strings should change again. + // Merge into the man index. Debug strings should change again. ICING_ASSERT_OK(index_->Merge()); - IndexDebugInfoProto out3 = index_->GetDebugInfo(/*verbosity=*/0); + IndexDebugInfoProto out3 = index_->GetDebugInfo(DebugInfoVerbosity::BASIC); EXPECT_TRUE(out3.has_index_storage_info()); - EXPECT_THAT(out3.main_index_info().lexicon_info(), Not(IsEmpty())); - EXPECT_THAT(out3.main_index_info().last_added_document_id(), - Eq(kDocumentId3)); - EXPECT_THAT(out3.lite_index_info().curr_size(), Eq(0)); - EXPECT_THAT(out3.lite_index_info().hit_buffer_size(), Gt(0)); - EXPECT_THAT(out3.lite_index_info().last_added_document_id(), - Eq(kInvalidDocumentId)); - EXPECT_THAT(out3.lite_index_info().searchable_end(), Eq(0)); - EXPECT_THAT(out3.lite_index_info().index_crc(), Gt(0)); - EXPECT_THAT(out3.lite_index_info().lexicon_info(), Not(IsEmpty())); + ICING_LOG(DBG) << "main_index_info:\n" << out3.main_index_info(); + ICING_LOG(DBG) << "lite_index_info:\n" << out3.lite_index_info(); + EXPECT_THAT(out3.main_index_info(), Not(IsEmpty())); + EXPECT_THAT(out3.lite_index_info(), Not(IsEmpty())); + EXPECT_THAT(out3.main_index_info(), StrNe(out2.main_index_info())); + EXPECT_THAT(out3.lite_index_info(), StrNe(out2.lite_index_info())); } TEST_F(IndexTest, BackfillingMultipleTermsSucceeds) { diff --git a/icing/index/iterator/doc-hit-info-iterator-filter.cc b/icing/index/iterator/doc-hit-info-iterator-filter.cc index 933f9b5..2e8ba23 100644 --- a/icing/index/iterator/doc-hit-info-iterator-filter.cc +++ b/icing/index/iterator/doc-hit-info-iterator-filter.cc @@ -66,25 +66,19 @@ DocHitInfoIteratorFilter::DocHitInfoIteratorFilter( libtextclassifier3::Status DocHitInfoIteratorFilter::Advance() { while (delegate_->Advance().ok()) { - if (!document_store_.DoesDocumentExist( - delegate_->doc_hit_info().document_id())) { - // Document doesn't exist, keep searching. This handles deletions and - // expired documents. - continue; - } - // Try to get the DocumentFilterData - auto document_filter_data_or = document_store_.GetDocumentFilterData( - delegate_->doc_hit_info().document_id()); - if (!document_filter_data_or.ok()) { + auto document_filter_data_optional = + document_store_.GetAliveDocumentFilterData( + delegate_->doc_hit_info().document_id()); + if (!document_filter_data_optional) { // Didn't find the DocumentFilterData in the filter cache. This could be - // because the DocumentId isn't valid or the filter cache is in some - // invalid state. This is bad, but not the query's responsibility to fix, - // so just skip this result for now. + // because the Document doesn't exist or the DocumentId isn't valid or the + // filter cache is in some invalid state. This is bad, but not the query's + // responsibility to fix, so just skip this result for now. continue; } // We should be guaranteed that this exists now. - DocumentFilterData data = std::move(document_filter_data_or).ValueOrDie(); + DocumentFilterData data = document_filter_data_optional.value(); if (!options_.namespaces.empty() && target_namespace_ids_.count(data.namespace_id()) == 0) { diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc index 034c8cb..9d33e2c 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc @@ -51,15 +51,15 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { SectionIdMask section_id_mask = delegate_->doc_hit_info().hit_section_ids_mask(); - auto data_or = document_store_.GetDocumentFilterData(document_id); - if (!data_or.ok()) { + auto data_optional = + document_store_.GetAliveDocumentFilterData(document_id); + if (!data_optional) { // Ran into some error retrieving information on this hit, skip continue; } // Guaranteed that the DocumentFilterData exists at this point - DocumentFilterData data = std::move(data_or).ValueOrDie(); - SchemaTypeId schema_type_id = data.schema_type_id(); + SchemaTypeId schema_type_id = data_optional.value().schema_type_id(); // A hit can be in multiple sections at once, need to check that at least // one of the confirmed section ids match the name of the target section diff --git a/icing/index/iterator/doc-hit-info-iterator_benchmark.cc b/icing/index/iterator/doc-hit-info-iterator_benchmark.cc index f975989..993c3b8 100644 --- a/icing/index/iterator/doc-hit-info-iterator_benchmark.cc +++ b/icing/index/iterator/doc-hit-info-iterator_benchmark.cc @@ -35,7 +35,7 @@ namespace { // // $ // blaze-bin/icing/index/iterator/doc-hit-info-iterator_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // $ blaze build --config=android_arm64 -c opt --dynamic_mode=off @@ -47,7 +47,7 @@ namespace { // /data/local/tmp/ // // $ adb shell /data/local/tmp/doc-hit-info-iterator_benchmark -// --benchmarks=all +// --benchmark_filter=all // Functor to be used with std::generate to create a container of DocHitInfos. // DocHitInfos are generated starting at docid starting_docid and continuing at diff --git a/icing/legacy/index/icing-lite-index-header.h b/icing/index/lite/lite-index-header.h index ac2d3c0..dd6a0a8 100644 --- a/icing/legacy/index/icing-lite-index-header.h +++ b/icing/index/lite/lite-index-header.h @@ -16,15 +16,15 @@ #define ICING_LEGACY_INDEX_ICING_LITE_INDEX_HEADER_H_ #include "icing/legacy/core/icing-string-util.h" -#include "icing/legacy/index/icing-common-types.h" +#include "icing/store/document-id.h" namespace icing { namespace lib { // A wrapper around the actual mmapped header data. -class IcingLiteIndex_Header { +class LiteIndex_Header { public: - virtual ~IcingLiteIndex_Header() = default; + virtual ~LiteIndex_Header() = default; // Returns true if the magic of the header matches the hard-coded magic // value associated with this header format. @@ -47,7 +47,7 @@ class IcingLiteIndex_Header { virtual void Reset() = 0; }; -class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { +class LiteIndex_HeaderImpl : public LiteIndex_Header { public: struct HeaderData { static const uint32_t kMagic = 0x6dfba6a0; @@ -66,7 +66,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { uint32_t searchable_end; }; - explicit IcingLiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {} + explicit LiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {} bool check_magic() const override { return hdr_->magic == HeaderData::kMagic; @@ -97,7 +97,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { void Reset() override { hdr_->lite_index_crc = 0; hdr_->magic = HeaderData::kMagic; - hdr_->last_added_docid = kIcingInvalidDocId; + hdr_->last_added_docid = kInvalidDocumentId; hdr_->cur_size = 0; hdr_->searchable_end = 0; } @@ -105,7 +105,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { private: HeaderData *hdr_; }; -static_assert(24 == sizeof(IcingLiteIndex_HeaderImpl::HeaderData), +static_assert(24 == sizeof(LiteIndex_HeaderImpl::HeaderData), "sizeof(HeaderData) != 24"); } // namespace lib diff --git a/icing/legacy/index/icing-lite-index-options.cc b/icing/index/lite/lite-index-options.cc index 4bf0d38..29075f8 100644 --- a/icing/legacy/index/icing-lite-index-options.cc +++ b/icing/index/lite/lite-index-options.cc @@ -12,13 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/legacy/index/icing-lite-index-options.h" +#include "icing/index/lite/lite-index-options.h" + +#include "icing/index/lite/term-id-hit-pair.h" namespace icing { namespace lib { namespace { +constexpr int kIcingMaxVariantsPerToken = 10; // Maximum number of variants + +constexpr size_t kIcingMaxSearchableDocumentSize = (1u << 16) - 1; // 64K +// Max num tokens per document. 64KB is our original maximum (searchable) +// document size. We clip if document exceeds this. +constexpr uint32_t kIcingMaxNumTokensPerDoc = + kIcingMaxSearchableDocumentSize / 5; +constexpr uint32_t kIcingMaxNumHitsPerDocument = + kIcingMaxNumTokensPerDoc * kIcingMaxVariantsPerToken; + uint32_t CalculateHitBufferSize(uint32_t hit_buffer_want_merge_bytes) { constexpr uint32_t kHitBufferSlopMult = 2; @@ -27,7 +39,7 @@ uint32_t CalculateHitBufferSize(uint32_t hit_buffer_want_merge_bytes) { // TODO(b/111690435) Move LiteIndex::Element to a separate file so that this // can use sizeof(LiteIndex::Element) uint32_t hit_capacity_elts_with_slop = - hit_buffer_want_merge_bytes / sizeof(uint64_t); + hit_buffer_want_merge_bytes / sizeof(TermIdHitPair); // Add some slop for index variants on top of max num tokens. hit_capacity_elts_with_slop += kIcingMaxNumHitsPerDocument; hit_capacity_elts_with_slop *= kHitBufferSlopMult; @@ -51,8 +63,8 @@ IcingDynamicTrie::Options CalculateTrieOptions(uint32_t hit_buffer_size) { } // namespace -IcingLiteIndexOptions::IcingLiteIndexOptions( - const std::string& filename_base, uint32_t hit_buffer_want_merge_bytes) +LiteIndexOptions::LiteIndexOptions(const std::string& filename_base, + uint32_t hit_buffer_want_merge_bytes) : filename_base(filename_base), hit_buffer_want_merge_bytes(hit_buffer_want_merge_bytes) { hit_buffer_size = CalculateHitBufferSize(hit_buffer_want_merge_bytes); diff --git a/icing/legacy/index/icing-lite-index-options.h b/icing/index/lite/lite-index-options.h index 2922621..ae58802 100644 --- a/icing/legacy/index/icing-lite-index-options.h +++ b/icing/index/lite/lite-index-options.h @@ -15,20 +15,19 @@ #ifndef ICING_LEGACY_INDEX_ICING_LITE_INDEX_OPTIONS_H_ #define ICING_LEGACY_INDEX_ICING_LITE_INDEX_OPTIONS_H_ -#include "icing/legacy/index/icing-common-types.h" #include "icing/legacy/index/icing-dynamic-trie.h" namespace icing { namespace lib { -struct IcingLiteIndexOptions { - IcingLiteIndexOptions() = default; - // Creates IcingLiteIndexOptions based off of the specified parameters. All +struct LiteIndexOptions { + LiteIndexOptions() = default; + // Creates LiteIndexOptions based off of the specified parameters. All // other fields are calculated based on the value of // hit_buffer_want_merge_bytes and the logic in CalculateHitBufferSize and // CalculateTrieOptions. - IcingLiteIndexOptions(const std::string& filename_base, - uint32_t hit_buffer_want_merge_bytes); + LiteIndexOptions(const std::string& filename_base, + uint32_t hit_buffer_want_merge_bytes); IcingDynamicTrie::Options lexicon_options; IcingDynamicTrie::Options display_mappings_options; diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index a5c6baf..3e614d2 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -23,6 +23,7 @@ #include <memory> #include <string> #include <string_view> +#include <unordered_set> #include <utility> #include <vector> @@ -33,13 +34,13 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index-header.h" #include "icing/index/term-property-id.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/core/icing-timer.h" #include "icing/legacy/index/icing-array-storage.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" -#include "icing/legacy/index/icing-lite-index-header.h" #include "icing/legacy/index/icing-mmapper.h" #include "icing/proto/term.pb.h" #include "icing/schema/section.h" @@ -60,7 +61,7 @@ std::string MakeHitBufferFilename(const std::string& filename_base) { return filename_base + "hb"; } -size_t header_size() { return sizeof(IcingLiteIndex_HeaderImpl::HeaderData); } +size_t header_size() { return sizeof(LiteIndex_HeaderImpl::HeaderData); } } // namespace @@ -156,8 +157,8 @@ libtextclassifier3::Status LiteIndex::Initialize() { // Set up header. header_mmap_.Remap(hit_buffer_fd_.get(), 0, header_size()); - header_ = std::make_unique<IcingLiteIndex_HeaderImpl>( - reinterpret_cast<IcingLiteIndex_HeaderImpl::HeaderData*>( + header_ = std::make_unique<LiteIndex_HeaderImpl>( + reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( header_mmap_.address())); header_->Reset(); @@ -171,8 +172,8 @@ libtextclassifier3::Status LiteIndex::Initialize() { UpdateChecksum(); } else { header_mmap_.Remap(hit_buffer_fd_.get(), 0, header_size()); - header_ = std::make_unique<IcingLiteIndex_HeaderImpl>( - reinterpret_cast<IcingLiteIndex_HeaderImpl::HeaderData*>( + header_ = std::make_unique<LiteIndex_HeaderImpl>( + reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( header_mmap_.address())); if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true, @@ -197,8 +198,7 @@ libtextclassifier3::Status LiteIndex::Initialize() { } } - ICING_VLOG(2) << IcingStringUtil::StringPrintf("Lite index init ok in %.3fms", - timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index init ok in " << timer.Elapsed() * 1000 << "ms"; return status; error: @@ -230,8 +230,7 @@ Crc32 LiteIndex::ComputeChecksum() { Crc32 all_crc(header_->CalculateHeaderCrc()); all_crc.Append(std::string_view(reinterpret_cast<const char*>(dependent_crcs), sizeof(dependent_crcs))); - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Lite index crc computed in %.3fms", timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index crc computed in " << timer.Elapsed() * 1000 << "ms"; return all_crc; } @@ -246,8 +245,7 @@ libtextclassifier3::Status LiteIndex::Reset() { header_->Reset(); UpdateChecksum(); - ICING_VLOG(2) << IcingStringUtil::StringPrintf("Lite index clear in %.3fms", - timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index clear in " << timer.Elapsed() * 1000 << "ms"; return libtextclassifier3::Status::OK; } @@ -391,15 +389,22 @@ bool LiteIndex::is_full() const { lexicon_.min_free_fraction() < (1.0 - kTrieFullFraction)); } -IndexDebugInfoProto::LiteIndexDebugInfoProto LiteIndex::GetDebugInfo( - int verbosity) { - IndexDebugInfoProto::LiteIndexDebugInfoProto res; - res.set_curr_size(header_->cur_size()); - res.set_hit_buffer_size(options_.hit_buffer_size); - res.set_last_added_document_id(header_->last_added_docid()); - res.set_searchable_end(header_->searchable_end()); - res.set_index_crc(ComputeChecksum().Get()); - lexicon_.GetDebugInfo(verbosity, res.mutable_lexicon_info()); +std::string LiteIndex::GetDebugInfo(DebugInfoVerbosity::Code verbosity) { + std::string res; + std::string lexicon_info; + lexicon_.GetDebugInfo(verbosity, &lexicon_info); + IcingStringUtil::SStringAppendF( + &res, 0, + "curr_size: %u\n" + "hit_buffer_size: %u\n" + "last_added_document_id %u\n" + "searchable_end: %u\n" + "index_crc: %u\n" + "\n" + "lite_lexicon_info:\n%s\n", + header_->cur_size(), options_.hit_buffer_size, + header_->last_added_docid(), header_->searchable_end(), + ComputeChecksum().Get(), lexicon_info.c_str()); return res; } @@ -432,34 +437,38 @@ IndexStorageInfoProto LiteIndex::GetStorageInfo( return storage_info; } -uint32_t LiteIndex::Seek(uint32_t term_id) { +void LiteIndex::SortHits() { // Make searchable by sorting by hit buffer. uint32_t sort_len = header_->cur_size() - header_->searchable_end(); - if (sort_len > 0) { - IcingTimer timer; - - auto* array_start = - hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size()); - TermIdHitPair::Value* sort_start = array_start + header_->searchable_end(); - std::sort(sort_start, array_start + header_->cur_size()); - - // Now merge with previous region. Since the previous region is already - // sorted and deduplicated, optimize the merge by skipping everything less - // than the new region's smallest value. - if (header_->searchable_end() > 0) { - std::inplace_merge(array_start, array_start + header_->searchable_end(), - array_start + header_->cur_size()); - } - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Lite index sort and merge %u into %u in %.3fms", sort_len, - header_->searchable_end(), timer.Elapsed() * 1000); - - // Now the entire array is sorted. - header_->set_searchable_end(header_->cur_size()); + if (sort_len <= 0) { + return; + } + IcingTimer timer; - // Update crc in-line. - UpdateChecksum(); + auto* array_start = + hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size()); + TermIdHitPair::Value* sort_start = array_start + header_->searchable_end(); + std::sort(sort_start, array_start + header_->cur_size()); + + // Now merge with previous region. Since the previous region is already + // sorted and deduplicated, optimize the merge by skipping everything less + // than the new region's smallest value. + if (header_->searchable_end() > 0) { + std::inplace_merge(array_start, array_start + header_->searchable_end(), + array_start + header_->cur_size()); } + ICING_VLOG(2) << "Lite index sort and merge " << sort_len << " into " + << header_->searchable_end() << " in " << timer.Elapsed() * 1000 << "ms"; + + // Now the entire array is sorted. + header_->set_searchable_end(header_->cur_size()); + + // Update crc in-line. + UpdateChecksum(); +} + +uint32_t LiteIndex::Seek(uint32_t term_id) { + SortHits(); // Binary search for our term_id. Make sure we get the first // element. Using kBeginSortValue ensures this for the hit value. @@ -473,5 +482,80 @@ uint32_t LiteIndex::Seek(uint32_t term_id) { return ptr - array; } +libtextclassifier3::Status LiteIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + const TermIdCodec* term_id_codec, DocumentId new_last_added_document_id) { + header_->set_last_added_docid(new_last_added_document_id); + if (header_->cur_size() == 0) { + return libtextclassifier3::Status::OK; + } + // Sort the hits so that hits with the same term id will be grouped together, + // which helps later to determine which terms will be unused after compaction. + SortHits(); + uint32_t new_size = 0; + uint32_t curr_term_id = 0; + uint32_t curr_tvi = 0; + std::unordered_set<uint32_t> tvi_to_delete; + for (uint32_t idx = 0; idx < header_->cur_size(); ++idx) { + TermIdHitPair term_id_hit_pair( + hit_buffer_.array_cast<TermIdHitPair>()[idx]); + if (idx == 0 || term_id_hit_pair.term_id() != curr_term_id) { + curr_term_id = term_id_hit_pair.term_id(); + ICING_ASSIGN_OR_RETURN(TermIdCodec::DecodedTermInfo term_info, + term_id_codec->DecodeTermInfo(curr_term_id)); + curr_tvi = term_info.tvi; + // Mark the property of the current term as not having hits in prefix + // section. The property will be set below if there are any valid hits + // from a prefix section. + lexicon_.ClearProperty(curr_tvi, GetHasHitsInPrefixSectionPropertyId()); + // Add curr_tvi to tvi_to_delete. It will be removed from tvi_to_delete + // below if there are any valid hits pointing to that termid. + tvi_to_delete.insert(curr_tvi); + } + DocumentId new_document_id = + document_id_old_to_new[term_id_hit_pair.hit().document_id()]; + if (new_document_id == kInvalidDocumentId) { + continue; + } + if (term_id_hit_pair.hit().is_in_prefix_section()) { + lexicon_.SetProperty(curr_tvi, GetHasHitsInPrefixSectionPropertyId()); + } + tvi_to_delete.erase(curr_tvi); + TermIdHitPair new_term_id_hit_pair( + term_id_hit_pair.term_id(), + Hit::TranslateHit(term_id_hit_pair.hit(), new_document_id)); + // Rewriting the hit_buffer in place. + // new_size is weakly less than idx so we are okay to overwrite the entry at + // new_size, and valp should never be nullptr since it is within the already + // allocated region of hit_buffer_. + TermIdHitPair::Value* valp = + hit_buffer_.GetMutableMem<TermIdHitPair::Value>(new_size++, 1); + *valp = new_term_id_hit_pair.value(); + } + header_->set_cur_size(new_size); + header_->set_searchable_end(new_size); + + // Delete unused terms. + std::unordered_set<std::string> terms_to_delete; + for (IcingDynamicTrie::Iterator term_iter(lexicon_, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + if (tvi_to_delete.find(term_iter.GetValueIndex()) != tvi_to_delete.end()) { + terms_to_delete.insert(term_iter.GetKey()); + } + } + for (const std::string& term : terms_to_delete) { + // Mark "term" as deleted. This won't actually free space in the lexicon. It + // will simply make it impossible to Find "term" in subsequent calls (which + // saves an unnecessary search through the hit buffer). This is acceptable + // because the free space will eventually be reclaimed the next time that + // the lite index is merged with the main index. + if (!lexicon_.Delete(term)) { + return absl_ports::InternalError( + "Could not delete invalid terms in lite lexicon during compaction."); + } + } + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/index/lite/lite-index.h b/icing/index/lite/lite-index.h index 378fc94..be629b8 100644 --- a/icing/index/lite/lite-index.h +++ b/icing/index/lite/lite-index.h @@ -30,12 +30,13 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index-header.h" +#include "icing/index/lite/lite-index-options.h" #include "icing/index/lite/term-id-hit-pair.h" +#include "icing/index/term-id-codec.h" #include "icing/legacy/index/icing-array-storage.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" -#include "icing/legacy/index/icing-lite-index-header.h" -#include "icing/legacy/index/icing-lite-index-options.h" #include "icing/legacy/index/icing-mmapper.h" #include "icing/proto/debug.pb.h" #include "icing/proto/storage.pb.h" @@ -53,7 +54,7 @@ namespace lib { class LiteIndex { public: // An entry in the hit buffer. - using Options = IcingLiteIndexOptions; + using Options = LiteIndexOptions; // Updates checksum of subcomponents. ~LiteIndex(); @@ -240,9 +241,9 @@ class LiteIndex { const IcingDynamicTrie& lexicon() const { return lexicon_; } // Returns debug information for the index in out. - // verbosity <= 0, simplest debug information - size of lexicon, hit buffer - // verbosity > 0, more detailed debug information from the lexicon. - IndexDebugInfoProto::LiteIndexDebugInfoProto GetDebugInfo(int verbosity); + // verbosity = BASIC, simplest debug information - size of lexicon, hit buffer + // verbosity = DETAILED, more detailed debug information from the lexicon. + std::string GetDebugInfo(DebugInfoVerbosity::Code verbosity); // Returns the byte size of all the elements held in the index. This excludes // the size of any internal metadata of the index, e.g. the index's header. @@ -260,6 +261,19 @@ class LiteIndex { IndexStorageInfoProto GetStorageInfo( IndexStorageInfoProto storage_info) const; + // Reduces internal file sizes by reclaiming space of deleted documents. + // + // This method also sets the last_added_docid of the index to + // new_last_added_document_id. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + const TermIdCodec* term_id_codec, DocumentId new_last_added_document_id); + private: static IcingDynamicTrie::RuntimeOptions MakeTrieRuntimeOptions(); @@ -279,6 +293,9 @@ class LiteIndex { // Sets the computed checksum in the header void UpdateChecksum(); + // Sort hits stored in the index. + void SortHits(); + // Returns the position of the first element with term_id, or the size of the // hit buffer if term_id is not present. uint32_t Seek(uint32_t term_id); @@ -301,7 +318,7 @@ class LiteIndex { IcingMMapper header_mmap_; // Wrapper around the mmapped header that contains stats on the lite index. - std::unique_ptr<IcingLiteIndex_Header> header_; + std::unique_ptr<LiteIndex_Header> header_; // Options used to initialize the LiteIndex. const Options options_; diff --git a/icing/index/main/flash-index-storage.cc b/icing/index/main/flash-index-storage.cc index 3c52375..33dacf9 100644 --- a/icing/index/main/flash-index-storage.cc +++ b/icing/index/main/flash-index-storage.cc @@ -133,9 +133,7 @@ bool FlashIndexStorage::CreateHeader() { posting_list_bytes /= 2) { uint32_t aligned_posting_list_bytes = (posting_list_bytes / sizeof(Hit) * sizeof(Hit)); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Block size %u: %u", header_block_->header()->num_index_block_infos, - aligned_posting_list_bytes); + ICING_VLOG(1) << "Block size " << header_block_->header()->num_index_block_infos << ": " << aligned_posting_list_bytes; // Initialize free list to empty. HeaderBlock::Header::IndexBlockInfo* block_info = @@ -169,23 +167,18 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { return false; } if (file_size % read_header.header()->block_size != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index size %" PRIu64 " not a multiple of block size %u", file_size, - read_header.header()->block_size); + ICING_LOG(ERROR) << "Index size " << file_size << " not a multiple of block size " << read_header.header()->block_size; return false; } if (file_size < static_cast<int64_t>(read_header.header()->block_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index size %" PRIu64 " shorter than block size %u", file_size, - read_header.header()->block_size); + ICING_LOG(ERROR) << "Index size " << file_size << " shorter than block size " << read_header.header()->block_size; return false; } if (read_header.header()->block_size % getpagesize() != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Block size %u is not a multiple of page size %d", - read_header.header()->block_size, getpagesize()); + ICING_LOG(ERROR) << "Block size " << read_header.header()->block_size + << " is not a multiple of page size " << getpagesize(); return false; } num_blocks_ = file_size / read_header.header()->block_size; @@ -215,11 +208,10 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { int posting_list_bytes = header_block_->header()->index_block_infos[i].posting_list_bytes; if (posting_list_bytes % sizeof(Hit) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Posting list size misaligned, index %u, size %u, hit %zu, " - "file_size %" PRIu64, - i, header_block_->header()->index_block_infos[i].posting_list_bytes, - sizeof(Hit), file_size); + ICING_LOG(ERROR) << "Posting list size misaligned, index " << i + << ", size " + << header_block_->header()->index_block_infos[i].posting_list_bytes + << ", hit " << sizeof(Hit) << ", file_size " << file_size; return false; } } @@ -229,8 +221,7 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { bool FlashIndexStorage::PersistToDisk() { // First, write header. if (!header_block_->Write(block_fd_.get())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Write index header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Write index header failed: " << strerror(errno); return false; } @@ -456,8 +447,7 @@ void FlashIndexStorage::FreePostingList(PostingListHolder holder) { int FlashIndexStorage::GrowIndex() { if (num_blocks_ >= kMaxBlockIndex) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Reached max block index %u", - kMaxBlockIndex); + ICING_VLOG(1) << "Reached max block index " << kMaxBlockIndex; return kInvalidBlockIndex; } @@ -465,8 +455,7 @@ int FlashIndexStorage::GrowIndex() { if (!filesystem_->Grow( block_fd_.get(), static_cast<uint64_t>(num_blocks_ + 1) * block_size())) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Error growing index file: %s", strerror(errno)); + ICING_VLOG(1) << "Error growing index file: " << strerror(errno); return kInvalidBlockIndex; } @@ -503,7 +492,8 @@ void FlashIndexStorage::FlushInMemoryFreeList() { } } -void FlashIndexStorage::GetDebugInfo(int verbosity, std::string* out) const { +void FlashIndexStorage::GetDebugInfo(DebugInfoVerbosity::Code verbosity, + std::string* out) const { // Dump and check integrity of the index block free lists. out->append("Free lists:\n"); for (size_t i = 0; i < header_block_->header()->num_index_block_infos; ++i) { diff --git a/icing/index/main/flash-index-storage.h b/icing/index/main/flash-index-storage.h index 6c6fbb8..fceb26f 100644 --- a/icing/index/main/flash-index-storage.h +++ b/icing/index/main/flash-index-storage.h @@ -160,7 +160,7 @@ class FlashIndexStorage { libtextclassifier3::Status Reset(); // TODO(b/222349894) Convert the string output to a protocol buffer instead. - void GetDebugInfo(int verbosity, std::string* out) const; + void GetDebugInfo(DebugInfoVerbosity::Code verbosity, std::string* out) const; private: FlashIndexStorage(const std::string& index_filename, diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 2d6007b..9f591c0 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -16,9 +16,12 @@ #include <cstdint> #include <cstring> #include <memory> +#include <string> +#include <unordered_set> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" #include "icing/index/main/index-block.h" #include "icing/index/term-id-codec.h" #include "icing/index/term-property-id.h" @@ -83,35 +86,40 @@ FindTermResult FindShortestValidTermWithPrefixHits( } // namespace +MainIndex::MainIndex(const std::string& index_directory, + const Filesystem* filesystem, + const IcingFilesystem* icing_filesystem) + : base_dir_(index_directory), + filesystem_(filesystem), + icing_filesystem_(icing_filesystem) {} + libtextclassifier3::StatusOr<std::unique_ptr<MainIndex>> MainIndex::Create( const std::string& index_directory, const Filesystem* filesystem, const IcingFilesystem* icing_filesystem) { ICING_RETURN_ERROR_IF_NULL(filesystem); ICING_RETURN_ERROR_IF_NULL(icing_filesystem); - auto main_index = std::make_unique<MainIndex>(); - ICING_RETURN_IF_ERROR( - main_index->Init(index_directory, filesystem, icing_filesystem)); + std::unique_ptr<MainIndex> main_index( + new MainIndex(index_directory, filesystem, icing_filesystem)); + ICING_RETURN_IF_ERROR(main_index->Init()); return main_index; } // TODO(b/139087650) : Migrate off of IcingFilesystem. -libtextclassifier3::Status MainIndex::Init( - const std::string& index_directory, const Filesystem* filesystem, - const IcingFilesystem* icing_filesystem) { - if (!filesystem->CreateDirectoryRecursively(index_directory.c_str())) { +libtextclassifier3::Status MainIndex::Init() { + if (!filesystem_->CreateDirectoryRecursively(base_dir_.c_str())) { return absl_ports::InternalError("Unable to create main index directory."); } - std::string flash_index_file = index_directory + "/main_index"; + std::string flash_index_file = base_dir_ + "/main_index"; ICING_ASSIGN_OR_RETURN( FlashIndexStorage flash_index, - FlashIndexStorage::Create(flash_index_file, filesystem)); + FlashIndexStorage::Create(flash_index_file, filesystem_)); flash_index_storage_ = std::make_unique<FlashIndexStorage>(std::move(flash_index)); - std::string lexicon_file = index_directory + "/main-lexicon"; + std::string lexicon_file = base_dir_ + "/main-lexicon"; IcingDynamicTrie::RuntimeOptions runtime_options; main_lexicon_ = std::make_unique<IcingDynamicTrie>( - lexicon_file, runtime_options, icing_filesystem); + lexicon_file, runtime_options, icing_filesystem_); IcingDynamicTrie::Options lexicon_options; if (!main_lexicon_->CreateIfNotExist(lexicon_options) || !main_lexicon_->Init()) { @@ -489,8 +497,7 @@ libtextclassifier3::Status MainIndex::AddHits( } // Now copy remaining backfills. - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Remaining backfills %zu", - backfill_map.size()); + ICING_VLOG(1) << "Remaining backfills " << backfill_map.size(); for (auto other_tvi_main_tvi_pair : backfill_map) { PostingListIdentifier backfill_posting_list_id = PostingListIdentifier::kInvalid; @@ -523,9 +530,7 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm( std::unique_ptr<PostingListAccessor> pl_accessor; if (posting_list_id.is_valid()) { if (posting_list_id.block_index() >= flash_index_storage_->num_blocks()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index dropped hits. Invalid block index %u >= %u", - posting_list_id.block_index(), flash_index_storage_->num_blocks()); + ICING_LOG(ERROR) << "Index dropped hits. Invalid block index " << posting_list_id.block_index() << " >= " << flash_index_storage_->num_blocks(); // TODO(b/159918304) : Consider revising the checksumming strategy in the // main index. Providing some mechanism to check for corruption - either // during initialization or some later time would allow us to avoid @@ -607,23 +612,167 @@ libtextclassifier3::Status MainIndex::AddPrefixBackfillHits( return libtextclassifier3::Status::OK; } -IndexDebugInfoProto::MainIndexDebugInfoProto MainIndex::GetDebugInfo( - int verbosity) const { - IndexDebugInfoProto::MainIndexDebugInfoProto res; +std::string MainIndex::GetDebugInfo(DebugInfoVerbosity::Code verbosity) const { + std::string res; // Lexicon. - main_lexicon_->GetDebugInfo(verbosity, res.mutable_lexicon_info()); + std::string lexicon_info; + main_lexicon_->GetDebugInfo(verbosity, &lexicon_info); - res.set_last_added_document_id(last_added_document_id()); + IcingStringUtil::SStringAppendF(&res, 0, + "last_added_document_id: %u\n" + "\n" + "main_lexicon_info:\n%s\n", + last_added_document_id(), + lexicon_info.c_str()); - if (verbosity <= 0) { + if (verbosity == DebugInfoVerbosity::BASIC) { return res; } - flash_index_storage_->GetDebugInfo(verbosity, - res.mutable_flash_index_storage_info()); + std::string flash_index_storage_info; + flash_index_storage_->GetDebugInfo(verbosity, &flash_index_storage_info); + IcingStringUtil::SStringAppendF(&res, 0, "flash_index_storage_info:\n%s\n", + flash_index_storage_info.c_str()); return res; } +libtextclassifier3::Status MainIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new) { + std::string temporary_index_dir_path = base_dir_ + "_temp"; + if (!filesystem_->DeleteDirectoryRecursively( + temporary_index_dir_path.c_str())) { + ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_dir_path; + return absl_ports::InternalError( + "Unable to delete temp directory to prepare to build new index."); + } + + DestructibleDirectory temporary_index_dir( + filesystem_, std::move(temporary_index_dir_path)); + if (!temporary_index_dir.is_valid()) { + return absl_ports::InternalError( + "Unable to create temp directory to build new index."); + } + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<MainIndex> new_index, + MainIndex::Create(temporary_index_dir.dir(), + filesystem_, icing_filesystem_)); + ICING_RETURN_IF_ERROR(TransferIndex(document_id_old_to_new, new_index.get())); + ICING_RETURN_IF_ERROR(new_index->PersistToDisk()); + new_index = nullptr; + flash_index_storage_ = nullptr; + main_lexicon_ = nullptr; + + if (!filesystem_->SwapFiles(temporary_index_dir.dir().c_str(), + base_dir_.c_str())) { + return absl_ports::InternalError( + "Unable to apply new index due to failed swap!"); + } + + // Reinitialize the index so that flash_index_storage_ and main_lexicon_ are + // properly updated. + return Init(); +} + +libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits( + const std::vector<DocumentId>& document_id_old_to_new, const char* term, + PostingListAccessor& old_pl_accessor, MainIndex* new_index) { + std::vector<Hit> new_hits; + bool has_no_exact_hits = true; + bool has_hits_in_prefix_section = false; + // The largest document id after translating hits. + DocumentId largest_document_id = kInvalidDocumentId; + ICING_ASSIGN_OR_RETURN(std::vector<Hit> tmp, + old_pl_accessor.GetNextHitsBatch()); + while (!tmp.empty()) { + for (const Hit& hit : tmp) { + DocumentId new_document_id = document_id_old_to_new[hit.document_id()]; + // Transfer the document id of the hit, if the document is not deleted + // or outdated. + if (new_document_id != kInvalidDocumentId) { + if (hit.is_in_prefix_section()) { + has_hits_in_prefix_section = true; + } + if (!hit.is_prefix_hit()) { + has_no_exact_hits = false; + } + if (largest_document_id == kInvalidDocumentId || + new_document_id > largest_document_id) { + largest_document_id = new_document_id; + } + new_hits.push_back(Hit::TranslateHit(hit, new_document_id)); + } + } + ICING_ASSIGN_OR_RETURN(tmp, old_pl_accessor.GetNextHitsBatch()); + } + // A term without exact hits indicates that it is a purely backfill term. If + // the term is not branching in the new trie, it means backfilling is no + // longer necessary, so that we can skip. + if (new_hits.empty() || + (has_no_exact_hits && !new_index->main_lexicon_->IsBranchingTerm(term))) { + return largest_document_id; + } + + ICING_ASSIGN_OR_RETURN( + PostingListAccessor hit_accum, + PostingListAccessor::Create(new_index->flash_index_storage_.get())); + for (auto itr = new_hits.rbegin(); itr != new_hits.rend(); ++itr) { + ICING_RETURN_IF_ERROR(hit_accum.PrependHit(*itr)); + } + PostingListAccessor::FinalizeResult result = + PostingListAccessor::Finalize(std::move(hit_accum)); + uint32_t tvi; + if (!result.id.is_valid() || + !new_index->main_lexicon_->Insert(term, &result.id, &tvi, + /*replace=*/false)) { + return absl_ports::InternalError( + absl_ports::StrCat("Could not transfer main index for term: ", term)); + } + if (has_no_exact_hits && !new_index->main_lexicon_->SetProperty( + tvi, GetHasNoExactHitsPropertyId())) { + return absl_ports::InternalError("Setting prefix prop failed"); + } + if (has_hits_in_prefix_section && + !new_index->main_lexicon_->SetProperty( + tvi, GetHasHitsInPrefixSectionPropertyId())) { + return absl_ports::InternalError("Setting prefix prop failed"); + } + return largest_document_id; +} + +libtextclassifier3::Status MainIndex::TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + MainIndex* new_index) { + DocumentId largest_document_id = kInvalidDocumentId; + for (IcingDynamicTrie::Iterator term_itr(*main_lexicon_, /*prefix=*/"", + /*reverse=*/true); + term_itr.IsValid(); term_itr.Advance()) { + PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; + memcpy(&posting_list_id, term_itr.GetValue(), sizeof(posting_list_id)); + if (posting_list_id == PostingListIdentifier::kInvalid) { + // Why? + ICING_LOG(ERROR) + << "Got invalid posting_list_id from previous main index"; + continue; + } + ICING_ASSIGN_OR_RETURN(PostingListAccessor pl_accessor, + PostingListAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_id)); + ICING_ASSIGN_OR_RETURN( + DocumentId curr_largest_document_id, + TransferAndAddHits(document_id_old_to_new, term_itr.GetKey(), + pl_accessor, new_index)); + if (curr_largest_document_id == kInvalidDocumentId) { + continue; + } + if (largest_document_id == kInvalidDocumentId || + curr_largest_document_id > largest_document_id) { + largest_document_id = curr_largest_document_id; + } + } + new_index->flash_index_storage_->set_last_indexed_docid(largest_document_id); + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index abb0418..4ed2e94 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -183,16 +183,28 @@ class MainIndex { IndexStorageInfoProto storage_info) const; // Returns debug information for the main index in out. - // verbosity <= 0, simplest debug information - just the lexicon - // verbosity > 0, more detailed debug information including raw postings - // lists. - IndexDebugInfoProto::MainIndexDebugInfoProto GetDebugInfo( - int verbosity) const; + // verbosity = BASIC, simplest debug information - just the lexicon + // verbosity = DETAILED, more detailed debug information including raw + // postings lists. + std::string GetDebugInfo(DebugInfoVerbosity::Code verbosity) const; + + // Reduces internal file sizes by reclaiming space of deleted documents. + // + // This method will update the last_added_docid of the index to the largest + // document id that still appears in the index after compaction. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new); private: - libtextclassifier3::Status Init(const std::string& index_directory, - const Filesystem* filesystem, - const IcingFilesystem* icing_filesystem); + MainIndex(const std::string& index_directory, const Filesystem* filesystem, + const IcingFilesystem* icing_filesystem); + + libtextclassifier3::Status Init(); // Helpers for merging the lexicon // Add all 'backfill' branch points. Backfill branch points are prefix @@ -288,6 +300,27 @@ class MainIndex { PostingListIdentifier backfill_posting_list_id, PostingListAccessor* hit_accum); + // Transfer hits from old_pl_accessor to new_index for term. + // + // Returns: + // largest document id added to the translated posting list, on success + // INTERNAL_ERROR on IO error + static libtextclassifier3::StatusOr<DocumentId> TransferAndAddHits( + const std::vector<DocumentId>& document_id_old_to_new, const char* term, + PostingListAccessor& old_pl_accessor, MainIndex* new_index); + + // Transfer hits from the current main index to new_index. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error + libtextclassifier3::Status TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + MainIndex* new_index); + + std::string base_dir_; + const Filesystem* filesystem_; + const IcingFilesystem* icing_filesystem_; std::unique_ptr<FlashIndexStorage> flash_index_storage_; std::unique_ptr<IcingDynamicTrie> main_lexicon_; }; diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index bcc35e6..c9e7127 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -15,11 +15,13 @@ #include <jni.h> #include <string> +#include <utility> -#include "icing/jni/jni-cache.h" #include <google/protobuf/message_lite.h> -#include "icing/absl_ports/status_imports.h" #include "icing/icing-search-engine.h" +#include "icing/jni/jni-cache.h" +#include "icing/jni/scoped-primitive-array-critical.h" +#include "icing/jni/scoped-utf-chars.h" #include "icing/proto/document.pb.h" #include "icing/proto/initialize.pb.h" #include "icing/proto/optimize.pb.h" @@ -29,6 +31,7 @@ #include "icing/proto/search.pb.h" #include "icing/proto/storage.pb.h" #include "icing/proto/usage.pb.h" +#include "icing/util/logging.h" #include "icing/util/status-macros.h" namespace { @@ -39,13 +42,8 @@ const char kNativePointerField[] = "nativePointer"; bool ParseProtoFromJniByteArray(JNIEnv* env, jbyteArray bytes, google::protobuf::MessageLite* protobuf) { - int bytes_size = env->GetArrayLength(bytes); - uint8_t* bytes_ptr = static_cast<uint8_t*>( - env->GetPrimitiveArrayCritical(bytes, /*isCopy=*/nullptr)); - bool parsed = protobuf->ParseFromArray(bytes_ptr, bytes_size); - env->ReleasePrimitiveArrayCritical(bytes, bytes_ptr, /*mode=*/0); - - return parsed; + icing::lib::ScopedPrimitiveArrayCritical<uint8_t> scoped_array(env, bytes); + return protobuf->ParseFromArray(scoped_array.data(), scoped_array.size()); } jbyteArray SerializeProtoToJniByteArray( @@ -57,10 +55,8 @@ jbyteArray SerializeProtoToJniByteArray( return nullptr; } - uint8_t* ret_buf = static_cast<uint8_t*>( - env->GetPrimitiveArrayCritical(ret, /*isCopy=*/nullptr)); - protobuf.SerializeWithCachedSizesToArray(ret_buf); - env->ReleasePrimitiveArrayCritical(ret, ret_buf, 0); + icing::lib::ScopedPrimitiveArrayCritical<uint8_t> scoped_array(env, ret); + protobuf.SerializeWithCachedSizesToArray(scoped_array.data()); return ret; } @@ -162,11 +158,9 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType( icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); - const char* native_schema_type = - env->GetStringUTFChars(schema_type, /*isCopy=*/nullptr); + icing::lib::ScopedUtfChars scoped_schema_type_chars(env, schema_type); icing::lib::GetSchemaTypeResultProto get_schema_type_result_proto = - icing->GetSchemaType(native_schema_type); - env->ReleaseStringUTFChars(schema_type, native_schema_type); + icing->GetSchemaType(scoped_schema_type_chars.c_str()); return SerializeProtoToJniByteArray(env, get_schema_type_result_proto); } @@ -193,20 +187,19 @@ JNIEXPORT jbyteArray JNICALL Java_com_google_android_icing_IcingSearchEngine_nativeGet( JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri, jbyteArray result_spec_bytes) { + icing::lib::IcingSearchEngine* icing = + GetIcingSearchEnginePointer(env, object); + icing::lib::GetResultSpecProto get_result_spec; if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &get_result_spec)) { ICING_LOG(ERROR) << "Failed to parse GetResultSpecProto in nativeGet"; return nullptr; } - icing::lib::IcingSearchEngine* icing = - GetIcingSearchEnginePointer(env, object); - const char* native_name_space = - env->GetStringUTFChars(name_space, /*isCopy=*/nullptr); - const char* native_uri = env->GetStringUTFChars(uri, /*isCopy=*/nullptr); + icing::lib::ScopedUtfChars scoped_name_space_chars(env, name_space); + icing::lib::ScopedUtfChars scoped_uri_chars(env, uri); icing::lib::GetResultProto get_result_proto = - icing->Get(native_name_space, native_uri, get_result_spec); - env->ReleaseStringUTFChars(uri, native_uri); - env->ReleaseStringUTFChars(name_space, native_name_space); + icing->Get(scoped_name_space_chars.c_str(), scoped_uri_chars.c_str(), + get_result_spec); return SerializeProtoToJniByteArray(env, get_result_proto); } @@ -303,13 +296,10 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDelete( icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); - const char* native_name_space = - env->GetStringUTFChars(name_space, /*isCopy=*/nullptr); - const char* native_uri = env->GetStringUTFChars(uri, /*isCopy=*/nullptr); + icing::lib::ScopedUtfChars scoped_name_space_chars(env, name_space); + icing::lib::ScopedUtfChars scoped_uri_chars(env, uri); icing::lib::DeleteResultProto delete_result_proto = - icing->Delete(native_name_space, native_uri); - env->ReleaseStringUTFChars(uri, native_uri); - env->ReleaseStringUTFChars(name_space, native_name_space); + icing->Delete(scoped_name_space_chars.c_str(), scoped_uri_chars.c_str()); return SerializeProtoToJniByteArray(env, delete_result_proto); } @@ -320,11 +310,9 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace( icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); - const char* native_name_space = - env->GetStringUTFChars(name_space, /*isCopy=*/nullptr); + icing::lib::ScopedUtfChars scoped_name_space_chars(env, name_space); icing::lib::DeleteByNamespaceResultProto delete_by_namespace_result_proto = - icing->DeleteByNamespace(native_name_space); - env->ReleaseStringUTFChars(name_space, native_name_space); + icing->DeleteByNamespace(scoped_name_space_chars.c_str()); return SerializeProtoToJniByteArray(env, delete_by_namespace_result_proto); } @@ -335,18 +323,17 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType( icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); - const char* native_schema_type = - env->GetStringUTFChars(schema_type, /*isCopy=*/nullptr); + icing::lib::ScopedUtfChars scoped_schema_type_chars(env, schema_type); icing::lib::DeleteBySchemaTypeResultProto delete_by_schema_type_result_proto = - icing->DeleteBySchemaType(native_schema_type); - env->ReleaseStringUTFChars(schema_type, native_schema_type); + icing->DeleteBySchemaType(scoped_schema_type_chars.c_str()); return SerializeProtoToJniByteArray(env, delete_by_schema_type_result_proto); } JNIEXPORT jbyteArray JNICALL Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery( - JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes) { + JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, + jboolean return_deleted_document_info) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -356,7 +343,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery( return nullptr; } icing::lib::DeleteByQueryResultProto delete_result_proto = - icing->DeleteByQuery(search_spec_proto); + icing->DeleteByQuery(search_spec_proto, return_deleted_document_info); return SerializeProtoToJniByteArray(env, delete_result_proto); } @@ -445,4 +432,49 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( return SerializeProtoToJniByteArray(env, suggestionResponse); } +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo( + JNIEnv* env, jclass clazz, jobject object, jint verbosity) { + icing::lib::IcingSearchEngine* icing = + GetIcingSearchEnginePointer(env, object); + + if (!icing::lib::DebugInfoVerbosity::Code_IsValid(verbosity)) { + ICING_LOG(ERROR) << "Invalid value for Debug Info verbosity: " << verbosity; + return nullptr; + } + + icing::lib::DebugInfoResultProto debug_info_result_proto = + icing->GetDebugInfo( + static_cast<icing::lib::DebugInfoVerbosity::Code>(verbosity)); + + return SerializeProtoToJniByteArray(env, debug_info_result_proto); +} + +JNIEXPORT jboolean JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog( + JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { + if (!icing::lib::LogSeverity::Code_IsValid(severity)) { + ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; + return false; + } + return icing::lib::ShouldLog( + static_cast<icing::lib::LogSeverity::Code>(severity), verbosity); +} + +JNIEXPORT jboolean JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel( + JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { + if (!icing::lib::LogSeverity::Code_IsValid(severity)) { + ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; + return false; + } + return icing::lib::SetLoggingLevel( + static_cast<icing::lib::LogSeverity::Code>(severity), verbosity); +} + +JNIEXPORT jstring JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetLoggingTag( + JNIEnv* env, jclass clazz) { + return env->NewStringUTF(icing::lib::kIcingLoggingTag); +} } // extern "C" diff --git a/icing/jni/jni-cache.cc b/icing/jni/jni-cache.cc index 9b75db6..1804b9a 100644 --- a/icing/jni/jni-cache.cc +++ b/icing/jni/jni-cache.cc @@ -159,8 +159,7 @@ libtextclassifier3::StatusOr<std::unique_ptr<JniCache>> JniCache::Create( // BreakIteratorBatcher ICING_GET_CLASS_OR_RETURN_NULL( - breakiterator, - "com/google/android/icing/BreakIteratorBatcher"); + breakiterator, "com/google/android/icing/BreakIteratorBatcher"); ICING_GET_METHOD(breakiterator, constructor, "<init>", "(Ljava/util/Locale;)V"); ICING_GET_METHOD(breakiterator, settext, "setText", "(Ljava/lang/String;)V"); diff --git a/icing/jni/scoped-primitive-array-critical.h b/icing/jni/scoped-primitive-array-critical.h new file mode 100644 index 0000000..062c145 --- /dev/null +++ b/icing/jni/scoped-primitive-array-critical.h @@ -0,0 +1,86 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_JNI_SCOPED_PRIMITIVE_ARRAY_CRITICAL_H_ +#define ICING_JNI_SCOPED_PRIMITIVE_ARRAY_CRITICAL_H_ + +#include <jni.h> + +#include <utility> + +namespace icing { +namespace lib { + +template <typename T> +class ScopedPrimitiveArrayCritical { + public: + ScopedPrimitiveArrayCritical(JNIEnv* env, jarray array) + : env_(env), array_(array) { + if (array_ == nullptr) { + array_critical_ = nullptr; + array_critical_size_ = 0; + } else { + array_critical_size_ = env->GetArrayLength(array); + array_critical_ = static_cast<T*>( + env->GetPrimitiveArrayCritical(array, /*isCopy=*/nullptr)); + } + } + + ScopedPrimitiveArrayCritical(ScopedPrimitiveArrayCritical&& rhs) + : env_(nullptr), + array_(nullptr), + array_critical_(nullptr), + array_critical_size_(0) { + Swap(rhs); + } + + ScopedPrimitiveArrayCritical(const ScopedPrimitiveArrayCritical&) = delete; + + ScopedPrimitiveArrayCritical& operator=(ScopedPrimitiveArrayCritical&& rhs) { + Swap(rhs); + return *this; + } + + ScopedPrimitiveArrayCritical& operator=(const ScopedPrimitiveArrayCritical&) = + delete; + + ~ScopedPrimitiveArrayCritical() { + if (array_critical_ != nullptr && array_ != nullptr) { + env_->ReleasePrimitiveArrayCritical(array_, array_critical_, /*mode=*/0); + } + } + + T* data() { return array_critical_; } + const T* data() const { return array_critical_; } + + size_t size() const { return array_critical_size_; } + + private: + void Swap(ScopedPrimitiveArrayCritical& other) { + std::swap(env_, other.env_); + std::swap(array_, other.array_); + std::swap(array_critical_, other.array_critical_); + std::swap(array_critical_size_, other.array_critical_size_); + } + + JNIEnv* env_; + jarray array_; + T* array_critical_; + size_t array_critical_size_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_JNI_SCOPED_PRIMITIVE_ARRAY_CRITICAL_H_ diff --git a/icing/jni/scoped-primitive-array-critical_test.cc b/icing/jni/scoped-primitive-array-critical_test.cc new file mode 100644 index 0000000..3655378 --- /dev/null +++ b/icing/jni/scoped-primitive-array-critical_test.cc @@ -0,0 +1,140 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/jni/scoped-primitive-array-critical.h" + +#include <jni.h> + +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/java/mock_jni_env.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Return; +using util::java::test::MockJNIEnv; + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayNull) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), /*array=*/nullptr); + EXPECT_THAT(scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(scoped_primitive_array.size(), Eq(0)); + + // Move construct a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> moved_scoped_primitive_array( + std::move(scoped_primitive_array)); + EXPECT_THAT(moved_scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(moved_scoped_primitive_array.size(), Eq(0)); + + // Move assign a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> move_assigned_scoped_primitive_array = + std::move(moved_scoped_primitive_array); + EXPECT_THAT(move_assigned_scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(move_assigned_scoped_primitive_array.size(), Eq(0)); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jarray fake_jarray = reinterpret_cast<jarray>(-303); + uint8_t fake_array[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray), IsNull())) + .WillByDefault(Return(fake_array)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray))).WillByDefault(Return(4)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), + /*array=*/fake_jarray); + EXPECT_THAT(scoped_primitive_array.data(), Eq(fake_array)); + EXPECT_THAT(scoped_primitive_array.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray), + Eq(fake_array), Eq(0))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayMoveConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jarray fake_jarray = reinterpret_cast<jarray>(-303); + uint8_t fake_array[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray), IsNull())) + .WillByDefault(Return(fake_array)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray))).WillByDefault(Return(4)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), + /*array=*/fake_jarray); + + // Move construct a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> moved_scoped_primitive_array( + std::move(scoped_primitive_array)); + EXPECT_THAT(moved_scoped_primitive_array.data(), Eq(fake_array)); + EXPECT_THAT(moved_scoped_primitive_array.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray), + Eq(fake_array), Eq(0))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayMoveAssignment) { + // Setup the mock to return: + // {1, 8, 63, 90} for jstring (-303) + // {5, 9, 82} for jstring (-505) + auto env_mock = std::make_unique<MockJNIEnv>(); + jarray fake_jarray1 = reinterpret_cast<jarray>(-303); + uint8_t fake_array1[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray1), IsNull())) + .WillByDefault(Return(fake_array1)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray1))).WillByDefault(Return(4)); + + jarray fake_jarray2 = reinterpret_cast<jarray>(-505); + uint8_t fake_array2[] = {5, 9, 82}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray2), IsNull())) + .WillByDefault(Return(fake_array2)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray2))).WillByDefault(Return(3)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array1( + env_mock.get(), + /*array=*/fake_jarray1); + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array2( + env_mock.get(), + /*array=*/fake_jarray2); + + // Move assign a scoped utf chars + scoped_primitive_array2 = std::move(scoped_primitive_array1); + EXPECT_THAT(scoped_primitive_array2.data(), Eq(fake_array1)); + EXPECT_THAT(scoped_primitive_array2.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray1), + Eq(fake_array1), Eq(0))) + .Times(1); + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray2), + Eq(fake_array2), Eq(0))) + .Times(1); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/jni/scoped-utf-chars.h b/icing/jni/scoped-utf-chars.h new file mode 100644 index 0000000..5a3ac6a --- /dev/null +++ b/icing/jni/scoped-utf-chars.h @@ -0,0 +1,81 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_JNI_SCOPED_UTF_CHARS_H_ +#define ICING_JNI_SCOPED_UTF_CHARS_H_ + +#include <jni.h> + +#include <cstddef> +#include <cstring> +#include <utility> + +namespace icing { +namespace lib { + +// An RAII class to manage access and allocation of a Java string's UTF chars. +class ScopedUtfChars { + public: + ScopedUtfChars(JNIEnv* env, jstring s) : env_(env), string_(s) { + if (s == nullptr) { + utf_chars_ = nullptr; + size_ = 0; + } else { + utf_chars_ = env->GetStringUTFChars(s, /*isCopy=*/nullptr); + size_ = strlen(utf_chars_); + } + } + + ScopedUtfChars(ScopedUtfChars&& rhs) + : env_(nullptr), string_(nullptr), utf_chars_(nullptr) { + Swap(rhs); + } + + ScopedUtfChars(const ScopedUtfChars&) = delete; + + ScopedUtfChars& operator=(ScopedUtfChars&& rhs) { + Swap(rhs); + return *this; + } + + ScopedUtfChars& operator=(const ScopedUtfChars&) = delete; + + ~ScopedUtfChars() { + if (utf_chars_ != nullptr) { + env_->ReleaseStringUTFChars(string_, utf_chars_); + } + } + + const char* c_str() const { return utf_chars_; } + + size_t size() const { return size_; } + + private: + void Swap(ScopedUtfChars& other) { + std::swap(env_, other.env_); + std::swap(string_, other.string_); + std::swap(utf_chars_, other.utf_chars_); + std::swap(size_, other.size_); + } + + JNIEnv* env_; + jstring string_; + const char* utf_chars_; + size_t size_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_JNI_SCOPED_UTF_CHARS_H_ diff --git a/icing/jni/scoped-utf-chars_test.cc b/icing/jni/scoped-utf-chars_test.cc new file mode 100644 index 0000000..d249f69 --- /dev/null +++ b/icing/jni/scoped-utf-chars_test.cc @@ -0,0 +1,126 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/jni/scoped-utf-chars.h" + +#include <jni.h> + +#include <string> +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/java/mock_jni_env.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Return; +using util::java::test::MockJNIEnv; + +TEST(ScopedJniClassesTest, ScopedUtfCharsNull) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/nullptr); + EXPECT_THAT(scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(scoped_utf_chars.size(), Eq(0)); + + // Move construct a scoped utf chars + ScopedUtfChars moved_scoped_utf_chars(std::move(scoped_utf_chars)); + EXPECT_THAT(moved_scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(moved_scoped_utf_chars.size(), Eq(0)); + + // Move assign a scoped utf chars + ScopedUtfChars move_assigned_scoped_utf_chars = + std::move(moved_scoped_utf_chars); + EXPECT_THAT(move_assigned_scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(move_assigned_scoped_utf_chars.size(), Eq(0)); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jstring fake_jstring = reinterpret_cast<jstring>(-303); + std::string fake_string = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring), IsNull())) + .WillByDefault(Return(fake_string.c_str())); + + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/fake_jstring); + EXPECT_THAT(scoped_utf_chars.c_str(), Eq(fake_string.c_str())); + EXPECT_THAT(scoped_utf_chars.size(), Eq(3)); + + EXPECT_CALL(*env_mock, + ReleaseStringUTFChars(Eq(fake_jstring), Eq(fake_string.c_str()))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsMoveConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jstring fake_jstring = reinterpret_cast<jstring>(-303); + std::string fake_string = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring), IsNull())) + .WillByDefault(Return(fake_string.c_str())); + + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/fake_jstring); + + // Move construct a scoped utf chars + ScopedUtfChars moved_scoped_utf_chars(std::move(scoped_utf_chars)); + EXPECT_THAT(moved_scoped_utf_chars.c_str(), Eq(fake_string.c_str())); + EXPECT_THAT(moved_scoped_utf_chars.size(), Eq(3)); + + EXPECT_CALL(*env_mock, + ReleaseStringUTFChars(Eq(fake_jstring), Eq(fake_string.c_str()))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsMoveAssignment) { + // Setup the mock to return: + // "foo" for jstring (-303) + // "bar baz" for jstring (-505) + auto env_mock = std::make_unique<MockJNIEnv>(); + jstring fake_jstring1 = reinterpret_cast<jstring>(-303); + std::string fake_string1 = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring1), IsNull())) + .WillByDefault(Return(fake_string1.c_str())); + + jstring fake_jstring2 = reinterpret_cast<jstring>(-505); + std::string fake_string2 = "bar baz"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring2), IsNull())) + .WillByDefault(Return(fake_string2.c_str())); + + ScopedUtfChars scoped_utf_chars1(env_mock.get(), /*s=*/fake_jstring1); + ScopedUtfChars scoped_utf_chars2(env_mock.get(), /*s=*/fake_jstring2); + + // Move assign a scoped utf chars + scoped_utf_chars2 = std::move(scoped_utf_chars1); + EXPECT_THAT(scoped_utf_chars2.c_str(), Eq(fake_string1.c_str())); + EXPECT_THAT(scoped_utf_chars2.size(), Eq(3)); + + EXPECT_CALL(*env_mock, ReleaseStringUTFChars(Eq(fake_jstring1), + Eq(fake_string1.c_str()))) + .Times(1); + EXPECT_CALL(*env_mock, ReleaseStringUTFChars(Eq(fake_jstring2), + Eq(fake_string2.c_str()))) + .Times(1); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/legacy/index/icing-array-storage.cc b/icing/legacy/index/icing-array-storage.cc index 4d2ef67..de5178a 100644 --- a/icing/legacy/index/icing-array-storage.cc +++ b/icing/legacy/index/icing-array-storage.cc @@ -65,17 +65,13 @@ bool IcingArrayStorage::Init(int fd, size_t fd_offset, bool map_shared, return false; } if (file_size < fd_offset) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage file size %" PRIu64 " less than offset %zu", file_size, - fd_offset); + ICING_LOG(ERROR) << "Array storage file size " << file_size << " less than offset " << fd_offset; return false; } uint32_t capacity_num_elts = (file_size - fd_offset) / elt_size; if (capacity_num_elts < num_elts) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage num elts %u > capacity num elts %u", num_elts, - capacity_num_elts); + ICING_LOG(ERROR) << "Array storage num elts " << num_elts << " > capacity num elts " << capacity_num_elts; return false; } @@ -108,8 +104,7 @@ bool IcingArrayStorage::Init(int fd, size_t fd_offset, bool map_shared, if (init_crc) { *crc_ptr_ = crc; } else if (crc != *crc_ptr_) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage bad crc %u vs %u", crc, *crc_ptr_); + ICING_LOG(ERROR) << "Array storage bad crc " << crc << " vs " << *crc_ptr_; goto failed; } } @@ -276,9 +271,9 @@ void IcingArrayStorage::UpdateCrc() { cur_offset += change.elt_len * elt_size_; } if (!changes_.empty()) { - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Array update partial crcs %d truncated %d overlapped %d duplicate %d", - num_partial_crcs, num_truncated, num_overlapped, num_duplicate); + ICING_VLOG(2) << "Array update partial crcs " << num_partial_crcs + << " truncated " << num_truncated << " overlapped " << num_overlapped + << " duplicate " << num_duplicate; } // Now update with grown area. @@ -286,8 +281,7 @@ void IcingArrayStorage::UpdateCrc() { cur_crc = IcingStringUtil::UpdateCrc32( cur_crc, array_cast<char>() + changes_end_ * elt_size_, (cur_num_ - changes_end_) * elt_size_); - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Array update tail crc offset %u -> %u", changes_end_, cur_num_); + ICING_VLOG(2) << "Array update tail crc offset " << changes_end_ << " -> " << cur_num_; } // Clear, now that we've applied changes. @@ -341,8 +335,7 @@ uint32_t IcingArrayStorage::Sync() { if (pwrite(fd_, array() + dirty_start, dirty_end - dirty_start, fd_offset_ + dirty_start) != static_cast<ssize_t>(dirty_end - dirty_start)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing pages failed (%u, %u)", dirty_start, dirty_end); + ICING_LOG(ERROR) << "Flushing pages failed (" << dirty_start << ", " << dirty_end << ")"; } in_dirty = false; } else if (!in_dirty && is_dirty) { @@ -361,8 +354,7 @@ uint32_t IcingArrayStorage::Sync() { if (pwrite(fd_, array() + dirty_start, dirty_end - dirty_start, fd_offset_ + dirty_start) != static_cast<ssize_t>(dirty_end - dirty_start)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing pages failed (%u, %u)", dirty_start, dirty_end); + ICING_LOG(ERROR) << "Flushing pages failed (" << dirty_start << ", " << dirty_end << ")"; } } @@ -377,9 +369,7 @@ uint32_t IcingArrayStorage::Sync() { } if (num_flushed > 0) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Flushing %u/%u %u contiguous pages in %.3fms", num_flushed, - dirty_pages_size, num_contiguous, timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Flushing " << num_flushed << "/" << dirty_pages_size << " " << num_contiguous << " contiguous pages in " << timer.Elapsed() * 1000 << "ms."; } return num_flushed; diff --git a/icing/legacy/index/icing-common-types.h b/icing/legacy/index/icing-common-types.h deleted file mode 100644 index 592b549..0000000 --- a/icing/legacy/index/icing-common-types.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2014 Google Inc. All Rights Reserved. -// Author: sbanacho@google.com (Scott Banachowski) -// Author: csyoung@google.com (C. Sean Young) - -#ifndef ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ -#define ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ - -#include "icing/legacy/core/icing-core-types.h" - -// Protocol buffers are shared across several components. -namespace com { -namespace google { -namespace android { -namespace gms { -namespace icing { -namespace lib { - -class ClientFileGroup; -class Document; -class Document_Section; -class DocumentStoreStatusProto; -class IMEUpdate; -class IMEUpdateResponse; -class IndexCorpusScoringConfig; -class IndexCorpusScoringConfig_Section; -class IndexScoringConfig; -class InitStatus; -class InitStatus_CorpusInitInfo; -class PendingDeleteUsageReport; -class PhraseAffinityRequest; -class QueryResponse; -class QueryResponse_Corpus; -class QueryResponse_Corpus_Section; -class QueryResponse_Corpus_Tag; -class QueryRequestSpec; -class QueryRequestSpec_CorpusSpec; -class QueryRequestSpec_SectionSpec; -class ResponseDebugInfo; -class ResultDebugInfo; -class SectionConfig; -class SuggestionResponse; -class SuggestionResponse_Suggestion; -class UsageReportsResponse; -class UsageStats; -class UsageStats_Corpus; - -} // namespace lib -} // namespace icing -} // namespace gms -} // namespace android -} // namespace google -} // namespace com - -namespace icing { -namespace lib { - -// Typedefs. -using IcingDocId = uint32_t; - -using IcingSectionId = uint32_t; - -using IcingCorpusId = uint16_t; -using IcingSectionIdMask = uint16_t; - -using IcingTagsCount = uint16_t; - -using IcingSequenceNumber = int64_t; - -using IcingScore = uint64_t; - -constexpr size_t kIcingMaxTokenLen = 30; // default shared between query - // processor and indexer -constexpr int kIcingQueryTermLimit = 50; // Maximum number of terms in a query -constexpr int kIcingMaxVariantsPerToken = 10; // Maximum number of variants - -// LINT.IfChange -constexpr int kIcingDocIdBits = 20; // 1M docs -constexpr IcingDocId kIcingInvalidDocId = (1u << kIcingDocIdBits) - 1; -constexpr IcingDocId kIcingMaxDocId = kIcingInvalidDocId - 1; -// LINT.ThenChange(//depot/google3/wireless/android/icing/plx/google_sql_common_macros.sql) - -constexpr int kIcingDocScoreBits = 32; - -constexpr int kIcingSectionIdBits = 4; // 4 bits for 16 values -constexpr IcingSectionId kIcingMaxSectionId = (1u << kIcingSectionIdBits) - 1; -constexpr IcingSectionId kIcingInvalidSectionId = kIcingMaxSectionId + 1; -constexpr IcingSectionIdMask kIcingSectionIdMaskAll = ~IcingSectionIdMask{0}; -constexpr IcingSectionIdMask kIcingSectionIdMaskNone = IcingSectionIdMask{0}; - -constexpr int kIcingCorpusIdBits = 15; // 32K -constexpr IcingCorpusId kIcingInvalidCorpusId = (1u << kIcingCorpusIdBits) - 1; -constexpr IcingCorpusId kIcingMaxCorpusId = kIcingInvalidCorpusId - 1; - -constexpr size_t kIcingMaxSearchableDocumentSize = (1u << 16) - 1; // 64K -// Max num tokens per document. 64KB is our original maximum (searchable) -// document size. We clip if document exceeds this. -constexpr uint32_t kIcingMaxNumTokensPerDoc = - kIcingMaxSearchableDocumentSize / 5; -constexpr uint32_t kIcingMaxNumHitsPerDocument = - kIcingMaxNumTokensPerDoc * kIcingMaxVariantsPerToken; - -constexpr IcingTagsCount kIcingInvalidTagCount = ~IcingTagsCount{0}; -constexpr IcingTagsCount kIcingMaxTagCount = kIcingInvalidTagCount - 1; - -// Location refers to document storage. -constexpr uint64_t kIcingInvalidLocation = ~uint64_t{0}; -constexpr uint64_t kIcingMaxDocStoreWriteLocation = uint64_t{1} - << 32; // 4bytes. - -// Dump symbols in the proto namespace. -using namespace ::com::google::android::gms::icing; // NOLINT(build/namespaces) -} // namespace lib -} // namespace icing - -#endif // ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ diff --git a/icing/legacy/index/icing-dynamic-trie.cc b/icing/legacy/index/icing-dynamic-trie.cc index 77876c4..c6816ad 100644 --- a/icing/legacy/index/icing-dynamic-trie.cc +++ b/icing/legacy/index/icing-dynamic-trie.cc @@ -101,15 +101,9 @@ namespace { constexpr uint32_t kInvalidNodeIndex = (1U << 24) - 1; constexpr uint32_t kInvalidNextIndex = ~0U; -// Returns the number of valid nexts in the array. -int GetValidNextsSize(IcingDynamicTrie::Next *next_array_start, - int next_array_length) { - int valid_nexts_length = 0; - for (; valid_nexts_length < next_array_length && - next_array_start[valid_nexts_length].node_index() != kInvalidNodeIndex; - ++valid_nexts_length) { - } - return valid_nexts_length; +void ResetMutableNext(IcingDynamicTrie::Next &mutable_next) { + mutable_next.set_val(0xff); + mutable_next.set_node_index(kInvalidNodeIndex); } } // namespace @@ -466,8 +460,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Init() { if (i == 0) { // Header. if (file_size != IcingMMapper::system_page_size()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie hdr wrong size: %" PRIu64, file_size); + ICING_LOG(ERROR) << "Trie hdr wrong size: " << file_size; goto failed; } @@ -528,8 +521,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Init() { sizeof(char), hdr_.hdr.suffixes_size(), hdr_.hdr.max_suffixes_size(), &crcs_->array_crcs[SUFFIX], init_crcs)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie mmap suffix failed"); + ICING_LOG(ERROR) << "Trie mmap suffix failed"; goto failed; } @@ -677,8 +669,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Sync() { } if (!WriteHeader()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing trie header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Flushing trie header failed: " << strerror(errno); success = false; } @@ -692,8 +683,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Sync() { } if (total_flushed > 0) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Flushing %u pages of trie", - total_flushed); + ICING_VLOG(1) << "Flushing " << total_flushed << " pages of trie"; } return success; @@ -769,8 +759,7 @@ IcingDynamicTrie::IcingDynamicTrieStorage::AllocNextArray(int size) { // Fill with char 0xff so we are sorted properly. for (int i = 0; i < aligned_size; i++) { - ret[i].set_val(0xff); - ret[i].set_node_index(kInvalidNodeIndex); + ResetMutableNext(ret[i]); } return ret; } @@ -824,8 +813,7 @@ uint32_t IcingDynamicTrie::IcingDynamicTrieStorage::UpdateCrc() { uint32_t IcingDynamicTrie::IcingDynamicTrieStorage::UpdateCrcInternal( bool write_hdr) { if (write_hdr && !WriteHeader()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing trie header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Flushing trie header failed: " << strerror(errno); } crcs_->header_crc = GetHeaderCrc(); @@ -919,8 +907,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Header::SerializeToArray( bool IcingDynamicTrie::IcingDynamicTrieStorage::Header::Verify() { // Check version. if (hdr.version() != kCurVersion) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie version %u mismatch", hdr.version()); + ICING_LOG(ERROR) << "Trie version " << hdr.version() << " mismatch"; return false; } @@ -1162,9 +1149,8 @@ bool IcingDynamicTrie::Sync() { Warm(); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Syncing dynamic trie %s took %.3fms", filename_base_.c_str(), - timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Syncing dynamic trie " << filename_base_.c_str() + << " took " << timer.Elapsed() * 1000 << "ms"; return success; } @@ -1214,8 +1200,7 @@ std::unique_ptr<IcingFlashBitmap> IcingDynamicTrie::OpenAndInitBitmap( const IcingFilesystem *filesystem) { auto bitmap = std::make_unique<IcingFlashBitmap>(filename, filesystem); if (!bitmap->Init() || (verify && !bitmap->Verify())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Init of %s failed", - filename.c_str()); + ICING_LOG(ERROR) << "Init of " << filename.c_str() << " failed"; return nullptr; } return bitmap; @@ -1245,16 +1230,14 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { vector<std::string> files; if (!filesystem_->GetMatchingFiles((property_bitmaps_prefix_ + "*").c_str(), &files)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not get files at prefix %s", property_bitmaps_prefix_.c_str()); + ICING_LOG(ERROR) << "Could not get files at prefix " << property_bitmaps_prefix_; goto failed; } for (size_t i = 0; i < files.size(); i++) { // Decode property id from filename. size_t property_id_start_idx = files[i].rfind('.'); if (property_id_start_idx == std::string::npos) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Malformed filename %s", - files[i].c_str()); + ICING_LOG(ERROR) << "Malformed filename " << files[i]; continue; } property_id_start_idx++; // skip dot @@ -1262,8 +1245,7 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { uint32_t property_id = strtol(files[i].c_str() + property_id_start_idx, &end, 10); // NOLINT if (!end || end != (files[i].c_str() + files[i].size())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Malformed filename %s", - files[i].c_str()); + ICING_LOG(ERROR) << "Malformed filename " << files[i]; continue; } std::unique_ptr<IcingFlashBitmap> bitmap = OpenAndInitBitmap( @@ -1271,8 +1253,7 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { runtime_options_.storage_policy == RuntimeOptions::kMapSharedWithCrc, filesystem_); if (!bitmap) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Open prop bitmap failed: %s", files[i].c_str()); + ICING_LOG(ERROR) << "Open prop bitmap failed: " << files[i]; goto failed; } bitmap->Truncate(truncate_idx); @@ -1550,9 +1531,7 @@ bool IcingDynamicTrie::ResetNext(uint32_t next_index) { if (mutable_next == nullptr) { return false; } - - mutable_next->set_val(0); - mutable_next->set_node_index(kInvalidNodeIndex); + ResetMutableNext(*mutable_next); return true; } @@ -1570,7 +1549,7 @@ bool IcingDynamicTrie::SortNextArray(const Node *node) { return false; } - std::sort(next_array_start, next_array_start + next_array_buffer_size - 1); + std::sort(next_array_start, next_array_start + next_array_buffer_size); return true; } @@ -1804,11 +1783,12 @@ bool IcingDynamicTrie::Find(const char *key, void *value, } IcingDynamicTrie::Iterator::Iterator(const IcingDynamicTrie &trie, - const char *prefix) + const char *prefix, bool reverse) : cur_key_(prefix), cur_suffix_(nullptr), cur_suffix_len_(0), single_leaf_match_(false), + reverse_(reverse), trie_(trie) { if (!trie.is_initialized()) { ICING_LOG(FATAL) << "DynamicTrie not initialized"; @@ -1817,19 +1797,29 @@ IcingDynamicTrie::Iterator::Iterator(const IcingDynamicTrie &trie, Reset(); } -void IcingDynamicTrie::Iterator::LeftBranchToLeaf(uint32_t node_index) { +void IcingDynamicTrie::Iterator::BranchToLeaf(uint32_t node_index, + BranchType branch_type) { // Go down the trie, following the left-most child until we hit a // leaf. Push to stack and cur_key nodes and chars as we go. - for (; !trie_.storage_->GetNode(node_index)->is_leaf(); - node_index = - trie_.storage_ - ->GetNext(trie_.storage_->GetNode(node_index)->next_index(), 0) - ->node_index()) { - branch_stack_.push_back(Branch(node_index)); - cur_key_.push_back( - trie_.storage_ - ->GetNext(trie_.storage_->GetNode(node_index)->next_index(), 0) - ->val()); + // When reverse_ is true, the method will follow the right-most child. + const Node *node = trie_.storage_->GetNode(node_index); + while (!node->is_leaf()) { + const Next *next_start = trie_.storage_->GetNext(node->next_index(), 0); + int child_idx; + if (branch_type == BranchType::kRightMost) { + uint32_t next_array_size = 1u << node->log2_num_children(); + child_idx = trie_.GetValidNextsSize(next_start, next_array_size) - 1; + } else { + // node isn't a leaf. So it must have >0 children. + // 0 is the left-most child. + child_idx = 0; + } + const Next &child_next = next_start[child_idx]; + branch_stack_.push_back(Branch(node_index, child_idx)); + cur_key_.push_back(child_next.val()); + + node_index = child_next.node_index(); + node = trie_.storage_->GetNode(node_index); } // We're at a leaf. @@ -1865,7 +1855,7 @@ void IcingDynamicTrie::Iterator::Reset() { // Two cases/states: // // - Found an intermediate node. If we matched all of prefix - // (cur_key_), LeftBranchToLeaf. + // (cur_key_), BranchToLeaf. // // - Found a leaf node, which is the ONLY matching key for this // prefix. Check that suffix matches the prefix. Then we set @@ -1888,7 +1878,9 @@ void IcingDynamicTrie::Iterator::Reset() { cur_suffix_len_ = strlen(cur_suffix_); single_leaf_match_ = true; } else if (static_cast<size_t>(key_offset) == cur_key_.size()) { - LeftBranchToLeaf(node_index); + BranchType branch_type = + (reverse_) ? BranchType::kRightMost : BranchType::kLeftMost; + BranchToLeaf(node_index, branch_type); } } @@ -1915,19 +1907,25 @@ bool IcingDynamicTrie::Iterator::Advance() { while (!branch_stack_.empty()) { Branch *branch = &branch_stack_.back(); const Node *node = trie_.storage_->GetNode(branch->node_idx); - branch->child_idx++; - if (branch->child_idx < (1 << node->log2_num_children()) && - trie_.storage_->GetNext(node->next_index(), branch->child_idx) - ->node_index() != kInvalidNodeIndex) { - // Successfully incremented to the next child. Update the char - // value at this depth. - cur_key_[cur_key_.size() - 1] = - trie_.storage_->GetNext(node->next_index(), branch->child_idx)->val(); - // We successfully found a sub-trie to explore. - LeftBranchToLeaf( - trie_.storage_->GetNext(node->next_index(), branch->child_idx) - ->node_index()); - return true; + if (reverse_) { + branch->child_idx--; + } else { + branch->child_idx++; + } + if (branch->child_idx >= 0 && + branch->child_idx < (1 << node->log2_num_children())) { + const Next *child_next = + trie_.storage_->GetNext(node->next_index(), branch->child_idx); + if (child_next->node_index() != kInvalidNodeIndex) { + // Successfully incremented to the next child. Update the char + // value at this depth. + cur_key_[cur_key_.size() - 1] = child_next->val(); + // We successfully found a sub-trie to explore. + BranchType branch_type = + (reverse_) ? BranchType::kRightMost : BranchType::kLeftMost; + BranchToLeaf(child_next->node_index(), branch_type); + return true; + } } branch_stack_.pop_back(); cur_key_.resize(cur_key_.size() - 1); @@ -2116,22 +2114,34 @@ const IcingDynamicTrie::Next *IcingDynamicTrie::GetNextByChar( return found; } +int IcingDynamicTrie::GetValidNextsSize( + const IcingDynamicTrie::Next *next_array_start, + int next_array_length) const { + // Only searching for key char 0xff is not sufficient, as 0xff can be a valid + // character. We must also specify kInvalidNodeIndex as the target node index + // when searching the next array. + return LowerBound(next_array_start, next_array_start + next_array_length, + /*key_char=*/0xff, /*node_index=*/kInvalidNodeIndex) - + next_array_start; +} + const IcingDynamicTrie::Next *IcingDynamicTrie::LowerBound( - const Next *start, const Next *end, uint8_t key_char) const { + const Next *start, const Next *end, uint8_t key_char, + uint32_t node_index) const { // Above this value will use binary search instead of linear // search. 16 was chosen from running some benchmarks with // different values. static const uint32_t kBinarySearchCutoff = 16; + Next key_next(key_char, node_index); if (end - start >= kBinarySearchCutoff) { // Binary search. - Next key_next(key_char, 0); return lower_bound(start, end, key_next); } else { // Linear search. const Next *found; for (found = start; found < end; found++) { - if (found->val() >= key_char) { + if (!(*found < key_next)) { // Should have gotten match. break; } @@ -2275,6 +2285,41 @@ std::vector<int> IcingDynamicTrie::FindBranchingPrefixLengths(const char *key, return prefix_lengths; } +bool IcingDynamicTrie::IsBranchingTerm(const char *key) const { + if (!is_initialized()) { + ICING_LOG(FATAL) << "DynamicTrie not initialized"; + } + + if (storage_->empty()) { + return false; + } + + uint32_t best_node_index; + int key_offset; + FindBestNode(key, &best_node_index, &key_offset, /*prefix=*/true); + const Node *cur_node = storage_->GetNode(best_node_index); + + if (cur_node->is_leaf()) { + return false; + } + + // There is no intermediate node for key in the trie. + if (key[key_offset] != '\0') { + return false; + } + + // Found key as an intermediate node, but key is not a valid term stored in + // the trie. In this case, we need at least two children for key to be a + // branching term. + if (GetNextByChar(cur_node, '\0') == nullptr) { + return cur_node->log2_num_children() >= 1; + } + + // The intermediate node for key must have more than two children for key to + // be a branching term, one of which represents the leaf node for key itself. + return cur_node->log2_num_children() > 1; +} + void IcingDynamicTrie::GetDebugInfo(int verbosity, std::string *out) const { Stats stats; CollectStats(&stats); @@ -2284,8 +2329,7 @@ void IcingDynamicTrie::GetDebugInfo(int verbosity, std::string *out) const { vector<std::string> files; if (!filesystem_->GetMatchingFiles((property_bitmaps_prefix_ + "*").c_str(), &files)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not get files at prefix %s", property_bitmaps_prefix_.c_str()); + ICING_LOG(ERROR) << "Could not get files at prefix " << property_bitmaps_prefix_; return; } for (size_t i = 0; i < files.size(); i++) { @@ -2357,8 +2401,7 @@ IcingFlashBitmap *IcingDynamicTrie::OpenOrCreatePropertyBitmap( } if (property_id > kMaxPropertyId) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Property id %u out of range", property_id); + ICING_LOG(ERROR) << "Property id " << property_id << " out of range"; return nullptr; } @@ -2500,7 +2543,26 @@ bool IcingDynamicTrie::Delete(const std::string_view key) { for (uint32_t next_index : nexts_to_reset) { ResetNext(next_index); } - SortNextArray(last_multichild_node); + + if (last_multichild_node != nullptr) { + SortNextArray(last_multichild_node); + uint32_t next_array_buffer_size = + 1u << last_multichild_node->log2_num_children(); + Next *next_array_start = this->storage_->GetMutableNextArray( + last_multichild_node->next_index(), next_array_buffer_size); + uint32_t num_children = + GetValidNextsSize(next_array_start, next_array_buffer_size); + // Shrink the next array if we can. + if (num_children == next_array_buffer_size / 2) { + Node *mutable_node = storage_->GetMutableNode( + storage_->GetNodeIndex(last_multichild_node)); + mutable_node->set_log2_num_children(mutable_node->log2_num_children() - + 1); + // Add the unused second half of the next array to the free list. + storage_->FreeNextArray(next_array_start + next_array_buffer_size / 2, + mutable_node->log2_num_children()); + } + } return true; } @@ -2512,8 +2574,7 @@ bool IcingDynamicTrie::ClearPropertyForAllValues(uint32_t property_id) { PropertyReadersAll readers(*this); if (!readers.Exists(property_id)) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Properties for id %u don't exist", property_id); + ICING_VLOG(1) << "Properties for id " << property_id << " don't exist"; return true; } diff --git a/icing/legacy/index/icing-dynamic-trie.h b/icing/legacy/index/icing-dynamic-trie.h index 013b926..b172632 100644 --- a/icing/legacy/index/icing-dynamic-trie.h +++ b/icing/legacy/index/icing-dynamic-trie.h @@ -400,6 +400,13 @@ class IcingDynamicTrie : public IIcingStorage { // itself. If utf8 is true, does not cut key mid-utf8. std::vector<int> FindBranchingPrefixLengths(const char *key, bool utf8) const; + // Check if key is a branching term. + // + // key is a branching term, if and only if there exists terms s1 and s2 in the + // trie such that key is the maximum common prefix of s1 and s2, but s1 and s2 + // are not prefixes of each other. + bool IsBranchingTerm(const char *key) const; + void GetDebugInfo(int verbosity, std::string *out) const override; double min_free_fraction() const; @@ -510,7 +517,8 @@ class IcingDynamicTrie : public IIcingStorage { // Change in underlying trie invalidates iterator. class Iterator { public: - Iterator(const IcingDynamicTrie &trie, const char *prefix); + Iterator(const IcingDynamicTrie &trie, const char *prefix, + bool reverse = false); void Reset(); bool Advance(); @@ -527,9 +535,10 @@ class IcingDynamicTrie : public IIcingStorage { Iterator(); // Copy is ok. - // Helper function that takes the left-most branch down - // intermediate nodes to a leaf. - void LeftBranchToLeaf(uint32_t node_index); + enum class BranchType { kLeftMost = 0, kRightMost = 1 }; + // Helper function that takes the left-most or the right-most branch down + // intermediate nodes to a leaf, based on branch_type. + void BranchToLeaf(uint32_t node_index, BranchType branch_type); std::string cur_key_; const char *cur_suffix_; @@ -538,10 +547,12 @@ class IcingDynamicTrie : public IIcingStorage { uint32_t node_idx; int child_idx; - explicit Branch(uint32_t ni) : node_idx(ni), child_idx(0) {} + explicit Branch(uint32_t node_index, int child_index) + : node_idx(node_index), child_idx(child_index) {} }; std::vector<Branch> branch_stack_; bool single_leaf_match_; + bool reverse_; const IcingDynamicTrie &trie_; }; @@ -612,8 +623,11 @@ class IcingDynamicTrie : public IIcingStorage { // Helpers for Find and Insert. const Next *GetNextByChar(const Node *node, uint8_t key_char) const; - const Next *LowerBound(const Next *start, const Next *end, - uint8_t key_char) const; + const Next *LowerBound(const Next *start, const Next *end, uint8_t key_char, + uint32_t node_index = 0) const; + // Returns the number of valid nexts in the array. + int GetValidNextsSize(const IcingDynamicTrie::Next *next_array_start, + int next_array_length) const; void FindBestNode(const char *key, uint32_t *best_node_index, int *key_offset, bool prefix, bool utf8 = false) const; diff --git a/icing/legacy/index/icing-dynamic-trie_test.cc b/icing/legacy/index/icing-dynamic-trie_test.cc index 193765b..850fcdc 100644 --- a/icing/legacy/index/icing-dynamic-trie_test.cc +++ b/icing/legacy/index/icing-dynamic-trie_test.cc @@ -20,6 +20,7 @@ #include <memory> #include <string> #include <unordered_map> +#include <unordered_set> #include <vector> #include "icing/text_classifier/lib3/utils/hash/farmhash.h" @@ -27,15 +28,19 @@ #include "gtest/gtest.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/index/icing-filesystem.h" +#include "icing/testing/random-string.h" #include "icing/testing/tmp-directory.h" - -using testing::ElementsAre; +#include "icing/util/logging.h" namespace icing { namespace lib { namespace { +using testing::ContainerEq; +using testing::ElementsAre; +using testing::StrEq; + constexpr std::string_view kKeys[] = { "", "ab", "ac", "abd", "bac", "bb", "bacd", "abbb", "abcdefg", }; @@ -105,6 +110,17 @@ class IcingDynamicTrieTest : public ::testing::Test { std::string trie_files_prefix_; }; +std::vector<std::pair<std::string, int>> RetrieveKeyValuePairs( + IcingDynamicTrie::Iterator& term_iter) { + std::vector<std::pair<std::string, int>> key_value; + for (; term_iter.IsValid(); term_iter.Advance()) { + uint32_t val; + memcpy(&val, term_iter.GetValue(), sizeof(val)); + key_value.push_back(std::make_pair(term_iter.GetKey(), val)); + } + return key_value; +} + constexpr std::string_view kCommonEnglishWords[] = { "that", "was", "for", "on", "are", "with", "they", "be", "at", "one", "have", "this", "from", "word", "but", "what", "some", "you", @@ -157,7 +173,6 @@ TEST_F(IcingDynamicTrieTest, Init) { TEST_F(IcingDynamicTrieTest, Iterator) { // Test iterator. IcingFilesystem filesystem; - uint32_t val; IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), &filesystem); ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); @@ -167,104 +182,161 @@ TEST_F(IcingDynamicTrieTest, Iterator) { ASSERT_TRUE(trie.Insert(kKeys[i].data(), &i)); } - // We try everything twice to test that Reset also works. - // Should get the entire trie. + std::vector<std::pair<std::string, int>> exp_key_values = { + {"", 0}, {"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}, + {"ac", 2}, {"bac", 4}, {"bacd", 6}, {"bb", 5}}; IcingDynamicTrie::Iterator it_all(trie, ""); - for (int i = 0; i < 2; i++) { - uint32_t count = 0; - for (; it_all.IsValid(); it_all.Advance()) { - uint32_t val_idx = it_all.GetValueIndex(); - EXPECT_EQ(it_all.GetValue(), trie.GetValueAtIndex(val_idx)); - count++; - } - EXPECT_EQ(count, kNumKeys); - it_all.Reset(); - } + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it_all.Reset(); + key_values = RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); // Get everything under "a". + exp_key_values = { + {"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}, {"ac", 2}}; IcingDynamicTrie::Iterator it1(trie, "a"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "ab"); - static const uint32_t kOne = 1; - ASSERT_TRUE(it1.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it1.GetValue(), &kOne, sizeof(kOne))); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abbb"); + // Should get same results after calling Reset + it1.Reset(); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abcdefg"); + // Now "b". + exp_key_values = {{"bac", 4}, {"bacd", 6}, {"bb", 5}}; + IcingDynamicTrie::Iterator it2(trie, "b"); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abd"); + // Should get same results after calling Reset + it2.Reset(); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "ac"); + // Get everything under "ab". + exp_key_values = {{"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}}; + IcingDynamicTrie::Iterator it3(trie, "ab"); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - EXPECT_FALSE(it1.Advance()); - EXPECT_FALSE(it1.IsValid()); + // Should get same results after calling Reset + it3.Reset(); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - it1.Reset(); + // Should match only one key exactly. + constexpr std::string_view kOneMatch[] = { + "abd", + "abcd", + "abcdef", + "abcdefg", + }; + // With the following match: + constexpr std::string_view kOneMatchMatched[] = { + "abd", + "abcdefg", + "abcdefg", + "abcdefg", + }; + + for (size_t k = 0; k < ABSL_ARRAYSIZE(kOneMatch); k++) { + IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data()); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; + + // Should get same results after calling Reset + it_single.Reset(); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; } - // Now "b". - IcingDynamicTrie::Iterator it2(trie, "b"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bac"); - val = 1; - ASSERT_TRUE(it1.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it1.GetValue(), &val, sizeof(val))); - val = 4; - ASSERT_TRUE(it2.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it2.GetValue(), &val, sizeof(val))); - - ASSERT_TRUE(it2.Advance()); - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bacd"); - - ASSERT_TRUE(it2.Advance()); - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bb"); - - EXPECT_FALSE(it2.Advance()); - EXPECT_FALSE(it2.IsValid()); - - it2.Reset(); + // Matches nothing. + constexpr std::string_view kNoMatch[] = { + "abbd", + "abcdeg", + "abcdefh", + }; + for (size_t k = 0; k < ABSL_ARRAYSIZE(kNoMatch); k++) { + IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data()); + EXPECT_FALSE(it_empty.IsValid()); + it_empty.Reset(); + EXPECT_FALSE(it_empty.IsValid()); } - // Get everything under "ab". - IcingDynamicTrie::Iterator it3(trie, "ab"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "ab"); - val = 1; - ASSERT_TRUE(it3.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it3.GetValue(), &val, sizeof(val))); + // Clear. + trie.Clear(); + EXPECT_FALSE(IcingDynamicTrie::Iterator(trie, "").IsValid()); + EXPECT_EQ(0u, trie.size()); + EXPECT_EQ(1.0, trie.min_free_fraction()); +} - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abbb"); +TEST_F(IcingDynamicTrieTest, IteratorReverse) { + // Test iterator. + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abcdefg"); + for (uint32_t i = 0; i < kNumKeys; i++) { + ASSERT_TRUE(trie.Insert(kKeys[i].data(), &i)); + } - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abd"); + // Should get the entire trie. + std::vector<std::pair<std::string, int>> exp_key_values = { + {"bb", 5}, {"bacd", 6}, {"bac", 4}, {"ac", 2}, {"abd", 3}, + {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}, {"", 0}}; + IcingDynamicTrie::Iterator it_all(trie, "", /*reverse=*/true); + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + it_all.Reset(); + key_values = RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Get everything under "a". + exp_key_values = { + {"ac", 2}, {"abd", 3}, {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}}; + IcingDynamicTrie::Iterator it1(trie, "a", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - EXPECT_FALSE(it3.Advance()); - EXPECT_FALSE(it3.IsValid()); + // Should get same results after calling Reset + it1.Reset(); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - it3.Reset(); - } + // Now "b". + exp_key_values = {{"bb", 5}, {"bacd", 6}, {"bac", 4}}; + IcingDynamicTrie::Iterator it2(trie, "b", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it2.Reset(); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Get everything under "ab". + exp_key_values = {{"abd", 3}, {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}}; + IcingDynamicTrie::Iterator it3(trie, "ab", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it3.Reset(); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); // Should match only one key exactly. constexpr std::string_view kOneMatch[] = { @@ -282,15 +354,19 @@ TEST_F(IcingDynamicTrieTest, Iterator) { }; for (size_t k = 0; k < ABSL_ARRAYSIZE(kOneMatch); k++) { - IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data()); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; - EXPECT_STREQ(it_single.GetKey(), kOneMatchMatched[k].data()); - EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; - EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; - - it_single.Reset(); - } + IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data(), + /*reverse=*/true); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; + + // Should get same results after calling Reset + it_single.Reset(); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; } // Matches nothing. @@ -300,21 +376,65 @@ TEST_F(IcingDynamicTrieTest, Iterator) { "abcdefh", }; for (size_t k = 0; k < ABSL_ARRAYSIZE(kNoMatch); k++) { - IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data()); - for (int i = 0; i < 2; i++) { - EXPECT_FALSE(it_empty.IsValid()); - - it_empty.Reset(); - } + IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data(), + /*reverse=*/true); + EXPECT_FALSE(it_empty.IsValid()); + it_empty.Reset(); + EXPECT_FALSE(it_empty.IsValid()); } // Clear. trie.Clear(); - EXPECT_FALSE(IcingDynamicTrie::Iterator(trie, "").IsValid()); + EXPECT_FALSE( + IcingDynamicTrie::Iterator(trie, "", /*reverse=*/true).IsValid()); EXPECT_EQ(0u, trie.size()); EXPECT_EQ(1.0, trie.min_free_fraction()); } +TEST_F(IcingDynamicTrieTest, IteratorLoadTest) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + std::default_random_engine random; + ICING_LOG(ERROR) << "Seed: " << std::default_random_engine::default_seed; + + std::vector<std::pair<std::string, int>> exp_key_values; + // Randomly generate 1024 terms. + for (int i = 0; i < 1024; ++i) { + std::string term = RandomString("abcdefg", 5, &random) + std::to_string(i); + ASSERT_TRUE(trie.Insert(term.c_str(), &i)); + exp_key_values.push_back(std::make_pair(term, i)); + } + // Lexicographically sort the expected keys. + std::sort(exp_key_values.begin(), exp_key_values.end()); + + // Check that the iterator works. + IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(term_iter); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Check that Reset works. + term_iter.Reset(); + key_values = RetrieveKeyValuePairs(term_iter); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + std::reverse(exp_key_values.begin(), exp_key_values.end()); + // Check that the reverse iterator works. + IcingDynamicTrie::Iterator term_iter_reverse(trie, /*prefix=*/"", + /*reverse=*/true); + key_values = RetrieveKeyValuePairs(term_iter_reverse); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Check that Reset works. + term_iter_reverse.Reset(); + key_values = RetrieveKeyValuePairs(term_iter_reverse); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); +} + TEST_F(IcingDynamicTrieTest, Persistence) { // Test persistence on the English dictionary. IcingFilesystem filesystem; @@ -962,6 +1082,102 @@ TEST_F(IcingDynamicTrieTest, DeletingNonExistingKeyShouldReturnTrue) { EXPECT_TRUE(trie.Find("bed", &value)); } +TEST_F(IcingDynamicTrieTest, DeletionResortsFullNextArray) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + uint32_t value = 1; + // 'f' -> [ 'a', 'j', 'o', 'u' ] + ASSERT_TRUE(trie.Insert("foul", &value)); + ASSERT_TRUE(trie.Insert("far", &value)); + ASSERT_TRUE(trie.Insert("fudge", &value)); + ASSERT_TRUE(trie.Insert("fjord", &value)); + + // Delete the third child + EXPECT_TRUE(trie.Delete("foul")); + + std::vector<std::string> remaining; + for (IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + remaining.push_back(term_iter.GetKey()); + } + EXPECT_THAT(remaining, ElementsAre("far", "fjord", "fudge")); +} + +TEST_F(IcingDynamicTrieTest, DeletionResortsPartiallyFilledNextArray) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + uint32_t value = 1; + // 'f' -> [ 'a', 'o', 'u', 0xFF ] + ASSERT_TRUE(trie.Insert("foul", &value)); + ASSERT_TRUE(trie.Insert("far", &value)); + ASSERT_TRUE(trie.Insert("fudge", &value)); + + // Delete the second child + EXPECT_TRUE(trie.Delete("foul")); + + std::vector<std::string> remaining; + for (IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + remaining.push_back(term_iter.GetKey()); + } + EXPECT_THAT(remaining, ElementsAre("far", "fudge")); +} + +TEST_F(IcingDynamicTrieTest, DeletionLoadTest) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + std::default_random_engine random; + ICING_LOG(ERROR) << "Seed: " << std::default_random_engine::default_seed; + std::vector<std::string> terms; + uint32_t value; + // Randomly generate 2048 terms. + for (int i = 0; i < 2048; ++i) { + terms.push_back(RandomString("abcdefg", 5, &random)); + ASSERT_TRUE(trie.Insert(terms.back().c_str(), &value)); + } + + // Randomly delete 1024 terms. + std::unordered_set<std::string> exp_remaining(terms.begin(), terms.end()); + std::shuffle(terms.begin(), terms.end(), random); + for (int i = 0; i < 1024; ++i) { + exp_remaining.erase(terms[i]); + ASSERT_TRUE(trie.Delete(terms[i].c_str())); + } + + // Check that the iterator still works, and the remaining terms are correct. + std::unordered_set<std::string> remaining; + for (IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + remaining.insert(term_iter.GetKey()); + } + EXPECT_THAT(remaining, ContainerEq(exp_remaining)); + + // Check that we can still insert terms after delete. + for (int i = 0; i < 2048; ++i) { + std::string term = RandomString("abcdefg", 5, &random); + ASSERT_TRUE(trie.Insert(term.c_str(), &value)); + exp_remaining.insert(term); + } + remaining.clear(); + for (IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + remaining.insert(term_iter.GetKey()); + } + EXPECT_THAT(remaining, ContainerEq(exp_remaining)); +} + } // namespace // The tests below are accessing private methods and fields of IcingDynamicTrie @@ -1133,5 +1349,142 @@ TEST_F(IcingDynamicTrieTest, BitmapsClosedWhenInitFails) { ASSERT_EQ(0, trie.property_bitmaps_.size()); } +TEST_F(IcingDynamicTrieTest, IsBranchingTermShouldWorkForExistingTerms) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + uint32_t value = 1; + + ASSERT_TRUE(trie.Insert("", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + + ASSERT_TRUE(trie.Insert("ab", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + + ASSERT_TRUE(trie.Insert("ac", &value)); + // "" is a prefix of "ab" and "ac", but it is not a branching term. + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + + ASSERT_TRUE(trie.Insert("ba", &value)); + // "" now branches to "ba" + EXPECT_TRUE(trie.IsBranchingTerm("")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + EXPECT_FALSE(trie.IsBranchingTerm("ba")); + + ASSERT_TRUE(trie.Insert("a", &value)); + EXPECT_TRUE(trie.IsBranchingTerm("")); + // "a" branches to "ab" and "ac" + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + EXPECT_FALSE(trie.IsBranchingTerm("ba")); + + ASSERT_TRUE(trie.Insert("abc", &value)); + ASSERT_TRUE(trie.Insert("acd", &value)); + EXPECT_TRUE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + // "ab" is a prefix of "abc", but it is not a branching term. + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + // "ac" is a prefix of "acd", but it is not a branching term. + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + EXPECT_FALSE(trie.IsBranchingTerm("ba")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + EXPECT_FALSE(trie.IsBranchingTerm("acd")); + + ASSERT_TRUE(trie.Insert("abcd", &value)); + EXPECT_TRUE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + // "ab" is a prefix of "abc" and "abcd", but it is not a branching term. + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + EXPECT_FALSE(trie.IsBranchingTerm("ba")); + // "abc" is a prefix of "abcd", but it is not a branching term. + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + EXPECT_FALSE(trie.IsBranchingTerm("acd")); + EXPECT_FALSE(trie.IsBranchingTerm("abcd")); + + ASSERT_TRUE(trie.Insert("abd", &value)); + EXPECT_TRUE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + // "ab" branches to "abc" and "abd" + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("ac")); + EXPECT_FALSE(trie.IsBranchingTerm("ba")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + EXPECT_FALSE(trie.IsBranchingTerm("acd")); + EXPECT_FALSE(trie.IsBranchingTerm("abcd")); + EXPECT_FALSE(trie.IsBranchingTerm("abd")); +} + +TEST_F(IcingDynamicTrieTest, IsBranchingTermShouldWorkForNonExistingTerms) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + uint32_t value = 1; + + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_FALSE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("aa", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_FALSE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("ac", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + // "a" does not exist in the trie, but now it branches to "aa" and "ac". + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("ad", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("abcd", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("abd", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + // "ab" does not exist in the trie, but now it branches to "abcd" and "abd". + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("abce", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + // "abc" does not exist in the trie, but now it branches to "abcd" and "abce". + EXPECT_TRUE(trie.IsBranchingTerm("abc")); + + ASSERT_TRUE(trie.Insert("abc_suffix", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + EXPECT_TRUE(trie.IsBranchingTerm("abc")); + EXPECT_FALSE(trie.IsBranchingTerm("abc_s")); + EXPECT_FALSE(trie.IsBranchingTerm("abc_su")); + EXPECT_FALSE(trie.IsBranchingTerm("abc_suffi")); +} + } // namespace lib } // namespace icing diff --git a/icing/legacy/index/icing-filesystem.cc b/icing/legacy/index/icing-filesystem.cc index 4f5e571..fbf5a27 100644 --- a/icing/legacy/index/icing-filesystem.cc +++ b/icing/legacy/index/icing-filesystem.cc @@ -65,18 +65,15 @@ void LogOpenFileDescriptors() { constexpr int kMaxFileDescriptorsToStat = 4096; struct rlimit rlim = {0, 0}; if (getrlimit(RLIMIT_NOFILE, &rlim) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "getrlimit() failed (errno=%d)", errno); + ICING_LOG(ERROR) << "getrlimit() failed (errno=" << errno << ")"; return; } int fd_lim = rlim.rlim_cur; if (fd_lim > kMaxFileDescriptorsToStat) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Maximum number of file descriptors (%d) too large.", fd_lim); + ICING_LOG(ERROR) << "Maximum number of file descriptors (" << fd_lim << ") too large."; fd_lim = kMaxFileDescriptorsToStat; } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Listing up to %d file descriptors.", fd_lim); + ICING_LOG(ERROR) << "Listing up to " << fd_lim << " file descriptors."; // Verify that /proc/self/fd is a directory. If not, procfs is not mounted or // inaccessible for some other reason. In that case, there's no point trying @@ -98,15 +95,12 @@ void LogOpenFileDescriptors() { if (len >= 0) { // Zero-terminate the buffer, because readlink() won't. target[len < target_size ? len : target_size - 1] = '\0'; - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> \"%s\"", fd, - target); + ICING_LOG(ERROR) << "fd " << fd << " -> \"" << target << "\""; } else if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> ? (errno=%d)", - fd, errno); + ICING_LOG(ERROR) << "fd " << fd << " -> ? (errno=" << errno << ")"; } } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "File descriptor list complete."); + ICING_LOG(ERROR) << "File descriptor list complete."; } // Logs an error formatted as: desc1 + file_name + desc2 + strerror(errnum). @@ -115,8 +109,7 @@ void LogOpenFileDescriptors() { // file descriptors (see LogOpenFileDescriptors() above). void LogOpenError(const char *desc1, const char *file_name, const char *desc2, int errnum) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "%s%s%s%s", desc1, file_name, desc2, strerror(errnum)); + ICING_LOG(ERROR) << desc1 << file_name << desc2 << strerror(errnum); if (errnum == EMFILE) { LogOpenFileDescriptors(); } @@ -157,8 +150,7 @@ bool ListDirectoryInternal(const char *dir_name, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Error closing %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << dir_name << ": " << strerror(errno); } return true; } @@ -181,12 +173,11 @@ void IcingScopedFd::reset(int fd) { const uint64_t IcingFilesystem::kBadFileSize; bool IcingFilesystem::DeleteFile(const char *file_name) const { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Deleting file %s", file_name); + ICING_VLOG(1) << "Deleting file " << file_name; int ret = unlink(file_name); bool success = (ret == 0) || (errno == ENOENT); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting file %s failed: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting file " << file_name << " failed: " << strerror(errno); } return success; } @@ -195,8 +186,7 @@ bool IcingFilesystem::DeleteDirectory(const char *dir_name) const { int ret = rmdir(dir_name); bool success = (ret == 0) || (errno == ENOENT); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting directory " << dir_name << " failed: " << strerror(errno); } return success; } @@ -208,8 +198,7 @@ bool IcingFilesystem::DeleteDirectoryRecursively(const char *dir_name) const { if (errno == ENOENT) { return true; // If directory didn't exist, this was successful. } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Stat " << dir_name << " failed: " << strerror(errno); return false; } vector<std::string> entries; @@ -222,8 +211,7 @@ bool IcingFilesystem::DeleteDirectoryRecursively(const char *dir_name) const { ++i) { std::string filename = std::string(dir_name) + '/' + *i; if (stat(filename.c_str(), &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", filename.c_str(), strerror(errno)); + ICING_LOG(ERROR) << "Stat " << filename << " failed: " << strerror(errno); success = false; } else if (S_ISDIR(st.st_mode)) { success = DeleteDirectoryRecursively(filename.c_str()) && success; @@ -246,8 +234,7 @@ bool IcingFilesystem::FileExists(const char *file_name) const { exists = S_ISREG(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << file_name << ": " << strerror(errno); } exists = false; } @@ -261,8 +248,7 @@ bool IcingFilesystem::DirectoryExists(const char *dir_name) const { exists = S_ISDIR(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat directory %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat directory " << dir_name << ": " << strerror(errno); } exists = false; } @@ -317,8 +303,7 @@ bool IcingFilesystem::GetMatchingFiles(const char *glob, int basename_idx = GetBasenameIndex(glob); if (basename_idx == 0) { // We need a directory. - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Expected directory, no matching files for: %s", glob); + ICING_VLOG(1) << "Expected directory, no matching files for: " << glob; return true; } const char *basename_glob = glob + basename_idx; @@ -374,8 +359,7 @@ uint64_t IcingFilesystem::GetFileSize(int fd) const { struct stat st; uint64_t size = kBadFileSize; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); } else { size = st.st_size; } @@ -386,8 +370,7 @@ uint64_t IcingFilesystem::GetFileSize(const char *filename) const { struct stat st; uint64_t size = kBadFileSize; if (stat(filename, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << filename << ": " << strerror(errno); } else { size = st.st_size; } @@ -399,8 +382,7 @@ bool IcingFilesystem::Truncate(int fd, uint64_t new_size) const { if (ret == 0) { lseek(fd, new_size, SEEK_SET); } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to truncate file: %s", strerror(errno)); + ICING_LOG(ERROR) << "Unable to truncate file: " << strerror(errno); } return (ret == 0); } @@ -418,8 +400,7 @@ bool IcingFilesystem::Truncate(const char *filename, uint64_t new_size) const { bool IcingFilesystem::Grow(int fd, uint64_t new_size) const { int ret = ftruncate(fd, new_size); if (ret != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to grow file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to grow file: " << strerror(errno); } return (ret == 0); } @@ -431,8 +412,7 @@ bool IcingFilesystem::Write(int fd, const void *data, size_t data_size) const { size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = write(fd, data, chunk_size); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t *>(data) + wrote; @@ -449,8 +429,7 @@ bool IcingFilesystem::PWrite(int fd, off_t offset, const void *data, size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = pwrite(fd, data, chunk_size, offset); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t *>(data) + wrote; @@ -468,8 +447,7 @@ bool IcingFilesystem::DataSync(int fd) const { #endif if (result < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to sync data: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to sync data: " << strerror(errno); return false; } return true; @@ -478,9 +456,7 @@ bool IcingFilesystem::DataSync(int fd) const { bool IcingFilesystem::RenameFile(const char *old_name, const char *new_name) const { if (rename(old_name, new_name) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to rename file %s to %s: %s", old_name, new_name, - strerror(errno)); + ICING_LOG(ERROR) << "Unable to rename file " << old_name << " to " << new_name << ": " << strerror(errno); return false; } return true; @@ -518,8 +494,7 @@ bool IcingFilesystem::CreateDirectory(const char *dir_name) const { if (mkdir(dir_name, S_IRUSR | S_IWUSR | S_IXUSR) == 0) { success = true; } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Creating directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Creating directory " << dir_name << " failed: " << strerror(errno); } } return success; @@ -561,8 +536,7 @@ end: if (src_fd > 0) close(src_fd); if (dst_fd > 0) close(dst_fd); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Couldn't copy file %s to %s", src, dst); + ICING_LOG(ERROR) << "Couldn't copy file " << src << " to " << dst; } return success; } @@ -583,8 +557,7 @@ bool IcingFilesystem::ComputeChecksum(int fd, uint32_t *checksum, uint64_t IcingFilesystem::GetDiskUsage(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -593,8 +566,7 @@ uint64_t IcingFilesystem::GetDiskUsage(int fd) const { uint64_t IcingFilesystem::GetFileDiskUsage(const char *path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -603,8 +575,7 @@ uint64_t IcingFilesystem::GetFileDiskUsage(const char *path) const { uint64_t IcingFilesystem::GetDiskUsage(const char *path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } uint64_t result = st.st_blocks * kStatBlockSize; diff --git a/icing/legacy/index/icing-flash-bitmap.cc b/icing/legacy/index/icing-flash-bitmap.cc index 56dec00..774308f 100644 --- a/icing/legacy/index/icing-flash-bitmap.cc +++ b/icing/legacy/index/icing-flash-bitmap.cc @@ -73,8 +73,7 @@ class IcingFlashBitmap::Accessor { bool IcingFlashBitmap::Verify() const { if (!is_initialized()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Can't verify unopened flash bitmap %s", filename_.c_str()); + ICING_LOG(ERROR) << "Can't verify unopened flash bitmap " << filename_; return false; } if (mmapper_ == nullptr) { @@ -83,26 +82,21 @@ bool IcingFlashBitmap::Verify() const { } Accessor accessor(mmapper_.get()); if (accessor.header()->magic != kMagic) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect magic header", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect magic header"; return false; } if (accessor.header()->version != kCurVersion) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect version", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect version"; return false; } if (accessor.header()->dirty) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s is dirty", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " is dirty"; return false; } uint32_t crc = IcingStringUtil::UpdateCrc32(0, accessor.data(), accessor.data_size()); if (accessor.header()->crc != crc) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect CRC32 %u %u", filename_.c_str(), - accessor.header()->crc, crc); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect CRC32 " << accessor.header()->crc << " " << crc; return false; } return true; @@ -265,17 +259,14 @@ uint32_t IcingFlashBitmap::UpdateCrc() const { bool IcingFlashBitmap::Grow(size_t new_file_size) { IcingScopedFd fd(filesystem_->OpenForWrite(filename_.c_str())); if (!filesystem_->Grow(fd.get(), new_file_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Grow %s to new size %zu failed", filename_.c_str(), new_file_size); + ICING_LOG(ERROR) << "Grow " << filename_ << " to new size " << new_file_size << " failed"; return false; } if (!mmapper_->Remap(fd.get(), 0, new_file_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Remap of %s after grow failed", filename_.c_str()); + ICING_LOG(ERROR) << "Remap of " << filename_ << " after grow failed"; return false; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Grew %s new size %zu", filename_.c_str(), new_file_size); + ICING_VLOG(1) << "Grew " << filename_ << " new size " << new_file_size; Accessor accessor(mmapper_.get()); accessor.header()->dirty = true; return true; diff --git a/icing/legacy/index/icing-mmapper.cc b/icing/legacy/index/icing-mmapper.cc index 7946c82..d086da2 100644 --- a/icing/legacy/index/icing-mmapper.cc +++ b/icing/legacy/index/icing-mmapper.cc @@ -67,8 +67,7 @@ void IcingMMapper::DoMapping(int fd, uint64_t location, size_t size) { address_ = reinterpret_cast<uint8_t *>(mmap_result_) + alignment_adjustment; } else { const char *errstr = strerror(errno); - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not mmap file for reading: %s", errstr); + ICING_LOG(ERROR) << "Could not mmap file for reading: " << errstr; mmap_result_ = nullptr; } } @@ -95,8 +94,7 @@ IcingMMapper::~IcingMMapper() { Unmap(); } bool IcingMMapper::Sync() { if (is_valid() && !read_only_) { if (msync(mmap_result_, mmap_len_, MS_SYNC) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("msync failed: %s", - strerror(errno)); + ICING_LOG(ERROR) << "msync failed: " << strerror(errno); return false; } } diff --git a/icing/legacy/index/icing-storage-file.cc b/icing/legacy/index/icing-storage-file.cc index 35a4418..bbc6b81 100644 --- a/icing/legacy/index/icing-storage-file.cc +++ b/icing/legacy/index/icing-storage-file.cc @@ -69,22 +69,18 @@ bool IcingStorageFile::Sync() { IcingTimer timer; if (!PreSync()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Pre-sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Pre-sync " << filename_ << " failed"; return false; } if (!filesystem_->DataSync(fd_.get())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Sync " << filename_ << " failed"; return false; } if (!PostSync()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Post-sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Post-sync " << filename_ << " failed"; return false; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Syncing %s took %.3fms", filename_.c_str(), timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Syncing " << filename_ << " took " << timer.Elapsed() * 1000 << "ms"; return true; } diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc index e48fe78..b505ac5 100644 --- a/icing/query/query-processor_benchmark.cc +++ b/icing/query/query-processor_benchmark.cc @@ -37,7 +37,7 @@ // //icing/query:query-processor_benchmark // // $ blaze-bin/icing/query/query-processor_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // Make target //icing/tokenization:language-segmenter depend on @@ -53,7 +53,7 @@ // $ adb push blaze-bin/icing/query/query-processor_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/query-processor_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/query-processor_benchmark --benchmark_filter=all // --adb // Flag to tell the benchmark that it'll be run on an Android device via adb, diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index eaa0efc..d1cce87 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -17,7 +17,6 @@ #include <memory> #include <string> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -27,6 +26,7 @@ #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" #include "icing/proto/schema.pb.h" @@ -127,22 +127,23 @@ class QueryProcessorTest : public Test { schema_store_.reset(); filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); } - Filesystem filesystem_; const std::string test_dir_; const std::string store_dir_; const std::string schema_store_dir_; + + private: + IcingFilesystem icing_filesystem_; + const std::string index_dir_; + + protected: std::unique_ptr<Index> index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; - std::unique_ptr<SchemaStore> schema_store_; - std::unique_ptr<DocumentStore> document_store_; FakeClock fake_clock_; std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache(); - - private: - IcingFilesystem icing_filesystem_; - const std::string index_dir_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> document_store_; }; TEST_F(QueryProcessorTest, CreationWithNullPointerShouldFail) { diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc index cfa53f6..b1a5a9e 100644 --- a/icing/query/suggestion-processor.cc +++ b/icing/query/suggestion-processor.cc @@ -93,4 +93,4 @@ SuggestionProcessor::SuggestionProcessor( normalizer_(*normalizer) {} } // namespace lib -} // namespace icing +} // namespace icing
\ No newline at end of file diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc index ba4c90a..b3012e9 100644 --- a/icing/query/suggestion-processor_test.cc +++ b/icing/query/suggestion-processor_test.cc @@ -99,16 +99,18 @@ class SuggestionProcessorTest : public Test { Filesystem filesystem_; const std::string test_dir_; const std::string store_dir_; + + private: + IcingFilesystem icing_filesystem_; + const std::string index_dir_; + + protected: std::unique_ptr<Index> index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; + FakeClock fake_clock_; std::unique_ptr<SchemaStore> schema_store_; std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache(); - FakeClock fake_clock_; - - private: - IcingFilesystem icing_filesystem_; - const std::string index_dir_; }; constexpr DocumentId kDocumentId0 = 0; diff --git a/icing/result/page-result.h b/icing/result/page-result.h new file mode 100644 index 0000000..6645593 --- /dev/null +++ b/icing/result/page-result.h @@ -0,0 +1,46 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_RESULT_PAGE_RESULT_H_ +#define ICING_RESULT_PAGE_RESULT_H_ + +#include <vector> + +#include "icing/proto/search.pb.h" + +namespace icing { +namespace lib { + +// Contains information of the search result of one page. +struct PageResult { + PageResult(std::vector<SearchResultProto::ResultProto> results_in, + int num_results_with_snippets_in, int requested_page_size_in) + : results(std::move(results_in)), + num_results_with_snippets(num_results_with_snippets_in), + requested_page_size(requested_page_size_in) {} + + // Results of one page + std::vector<SearchResultProto::ResultProto> results; + + // Number of results with snippets. + int num_results_with_snippets; + + // The page size for this query. This should always be >= results.size(). + int requested_page_size; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_RESULT_PAGE_RESULT_H_ diff --git a/icing/result/projection-tree.h b/icing/result/projection-tree.h index b2e5ffc..8e38aaf 100644 --- a/icing/result/projection-tree.h +++ b/icing/result/projection-tree.h @@ -18,7 +18,6 @@ #include <string_view> #include <vector> -#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/proto/search.pb.h" namespace icing { @@ -31,14 +30,23 @@ class ProjectionTree { struct Node { explicit Node(std::string_view name = "") : name(name) {} + // TODO: change string_view to string std::string_view name; std::vector<Node> children; + + bool operator==(const Node& other) const { + return name == other.name && children == other.children; + } }; explicit ProjectionTree(const TypePropertyMask& type_field_mask); const Node& root() const { return root_; } + bool operator==(const ProjectionTree& other) const { + return root_ == other.root_; + } + private: // Add a child node with property_name to current_children and returns a // pointer to the child node. diff --git a/icing/result/result-retriever-v2.cc b/icing/result/result-retriever-v2.cc new file mode 100644 index 0000000..92ab048 --- /dev/null +++ b/icing/result/result-retriever-v2.cc @@ -0,0 +1,186 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/result/result-retriever-v2.h" + +#include <memory> +#include <string_view> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/proto/search.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/result/page-result.h" +#include "icing/result/projection-tree.h" +#include "icing/result/projector.h" +#include "icing/result/snippet-context.h" +#include "icing/result/snippet-retriever.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-store.h" +#include "icing/store/namespace-id.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/transform/normalizer.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +bool GroupResultLimiterV2::ShouldBeRemoved( + const ScoredDocumentHit& scored_document_hit, + const std::unordered_map<NamespaceId, int>& namespace_group_id_map, + const DocumentStore& document_store, + std::vector<int>& group_result_limits) const { + auto document_filter_data_optional = + document_store.GetAliveDocumentFilterData( + scored_document_hit.document_id()); + if (!document_filter_data_optional) { + // The document doesn't exist. + return true; + } + NamespaceId namespace_id = + document_filter_data_optional.value().namespace_id(); + auto iter = namespace_group_id_map.find(namespace_id); + if (iter == namespace_group_id_map.end()) { + // If a namespace id isn't found in namespace_group_id_map, then there are + // no limits placed on results from this namespace. + return false; + } + int& count = group_result_limits.at(iter->second); + if (count <= 0) { + return true; + } + --count; + return false; +} + +libtextclassifier3::StatusOr<std::unique_ptr<ResultRetrieverV2>> +ResultRetrieverV2::Create( + const DocumentStore* doc_store, const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, + std::unique_ptr<const GroupResultLimiterV2> group_result_limiter) { + ICING_RETURN_ERROR_IF_NULL(doc_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(language_segmenter); + ICING_RETURN_ERROR_IF_NULL(normalizer); + ICING_RETURN_ERROR_IF_NULL(group_result_limiter); + + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<SnippetRetriever> snippet_retriever, + SnippetRetriever::Create(schema_store, language_segmenter, normalizer)); + + return std::unique_ptr<ResultRetrieverV2>( + new ResultRetrieverV2(doc_store, std::move(snippet_retriever), + std::move(group_result_limiter))); +} + +std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage( + ResultStateV2& result_state) const { + absl_ports::unique_lock l(&result_state.mutex); + + // For calculating page + int original_scored_document_hits_ranker_size = + result_state.scored_document_hits_ranker->size(); + int num_results_with_snippets = 0; + + const SnippetContext& snippet_context = result_state.snippet_context(); + const std::unordered_map<std::string, ProjectionTree>& projection_tree_map = + result_state.projection_tree_map(); + auto wildcard_projection_tree_itr = projection_tree_map.find( + std::string(ProjectionTree::kSchemaTypeWildcard)); + + // Calculates how many snippets to return for this page. + int remaining_num_to_snippet = + snippet_context.snippet_spec.num_to_snippet() - result_state.num_returned; + if (remaining_num_to_snippet < 0) { + remaining_num_to_snippet = 0; + } + + // Retrieve info + std::vector<SearchResultProto::ResultProto> results; + int32_t num_total_bytes = 0; + while (results.size() < result_state.num_per_page() && + !result_state.scored_document_hits_ranker->empty()) { + ScoredDocumentHit next_best_document_hit = + result_state.scored_document_hits_ranker->PopNext(); + if (group_result_limiter_->ShouldBeRemoved( + next_best_document_hit, result_state.namespace_group_id_map(), + doc_store_, result_state.group_result_limits)) { + continue; + } + + libtextclassifier3::StatusOr<DocumentProto> document_or = + doc_store_.Get(next_best_document_hit.document_id()); + if (!document_or.ok()) { + // Skip the document if getting errors. + ICING_LOG(WARNING) << "Fail to fetch document from document store: " + << document_or.status().error_message(); + continue; + } + + DocumentProto document = std::move(document_or).ValueOrDie(); + // Apply projection + auto itr = projection_tree_map.find(document.schema()); + if (itr != projection_tree_map.end()) { + projector::Project(itr->second.root().children, &document); + } else if (wildcard_projection_tree_itr != projection_tree_map.end()) { + projector::Project(wildcard_projection_tree_itr->second.root().children, + &document); + } + + SearchResultProto::ResultProto result; + // Add the snippet if requested. + if (snippet_context.snippet_spec.num_matches_per_property() > 0 && + remaining_num_to_snippet > results.size()) { + SnippetProto snippet_proto = snippet_retriever_->RetrieveSnippet( + snippet_context.query_terms, snippet_context.match_type, + snippet_context.snippet_spec, document, + next_best_document_hit.hit_section_id_mask()); + *result.mutable_snippet() = std::move(snippet_proto); + ++num_results_with_snippets; + } + + // Add the document, itself. + *result.mutable_document() = std::move(document); + result.set_score(next_best_document_hit.score()); + size_t result_bytes = result.ByteSizeLong(); + results.push_back(std::move(result)); + + // Check if num_total_bytes + result_bytes reaches or exceeds + // num_total_bytes_per_page_threshold. Use subtraction to avoid integer + // overflow. + if (result_bytes >= + result_state.num_total_bytes_per_page_threshold() - num_total_bytes) { + break; + } + num_total_bytes += result_bytes; + } + + // Update numbers in ResultState + result_state.num_returned += results.size(); + result_state.IncrementNumTotalHits( + result_state.scored_document_hits_ranker->size() - + original_scored_document_hits_ranker_size); + + bool has_more_results = !result_state.scored_document_hits_ranker->empty(); + + return std::make_pair( + PageResult(std::move(results), num_results_with_snippets, + result_state.num_per_page()), + has_more_results); +} + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-retriever-v2.h b/icing/result/result-retriever-v2.h new file mode 100644 index 0000000..b481cfc --- /dev/null +++ b/icing/result/result-retriever-v2.h @@ -0,0 +1,108 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_RESULT_RETRIEVER_V2_H_ +#define ICING_RESULT_RETRIEVER_V2_H_ + +#include <memory> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/result/page-result.h" +#include "icing/result/result-state-v2.h" +#include "icing/result/snippet-retriever.h" +#include "icing/schema/schema-store.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-store.h" +#include "icing/store/namespace-id.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/transform/normalizer.h" + +namespace icing { +namespace lib { + +class GroupResultLimiterV2 { + public: + GroupResultLimiterV2() {} + + virtual ~GroupResultLimiterV2() = default; + + // Returns true if the scored_document_hit should be removed. + virtual bool ShouldBeRemoved( + const ScoredDocumentHit& scored_document_hit, + const std::unordered_map<NamespaceId, int>& namespace_group_id_map, + const DocumentStore& document_store, + std::vector<int>& group_result_limits) const; +}; + +class ResultRetrieverV2 { + public: + // Factory function to create a ResultRetrieverV2 which does not take + // ownership of any input components, and all pointers must refer to valid + // objects that outlive the created ResultRetrieverV2 instance. + // + // Returns: + // A ResultRetrieverV2 on success + // FAILED_PRECONDITION on any null pointer input + static libtextclassifier3::StatusOr<std::unique_ptr<ResultRetrieverV2>> + Create(const DocumentStore* doc_store, const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer, + std::unique_ptr<const GroupResultLimiterV2> group_result_limiter = + std::make_unique<const GroupResultLimiterV2>()); + + // Retrieves results (pairs of DocumentProtos and SnippetProtos) with the + // given ResultState which holds document and snippet information. It pulls + // out the next top rank documents from ResultState, retrieves the documents + // from storage, updates ResultState, and finally wraps the result + other + // information into PageResult. The expected number of documents to return is + // min(num_per_page, the number of all scored document hits) inside + // ResultState. + // + // The number of snippets to return is based on the total number of snippets + // needed and number of snippets that have already been returned previously + // for the same query. The order of results returned will be sorted by + // scored_document_hit_comparator inside ResultState. + // + // An additional boolean value will be returned, indicating if ResultState has + // remaining documents to be retrieved next round. + // + // All errors will be ignored. It will keep retrieving the next document and + // valid documents will be included in PageResult. + // + // Returns: + // std::pair<PageResult, bool> + std::pair<PageResult, bool> RetrieveNextPage( + ResultStateV2& result_state) const; + + private: + explicit ResultRetrieverV2( + const DocumentStore* doc_store, + std::unique_ptr<SnippetRetriever> snippet_retriever, + std::unique_ptr<const GroupResultLimiterV2> group_result_limiter) + : doc_store_(*doc_store), + snippet_retriever_(std::move(snippet_retriever)), + group_result_limiter_(std::move(group_result_limiter)) {} + + const DocumentStore& doc_store_; + std::unique_ptr<SnippetRetriever> snippet_retriever_; + const std::unique_ptr<const GroupResultLimiterV2> group_result_limiter_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_RESULT_RETRIEVER_V2_H_ diff --git a/icing/result/result-retriever-v2_group-result-limiter_test.cc b/icing/result/result-retriever-v2_group-result-limiter_test.cc new file mode 100644 index 0000000..e0a6c79 --- /dev/null +++ b/icing/result/result-retriever-v2_group-result-limiter_test.cc @@ -0,0 +1,775 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> +#include <vector> + +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/portable/equals-proto.h" +#include "icing/portable/platform.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/store/namespace-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +class ResultRetrieverV2GroupResultLimiterTest : public testing::Test { + protected: + ResultRetrieverV2GroupResultLimiterTest() + : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + + SchemaProto schema; + schema.add_types()->set_schema_type("Document"); + ICING_ASSERT_OK(schema_store_->SetSchema(std::move(schema))); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + const Filesystem filesystem_; + const std::string test_dir_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + FakeClock fake_clock_; +}; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). + +SearchSpecProto CreateSearchSpec(TermMatchType::Code match_type) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(match_type); + return search_spec; +} + +ScoringSpecProto CreateScoringSpec(bool is_descending_order) { + ScoringSpecProto scoring_spec; + scoring_spec.set_order_by(is_descending_order ? ScoringSpecProto::Order::DESC + : ScoringSpecProto::Order::ASC); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingShouldLimitResults) { + // Creates 2 documents and ensures the relationship in terms of document + // score is: document1 < document2 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score())}; + + // Create a ResultSpec that limits "namespace" to a single result. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("namespace"); + + // Creates a ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Only the top ranked document in "namespace" (document2), should be + // returned. + auto [page_result, has_more_results] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result.results, SizeIs(1)); + EXPECT_THAT(page_result.results.at(0).document(), EqualsProto(document2)); + // Document1 has not been returned due to GroupResultLimiter, but since it was + // "filtered out", there should be no more results. + EXPECT_FALSE(has_more_results); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingHasEmptyFirstPage) { + // Creates 2 documents and ensures the relationship in terms of document + // score is: document1 < document2 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score())}; + + // Create a ResultSpec that limits "namespace" to 0 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(0); + result_grouping->add_namespaces("namespace"); + + // Creates a ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // First page: empty page + auto [page_result, has_more_results] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result.results, IsEmpty()); + EXPECT_FALSE(has_more_results); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingHasEmptyLastPage) { + // Creates 4 documents and ensures the relationship in terms of document + // score is: document1 < document2 < document3 < document4 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace", "uri/3") + .SetSchema("Document") + .SetScore(3) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(document3)); + + DocumentProto document4 = DocumentBuilder() + .SetKey("namespace", "uri/4") + .SetSchema("Document") + .SetScore(4) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store_->Put(document4)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score()), + ScoredDocumentHit(document_id3, kSectionIdMaskNone, document3.score()), + ScoredDocumentHit(document_id4, kSectionIdMaskNone, document4.score())}; + + // Create a ResultSpec that limits "namespace" to 2 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace"); + + // Creates a ResultState with 4 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // First page: document4 and document3 should be returned. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result1.results, SizeIs(2)); + EXPECT_THAT(page_result1.results.at(0).document(), EqualsProto(document4)); + EXPECT_THAT(page_result1.results.at(1).document(), EqualsProto(document3)); + EXPECT_TRUE(has_more_results1); + + // Second page: although there are valid document hits in result state, all of + // them will be filtered out by group result limiter, so we should get an + // empty page. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, SizeIs(0)); + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingDoesNotLimitOtherNamespaceResults) { + // Creates 4 documents and ensures the relationship in terms of document + // score is: document1 < document2 < document3 < document4 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace1", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace1", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace2", "uri/3") + .SetSchema("Document") + .SetScore(3) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(document3)); + + DocumentProto document4 = DocumentBuilder() + .SetKey("namespace2", "uri/4") + .SetSchema("Document") + .SetScore(4) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store_->Put(document4)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score()), + ScoredDocumentHit(document_id3, kSectionIdMaskNone, document3.score()), + ScoredDocumentHit(document_id4, kSectionIdMaskNone, document4.score())}; + + // Create a ResultSpec that limits "namespace1" to a single result, but + // doesn't limit "namespace2". + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("namespace1"); + + // Creates a ResultState with 4 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // All documents in "namespace2" should be returned. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).document(), EqualsProto(document4)); + EXPECT_THAT(page_result.results.at(1).document(), EqualsProto(document3)); + EXPECT_THAT(page_result.results.at(2).document(), EqualsProto(document2)); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingNonexistentNamespaceShouldBeIgnored) { + // Creates 2 documents and ensures the relationship in terms of document + // score is: document1 < document2 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score())}; + + // Create a ResultSpec that limits "namespace"+"nonExistentNamespace" to a + // single result. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("namespace"); + result_grouping->add_namespaces("nonexistentNamespace"); + + // Creates a ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Only the top ranked document in "namespace" (document2), should be + // returned. The presence of "nonexistentNamespace" in the same result + // grouping should have no effect. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(1)); + EXPECT_THAT(page_result.results.at(0).document(), EqualsProto(document2)); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingMultiNamespaceGrouping) { + // Creates 6 documents and ensures the relationship in terms of document + // score is: document1 < document2 < document3 < document4 < document5 < + // document6 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace1", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace1", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace2", "uri/3") + .SetSchema("Document") + .SetScore(3) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(document3)); + + DocumentProto document4 = DocumentBuilder() + .SetKey("namespace2", "uri/4") + .SetSchema("Document") + .SetScore(4) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store_->Put(document4)); + + DocumentProto document5 = DocumentBuilder() + .SetKey("namespace3", "uri/5") + .SetSchema("Document") + .SetScore(5) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store_->Put(document5)); + + DocumentProto document6 = DocumentBuilder() + .SetKey("namespace3", "uri/6") + .SetSchema("Document") + .SetScore(6) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, + document_store_->Put(document6)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score()), + ScoredDocumentHit(document_id3, kSectionIdMaskNone, document3.score()), + ScoredDocumentHit(document_id4, kSectionIdMaskNone, document4.score()), + ScoredDocumentHit(document_id5, kSectionIdMaskNone, document5.score()), + ScoredDocumentHit(document_id6, kSectionIdMaskNone, document6.score())}; + + // Create a ResultSpec that limits "namespace1" to a single result and limits + // "namespace2"+"namespace3" to a total of two results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("namespace1"); + result_grouping = result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace2"); + result_grouping->add_namespaces("namespace3"); + + // Creates a ResultState with 6 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Only the top-ranked result in "namespace1" (document2) should be returned. + // Only the top-ranked results across "namespace2" and "namespace3" + // (document6, document5) should be returned. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).document(), EqualsProto(document6)); + EXPECT_THAT(page_result.results.at(1).document(), EqualsProto(document5)); + EXPECT_THAT(page_result.results.at(2).document(), EqualsProto(document2)); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingOnlyNonexistentNamespaces) { + // Creates 2 documents and ensures the relationship in terms of document + // score is: document1 < document2 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score())}; + + // Create a ResultSpec that limits "nonexistentNamespace" to a single result. + // but doesn't limit "namespace" + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("nonexistentNamespace"); + + // Creates a ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // All documents in "namespace" should be returned. The presence of + // "nonexistentNamespace" should have no effect. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + EXPECT_THAT(page_result.results.at(0).document(), EqualsProto(document2)); + EXPECT_THAT(page_result.results.at(1).document(), EqualsProto(document1)); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ShouldUpdateResultStateCorrectlyWithGroupResultLimiter) { + // Creates 5 documents and ensures the relationship in terms of document + // score is: document1 < document2 < document3 < document4 < document5 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace2", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace1", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace1", "uri/3") + .SetSchema("Document") + .SetScore(3) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(document3)); + + DocumentProto document4 = DocumentBuilder() + .SetKey("namespace2", "uri/4") + .SetSchema("Document") + .SetScore(4) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store_->Put(document4)); + + DocumentProto document5 = DocumentBuilder() + .SetKey("namespace2", "uri/5") + .SetSchema("Document") + .SetScore(5) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store_->Put(document5)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score()), + ScoredDocumentHit(document_id3, kSectionIdMaskNone, document3.score()), + ScoredDocumentHit(document_id4, kSectionIdMaskNone, document4.score()), + ScoredDocumentHit(document_id5, kSectionIdMaskNone, document5.score())}; + + // Create a ResultSpec that limits "namespace1" to 3 results and "namespace2" + // to a single result. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(3); + result_grouping->add_namespaces("namespace1"); + result_grouping = result_spec.add_result_groupings(); + result_grouping->set_max_results(1); + result_grouping->add_namespaces("namespace2"); + + // Get namespace ids. + ICING_ASSERT_OK_AND_ASSIGN(NamespaceId namespace_id1, + document_store_->GetNamespaceId("namespace1")); + ICING_ASSERT_OK_AND_ASSIGN(NamespaceId namespace_id2, + document_store_->GetNamespaceId("namespace2")); + + // Creates a ResultState with 5 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + { + absl_ports::shared_lock l(&result_state.mutex); + + ASSERT_THAT( + result_state.namespace_group_id_map(), + UnorderedElementsAre(Pair(namespace_id1, 0), Pair(namespace_id2, 1))); + ASSERT_THAT(result_state.group_result_limits, ElementsAre(3, 1)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // document5, document4, document1 belong to namespace2 (with max_results = + // 1). + // docuemnt3, document2 belong to namespace 1 (with max_results = 3). + // Since num_per_page is 2, we expect to get document5 and document3 in the + // first page. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result1.results, SizeIs(2)); + ASSERT_THAT(page_result1.results.at(0).document(), EqualsProto(document5)); + ASSERT_THAT(page_result1.results.at(1).document(), EqualsProto(document3)); + ASSERT_TRUE(has_more_results1); + { + absl_ports::shared_lock l(&result_state.mutex); + + // Should remove document5, document4 and document3 from + // scored_document_hits. It removes more than num_per_page documents because + // document4 is filtered out by GroupResultLimiter and ResultRetriever has + // to fetch the next one until returning num_per_page documents or no + // remaining documents in scored_document_hits. + ScoredDocumentHit scored_document_hit1(document_id1, kSectionIdMaskNone, + document1.score()); + ScoredDocumentHit scored_document_hit2(document_id2, kSectionIdMaskNone, + document2.score()); + EXPECT_THAT(result_state.scored_document_hits_ranker, Pointee(SizeIs(2))); + + // Even though we removed 3 document hits from scored_document_hits this + // round, num_returned should still be 2, since document4 was "filtered out" + // and should not be counted into num_returned. + EXPECT_THAT(result_state.num_returned, Eq(2)); + // namespace_group_id_map should be unchanged. + EXPECT_THAT( + result_state.namespace_group_id_map(), + UnorderedElementsAre(Pair(namespace_id1, 0), Pair(namespace_id2, 1))); + // GroupResultLimiter should decrement the # in group_result_limits. + EXPECT_THAT(result_state.group_result_limits, ElementsAre(2, 0)); + } + + // Although there are document2 and document1 left, since namespace2 has + // reached its max results, document1 should be excluded from the second page. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result2.results, SizeIs(1)); + ASSERT_THAT(page_result2.results.at(0).document(), EqualsProto(document2)); + ASSERT_FALSE(has_more_results2); + { + absl_ports::shared_lock l(&result_state.mutex); + + // Should remove document2 and document1 from scored_document_hits. + EXPECT_THAT(result_state.scored_document_hits_ranker, Pointee(IsEmpty())); + // Even though we removed 2 document hits from scored_document_hits this + // round, num_returned should only be incremented by 1 (and thus become 3), + // since document1 was "filtered out" and should not be counted into + // num_returned. + EXPECT_THAT(result_state.num_returned, Eq(3)); + // namespace_group_id_map should be unchanged. + EXPECT_THAT( + result_state.namespace_group_id_map(), + UnorderedElementsAre(Pair(namespace_id1, 0), Pair(namespace_id2, 1))); + // GroupResultLimiter should decrement the # in group_result_limits. + EXPECT_THAT(result_state.group_result_limits, ElementsAre(1, 0)); + } +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-retriever-v2_projection_test.cc b/icing/result/result-retriever-v2_projection_test.cc new file mode 100644 index 0000000..bdd1715 --- /dev/null +++ b/icing/result/result-retriever-v2_projection_test.cc @@ -0,0 +1,1281 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> +#include <vector> + +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/portable/equals-proto.h" +#include "icing/portable/platform.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/result/page-result.h" +#include "icing/result/projection-tree.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::SizeIs; + +constexpr PropertyConfigProto::Cardinality::Code CARDINALITY_OPTIONAL = + PropertyConfigProto::Cardinality::OPTIONAL; + +constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_PLAIN = + StringIndexingConfig::TokenizerType::PLAIN; + +constexpr TermMatchType::Code MATCH_EXACT = TermMatchType::EXACT_ONLY; +constexpr TermMatchType::Code MATCH_PREFIX = TermMatchType::PREFIX; + +class ResultRetrieverV2ProjectionTest : public testing::Test { + protected: + ResultRetrieverV2ProjectionTest() : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType( + SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty( + PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SectionId GetSectionId(const std::string& type, const std::string& property) { + auto type_id_or = schema_store_->GetSchemaTypeId(type); + if (!type_id_or.ok()) { + return kInvalidSectionId; + } + SchemaTypeId type_id = type_id_or.ValueOrDie(); + for (SectionId section_id = 0; section_id <= kMaxSectionId; ++section_id) { + auto metadata_or = schema_store_->GetSectionMetadata(type_id, section_id); + if (!metadata_or.ok()) { + break; + } + const SectionMetadata* metadata = metadata_or.ValueOrDie(); + if (metadata->path == property) { + return metadata->id; + } + } + return kInvalidSectionId; + } + + const Filesystem filesystem_; + const std::string test_dir_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + FakeClock fake_clock_; +}; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). + +SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { + SectionIdMask mask = 0; + for (SectionId section_id : section_ids) { + mask |= (1u << section_id); + } + return mask; +} + +SearchSpecProto CreateSearchSpec(TermMatchType::Code match_type) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(match_type); + return search_spec; +} + +ScoringSpecProto CreateScoringSpec(bool is_descending_order) { + ScoringSpecProto scoring_spec; + scoring_spec.set_order_by(is_descending_order ? ScoringSpecProto::Order::DESC + : ScoringSpecProto::Order::ASC); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionTopLevelLeadNodeFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results only contain the 'name' property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionNestedLeafNodeFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build()) + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("sender.name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results only contain the 'sender.name' + // property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionIntermediateNodeFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build()) + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("sender"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results only contain the 'sender' + // property and all of the subproperties of 'sender'. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleNestedFieldPaths) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build()) + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("sender.name"); + type_property_mask->add_paths("sender.emailAddress"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results only contain the 'sender.name' and + // 'sender.address' properties. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Person") + .AddStringProperty("name", "Tom Hanks") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionEmptyFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results contain *no* properties. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionInvalidFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("nonExistentProperty"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results contain *no* properties. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionValidAndInvalidFieldPath) { + // 1. Add two Email documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .AddStringProperty("body", + "Count all the sheep and tell them 'Hello'.") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("name"); + type_property_mask->add_paths("nonExistentProperty"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned results only contain the 'name' property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Goodnight Moon!") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesNoWildcards) { + // 1. Add two documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* type_property_mask = result_spec.add_type_property_masks(); + type_property_mask->set_schema_type("Email"); + type_property_mask->add_paths("name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned Email results only contain the 'name' + // property and the returned Person results have all of their properties. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesWildcard) { + // 1. Add two documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* wildcard_type_property_mask = + result_spec.add_type_property_masks(); + wildcard_type_property_mask->set_schema_type( + std::string(ProjectionTree::kSchemaTypeWildcard)); + wildcard_type_property_mask->add_paths("name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned Email results only contain the 'name' + // property and the returned Person results only contain the 'name' property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, + ProjectionMultipleTypesWildcardWithOneOverride) { + // 1. Add two documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* email_type_property_mask = + result_spec.add_type_property_masks(); + email_type_property_mask->set_schema_type("Email"); + email_type_property_mask->add_paths("body"); + TypePropertyMask* wildcard_type_property_mask = + result_spec.add_type_property_masks(); + wildcard_type_property_mask->set_schema_type( + std::string(ProjectionTree::kSchemaTypeWildcard)); + wildcard_type_property_mask->add_paths("name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned Email results only contain the 'body' + // property and the returned Person results only contain the 'name' property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, + ProjectionSingleTypesWildcardAndOverride) { + // 1. Add two documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri") + .SetSchema("Person") + .AddStringProperty("name", "Mr. Body") + .AddStringProperty("emailAddress", "mr.body123@gmail.com") + .Build()) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* email_type_property_mask = + result_spec.add_type_property_masks(); + email_type_property_mask->set_schema_type("Email"); + email_type_property_mask->add_paths("sender.name"); + TypePropertyMask* wildcard_type_property_mask = + result_spec.add_type_property_masks(); + wildcard_type_property_mask->set_schema_type( + std::string(ProjectionTree::kSchemaTypeWildcard)); + wildcard_type_property_mask->add_paths("name"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned Email results only contain the 'sender.name' + // property and the returned Person results only contain the 'name' property. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri") + .SetSchema("Person") + .AddStringProperty("name", "Mr. Body") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +TEST_F(ResultRetrieverV2ProjectionTest, + ProjectionSingleTypesWildcardAndOverrideNestedProperty) { + // 1. Add two documents + DocumentProto document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty("name", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri") + .SetSchema("Person") + .AddStringProperty("name", "Mr. Body") + .AddStringProperty("emailAddress", "mr.body123@gmail.com") + .Build()) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document_one)); + + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .AddStringProperty("name", "Joe Fox") + .AddStringProperty("emailAddress", "ny152@aol.com") + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document_two)); + + // 2. Setup the scored results. + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + // 3. Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* email_type_property_mask = + result_spec.add_type_property_masks(); + email_type_property_mask->set_schema_type("Email"); + email_type_property_mask->add_paths("sender.name"); + TypePropertyMask* wildcard_type_property_mask = + result_spec.add_type_property_masks(); + wildcard_type_property_mask->set_schema_type( + std::string(ProjectionTree::kSchemaTypeWildcard)); + wildcard_type_property_mask->add_paths("sender"); + + // 4. Create ResultState with custom ResultSpec. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // 5. Verify that the returned Email results only contain the 'sender.name' + // property and the returned Person results contain no properties. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(2)); + + DocumentProto projected_document_one = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri") + .SetSchema("Person") + .AddStringProperty("name", "Mr. Body") + .Build()) + .Build(); + EXPECT_THAT(page_result.results.at(0).document(), + EqualsProto(projected_document_one)); + + DocumentProto projected_document_two = DocumentBuilder() + .SetKey("namespace", "uri2") + .SetCreationTimestampMs(1000) + .SetSchema("Person") + .Build(); + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(projected_document_two)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-retriever-v2_snippet_test.cc b/icing/result/result-retriever-v2_snippet_test.cc new file mode 100644 index 0000000..afb31cf --- /dev/null +++ b/icing/result/result-retriever-v2_snippet_test.cc @@ -0,0 +1,573 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <limits> +#include <memory> +#include <string_view> +#include <vector> + +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/portable/equals-proto.h" +#include "icing/portable/platform.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/snippet-helpers.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; + +constexpr PropertyConfigProto::Cardinality::Code CARDINALITY_OPTIONAL = + PropertyConfigProto::Cardinality::OPTIONAL; + +constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_PLAIN = + StringIndexingConfig::TokenizerType::PLAIN; + +constexpr TermMatchType::Code MATCH_EXACT = TermMatchType::EXACT_ONLY; +constexpr TermMatchType::Code MATCH_PREFIX = TermMatchType::PREFIX; + +class ResultRetrieverV2SnippetTest : public testing::Test { + protected: + ResultRetrieverV2SnippetTest() : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType( + SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty( + PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SectionId GetSectionId(const std::string& type, const std::string& property) { + auto type_id_or = schema_store_->GetSchemaTypeId(type); + if (!type_id_or.ok()) { + return kInvalidSectionId; + } + SchemaTypeId type_id = type_id_or.ValueOrDie(); + for (SectionId section_id = 0; section_id <= kMaxSectionId; ++section_id) { + auto metadata_or = schema_store_->GetSectionMetadata(type_id, section_id); + if (!metadata_or.ok()) { + break; + } + const SectionMetadata* metadata = metadata_or.ValueOrDie(); + if (metadata->path == property) { + return metadata->id; + } + } + return kInvalidSectionId; + } + + const Filesystem filesystem_; + const std::string test_dir_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + FakeClock fake_clock_; +}; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). + +ResultSpecProto::SnippetSpecProto CreateSnippetSpec() { + ResultSpecProto::SnippetSpecProto snippet_spec; + snippet_spec.set_num_to_snippet(std::numeric_limits<int>::max()); + snippet_spec.set_num_matches_per_property(std::numeric_limits<int>::max()); + snippet_spec.set_max_window_utf32_length(1024); + return snippet_spec; +} + +DocumentProto CreateDocument(int id) { + return DocumentBuilder() + .SetKey("icing", "Email/" + std::to_string(id)) + .SetSchema("Email") + .AddStringProperty("name", "subject foo " + std::to_string(id)) + .AddStringProperty("body", "body bar " + std::to_string(id)) + .SetCreationTimestampMs(1574365086666 + id) + .Build(); +} + +SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { + SectionIdMask mask = 0; + for (SectionId section_id : section_ids) { + mask |= (1u << section_id); + } + return mask; +} + +SearchSpecProto CreateSearchSpec(TermMatchType::Code match_type) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(match_type); + return search_spec; +} + +ScoringSpecProto CreateScoringSpec(bool is_descending_order) { + ScoringSpecProto scoring_spec; + scoring_spec.set_order_by(is_descending_order ? ScoringSpecProto::Order::DESC + : ScoringSpecProto::Order::ASC); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +TEST_F(ResultRetrieverV2SnippetTest, + DefaultSnippetSpecShouldDisableSnippeting) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/3), *document_store_); + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).snippet(), + EqualsProto(SnippetProto::default_instance())); + EXPECT_THAT(page_result.results.at(1).snippet(), + EqualsProto(SnippetProto::default_instance())); + EXPECT_THAT(page_result.results.at(2).snippet(), + EqualsProto(SnippetProto::default_instance())); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(0)); +} + +TEST_F(ResultRetrieverV2SnippetTest, SimpleSnippeted) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Create ResultSpec with custom snippet spec. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/3); + *result_spec.mutable_snippet_spec() = CreateSnippetSpec(); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{{"", {"foo", "bar"}}}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(3)); + + const DocumentProto& result_document_one = + page_result.results.at(0).document(); + const SnippetProto& result_snippet_one = page_result.results.at(0).snippet(); + EXPECT_THAT(result_document_one, EqualsProto(CreateDocument(/*id=*/1))); + EXPECT_THAT(result_snippet_one.entries(), SizeIs(2)); + EXPECT_THAT(result_snippet_one.entries(0).property_name(), Eq("body")); + std::string_view content = GetString( + &result_document_one, result_snippet_one.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_one.entries(0)), + ElementsAre("body bar 1")); + EXPECT_THAT(GetMatches(content, result_snippet_one.entries(0)), + ElementsAre("bar")); + EXPECT_THAT(result_snippet_one.entries(1).property_name(), Eq("name")); + content = GetString(&result_document_one, + result_snippet_one.entries(1).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_one.entries(1)), + ElementsAre("subject foo 1")); + EXPECT_THAT(GetMatches(content, result_snippet_one.entries(1)), + ElementsAre("foo")); + + const DocumentProto& result_document_two = + page_result.results.at(1).document(); + const SnippetProto& result_snippet_two = page_result.results.at(1).snippet(); + EXPECT_THAT(result_document_two, EqualsProto(CreateDocument(/*id=*/2))); + EXPECT_THAT(result_snippet_two.entries(), SizeIs(2)); + EXPECT_THAT(result_snippet_two.entries(0).property_name(), Eq("body")); + content = GetString(&result_document_two, + result_snippet_two.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_two.entries(0)), + ElementsAre("body bar 2")); + EXPECT_THAT(GetMatches(content, result_snippet_two.entries(0)), + ElementsAre("bar")); + EXPECT_THAT(result_snippet_two.entries(1).property_name(), Eq("name")); + content = GetString(&result_document_two, + result_snippet_two.entries(1).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_two.entries(1)), + ElementsAre("subject foo 2")); + EXPECT_THAT(GetMatches(content, result_snippet_two.entries(1)), + ElementsAre("foo")); + + const DocumentProto& result_document_three = + page_result.results.at(2).document(); + const SnippetProto& result_snippet_three = + page_result.results.at(2).snippet(); + EXPECT_THAT(result_document_three, EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(result_snippet_three.entries(), SizeIs(2)); + EXPECT_THAT(result_snippet_three.entries(0).property_name(), Eq("body")); + content = GetString(&result_document_three, + result_snippet_three.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_three.entries(0)), + ElementsAre("body bar 3")); + EXPECT_THAT(GetMatches(content, result_snippet_three.entries(0)), + ElementsAre("bar")); + EXPECT_THAT(result_snippet_three.entries(1).property_name(), Eq("name")); + content = GetString(&result_document_three, + result_snippet_three.entries(1).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_three.entries(1)), + ElementsAre("subject foo 3")); + EXPECT_THAT(GetMatches(content, result_snippet_three.entries(1)), + ElementsAre("foo")); +} + +TEST_F(ResultRetrieverV2SnippetTest, OnlyOneDocumentSnippeted) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Create ResultSpec with custom snippet spec. + ResultSpecProto::SnippetSpecProto snippet_spec = CreateSnippetSpec(); + snippet_spec.set_num_to_snippet(1); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/3); + *result_spec.mutable_snippet_spec() = std::move(snippet_spec); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{{"", {"foo", "bar"}}}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(1)); + + const DocumentProto& result_document = page_result.results.at(0).document(); + const SnippetProto& result_snippet = page_result.results.at(0).snippet(); + EXPECT_THAT(result_document, EqualsProto(CreateDocument(/*id=*/1))); + EXPECT_THAT(result_snippet.entries(), SizeIs(2)); + EXPECT_THAT(result_snippet.entries(0).property_name(), Eq("body")); + std::string_view content = + GetString(&result_document, result_snippet.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet.entries(0)), + ElementsAre("body bar 1")); + EXPECT_THAT(GetMatches(content, result_snippet.entries(0)), + ElementsAre("bar")); + EXPECT_THAT(result_snippet.entries(1).property_name(), Eq("name")); + content = + GetString(&result_document, result_snippet.entries(1).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet.entries(1)), + ElementsAre("subject foo 1")); + EXPECT_THAT(GetMatches(content, result_snippet.entries(1)), + ElementsAre("foo")); + + EXPECT_THAT(page_result.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/2))); + EXPECT_THAT(page_result.results.at(1).snippet(), + EqualsProto(SnippetProto::default_instance())); + + EXPECT_THAT(page_result.results.at(2).document(), + EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(page_result.results.at(2).snippet(), + EqualsProto(SnippetProto::default_instance())); +} + +TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetAllResults) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Create ResultSpec with custom snippet spec. + ResultSpecProto::SnippetSpecProto snippet_spec = CreateSnippetSpec(); + snippet_spec.set_num_to_snippet(5); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/3); + *result_spec.mutable_snippet_spec() = std::move(snippet_spec); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{{"", {"foo", "bar"}}}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + // num_to_snippet = 5, num_previously_returned_in = 0, + // We can return 5 - 0 = 5 snippets at most. We're able to return all 3 + // snippets here. + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).snippet().entries(), Not(IsEmpty())); + EXPECT_THAT(page_result.results.at(1).snippet().entries(), Not(IsEmpty())); + EXPECT_THAT(page_result.results.at(2).snippet().entries(), Not(IsEmpty())); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(3)); +} + +TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetSomeResults) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Create ResultSpec with custom snippet spec. + ResultSpecProto::SnippetSpecProto snippet_spec = CreateSnippetSpec(); + snippet_spec.set_num_to_snippet(5); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/3); + *result_spec.mutable_snippet_spec() = std::move(snippet_spec); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{{"", {"foo", "bar"}}}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + { + absl_ports::unique_lock l(&result_state.mutex); + + // Set (previously) num_returned = 3 docs + result_state.num_returned = 3; + } + + // num_to_snippet = 5, (previously) num_returned = 3, + // We can return 5 - 3 = 2 snippets. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).snippet().entries(), Not(IsEmpty())); + EXPECT_THAT(page_result.results.at(1).snippet().entries(), Not(IsEmpty())); + EXPECT_THAT(page_result.results.at(2).snippet().entries(), IsEmpty()); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(2)); +} + +TEST_F(ResultRetrieverV2SnippetTest, ShouldNotSnippetAnyResults) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(CreateDocument(/*id=*/3))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Create ResultSpec with custom snippet spec. + ResultSpecProto::SnippetSpecProto snippet_spec = CreateSnippetSpec(); + snippet_spec.set_num_to_snippet(5); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/3); + *result_spec.mutable_snippet_spec() = std::move(snippet_spec); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{{"", {"foo", "bar"}}}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/false), result_spec, + *document_store_); + { + absl_ports::unique_lock l(&result_state.mutex); + + // Set (previously) num_returned = 6 docs + result_state.num_returned = 6; + } + + // num_to_snippet = 5, (previously) num_returned = 6, + // We can't return any snippets for this page. + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result.results, SizeIs(3)); + EXPECT_THAT(page_result.results.at(0).snippet().entries(), IsEmpty()); + EXPECT_THAT(page_result.results.at(1).snippet().entries(), IsEmpty()); + EXPECT_THAT(page_result.results.at(2).snippet().entries(), IsEmpty()); + EXPECT_THAT(page_result.num_results_with_snippets, Eq(0)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-retriever-v2_test.cc b/icing/result/result-retriever-v2_test.cc new file mode 100644 index 0000000..0998754 --- /dev/null +++ b/icing/result/result-retriever-v2_test.cc @@ -0,0 +1,815 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/result/result-retriever-v2.h" + +#include <atomic> +#include <memory> +#include <unordered_map> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/mock-filesystem.h" +#include "icing/portable/equals-proto.h" +#include "icing/portable/platform.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/result/page-result.h" +#include "icing/result/result-state-v2.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::DoDefault; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Gt; +using ::testing::IsEmpty; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::SizeIs; +using NamespaceIdMap = std::unordered_map<NamespaceId, int>; + +constexpr PropertyConfigProto::Cardinality::Code CARDINALITY_OPTIONAL = + PropertyConfigProto::Cardinality::OPTIONAL; + +constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_PLAIN = + StringIndexingConfig::TokenizerType::PLAIN; + +constexpr TermMatchType::Code MATCH_EXACT = TermMatchType::EXACT_ONLY; +constexpr TermMatchType::Code MATCH_PREFIX = TermMatchType::PREFIX; + +// Mock the behavior of GroupResultLimiter::ShouldBeRemoved. +class MockGroupResultLimiter : public GroupResultLimiterV2 { + public: + MockGroupResultLimiter() : GroupResultLimiterV2() { + ON_CALL(*this, ShouldBeRemoved).WillByDefault(Return(false)); + } + + MOCK_METHOD(bool, ShouldBeRemoved, + (const ScoredDocumentHit&, const NamespaceIdMap&, + const DocumentStore&, std::vector<int>&), + (const, override)); +}; + +class ResultRetrieverV2Test : public ::testing::Test { + protected: + ResultRetrieverV2Test() : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType( + SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty( + PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + num_total_hits_ = 0; + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SectionId GetSectionId(const std::string& type, const std::string& property) { + auto type_id_or = schema_store_->GetSchemaTypeId(type); + if (!type_id_or.ok()) { + return kInvalidSectionId; + } + SchemaTypeId type_id = type_id_or.ValueOrDie(); + for (SectionId section_id = 0; section_id <= kMaxSectionId; ++section_id) { + auto metadata_or = schema_store_->GetSectionMetadata(type_id, section_id); + if (!metadata_or.ok()) { + break; + } + const SectionMetadata* metadata = metadata_or.ValueOrDie(); + if (metadata->path == property) { + return metadata->id; + } + } + return kInvalidSectionId; + } + + const Filesystem filesystem_; + const std::string test_dir_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::atomic<int> num_total_hits_; + FakeClock fake_clock_; +}; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). + +DocumentProto CreateDocument(int id) { + return DocumentBuilder() + .SetKey("icing", "Email/" + std::to_string(id)) + .SetSchema("Email") + .AddStringProperty("name", "subject foo " + std::to_string(id)) + .AddStringProperty("body", "body bar " + std::to_string(id)) + .SetCreationTimestampMs(1574365086666 + id) + .Build(); +} + +SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { + SectionIdMask mask = 0; + for (SectionId section_id : section_ids) { + mask |= (1u << section_id); + } + return mask; +} + +SearchSpecProto CreateSearchSpec(TermMatchType::Code match_type) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(match_type); + return search_spec; +} + +ScoringSpecProto CreateScoringSpec(bool is_descending_order) { + ScoringSpecProto scoring_spec; + scoring_spec.set_order_by(is_descending_order ? ScoringSpecProto::Order::DESC + : ScoringSpecProto::Order::ASC); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +TEST_F(ResultRetrieverV2Test, CreationWithNullPointerShouldFail) { + EXPECT_THAT( + ResultRetrieverV2::Create(/*doc_store=*/nullptr, schema_store_.get(), + language_segmenter_.get(), normalizer_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + EXPECT_THAT( + ResultRetrieverV2::Create(doc_store.get(), /*schema_store=*/nullptr, + language_segmenter_.get(), normalizer_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + /*language_segmenter=*/nullptr, + normalizer_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), + /*normalizer=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ResultRetrieverV2Test, ShouldRetrieveSimpleResults) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + doc_store->Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + doc_store->Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + doc_store->Put(CreateDocument(/*id=*/5))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/19}, + {document_id2, hit_section_id_mask, /*score=*/12}, + {document_id3, hit_section_id_mask, /*score=*/8}, + {document_id4, hit_section_id_mask, /*score=*/3}, + {document_id5, hit_section_id_mask, /*score=*/1}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(19); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(12); + SearchResultProto::ResultProto result3; + *result3.mutable_document() = CreateDocument(/*id=*/3); + result3.set_score(8); + SearchResultProto::ResultProto result4; + *result4.mutable_document() = CreateDocument(/*id=*/4); + result4.set_score(3); + SearchResultProto::ResultProto result5; + *result5.mutable_document() = CreateDocument(/*id=*/5); + result5.set_score(1); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/2), *doc_store); + + // First page, 2 results + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result1.results, + ElementsAre(EqualsProto(result1), EqualsProto(result2))); + // num_results_with_snippets is 0 when there is no snippet. + EXPECT_THAT(page_result1.num_results_with_snippets, Eq(0)); + // Requested page size is same as num_per_page. + EXPECT_THAT(page_result1.requested_page_size, Eq(2)); + // Has more results. + EXPECT_TRUE(has_more_results1); + + // Second page, 2 results + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, + ElementsAre(EqualsProto(result3), EqualsProto(result4))); + // num_results_with_snippets is 0 when there is no snippet. + EXPECT_THAT(page_result2.num_results_with_snippets, Eq(0)); + // Requested page size is same as num_per_page. + EXPECT_THAT(page_result2.requested_page_size, Eq(2)); + // Has more results. + EXPECT_TRUE(has_more_results2); + + // Third page, 1 result + auto [page_result3, has_more_results3] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result3.results, ElementsAre(EqualsProto(result5))); + // num_results_with_snippets is 0 when there is no snippet. + EXPECT_THAT(page_result3.num_results_with_snippets, Eq(0)); + // Requested page size is same as num_per_page. + EXPECT_THAT(page_result3.requested_page_size, Eq(2)); + // No more results. + EXPECT_FALSE(has_more_results3); +} + +TEST_F(ResultRetrieverV2Test, ShouldIgnoreNonInternalErrors) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + DocumentId invalid_document_id = -1; + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/12}, + {document_id2, hit_section_id_mask, /*score=*/4}, + {invalid_document_id, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get(), + std::make_unique<MockGroupResultLimiter>())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(12); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(4); + + ResultStateV2 result_state1( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/3), *doc_store); + PageResult page_result1 = + result_retriever->RetrieveNextPage(result_state1).first; + EXPECT_THAT(page_result1.results, + ElementsAre(EqualsProto(result1), EqualsProto(result2))); + + DocumentId non_existing_document_id = 4; + scored_document_hits = { + {non_existing_document_id, hit_section_id_mask, /*score=*/15}, + {document_id1, hit_section_id_mask, /*score=*/12}, + {document_id2, hit_section_id_mask, /*score=*/4}}; + ResultStateV2 result_state2( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/3), *doc_store); + PageResult page_result2 = + result_retriever->RetrieveNextPage(result_state2).first; + EXPECT_THAT(page_result2.results, + ElementsAre(EqualsProto(result1), EqualsProto(result2))); +} + +TEST_F(ResultRetrieverV2Test, ShouldIgnoreInternalErrors) { + MockFilesystem mock_filesystem; + EXPECT_CALL(mock_filesystem, + PRead(A<int>(), A<void*>(), A<size_t>(), A<off_t>())) + .WillOnce(Return(false)) + .WillRepeatedly(DoDefault()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&mock_filesystem, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get(), + std::make_unique<MockGroupResultLimiter>())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(0); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/2), *doc_store); + PageResult page_result = + result_retriever->RetrieveNextPage(result_state).first; + // We mocked mock_filesystem to return an internal error when retrieving doc2, + // so doc2 should be skipped and doc1 should still be returned. + EXPECT_THAT(page_result.results, ElementsAre(EqualsProto(result1))); +} + +TEST_F(ResultRetrieverV2Test, ShouldUpdateResultState) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + doc_store->Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + doc_store->Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + doc_store->Put(CreateDocument(/*id=*/5))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}, + {document_id3, hit_section_id_mask, /*score=*/0}, + {document_id4, hit_section_id_mask, /*score=*/0}, + {document_id5, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/2), *doc_store); + + // First page, 2 results + PageResult page_result1 = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result1.results, SizeIs(2)); + { + absl_ports::shared_lock l(&result_state.mutex); + + // num_returned = size of first page + EXPECT_THAT(result_state.num_returned, Eq(2)); + // Should remove the 2 returned docs from scored_document_hits and only + // contain the remaining 3. + EXPECT_THAT(result_state.scored_document_hits_ranker, Pointee(SizeIs(3))); + } + + // Second page, 2 results + PageResult page_result2 = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result2.results, SizeIs(2)); + { + absl_ports::shared_lock l(&result_state.mutex); + + // num_returned = size of first and second pages + EXPECT_THAT(result_state.num_returned, Eq(4)); + // Should remove the 2 returned docs from scored_document_hits and only + // contain the remaining 1. + EXPECT_THAT(result_state.scored_document_hits_ranker, Pointee(SizeIs(1))); + } + + // Third page, 1 result + PageResult page_result3 = + result_retriever->RetrieveNextPage(result_state).first; + ASSERT_THAT(page_result3.results, SizeIs(1)); + { + absl_ports::shared_lock l(&result_state.mutex); + + // num_returned = size of first, second and third pages + EXPECT_THAT(result_state.num_returned, Eq(5)); + // Should remove the 1 returned doc from scored_document_hits and become + // empty. + EXPECT_THAT(result_state.scored_document_hits_ranker, Pointee(IsEmpty())); + } +} + +TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + std::vector<ScoredDocumentHit> scored_document_hits1 = { + {document_id1, hit_section_id_mask, /*score=*/0}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + std::shared_ptr<ResultStateV2> result_state1 = + std::make_shared<ResultStateV2>( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), + /*is_descending=*/true), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/1), *doc_store); + { + absl_ports::unique_lock l(&result_state1->mutex); + + result_state1->RegisterNumTotalHits(&num_total_hits_); + ASSERT_THAT(num_total_hits_, Eq(2)); + } + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + doc_store->Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + doc_store->Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + doc_store->Put(CreateDocument(/*id=*/5))); + std::vector<ScoredDocumentHit> scored_document_hits2 = { + {document_id3, hit_section_id_mask, /*score=*/0}, + {document_id4, hit_section_id_mask, /*score=*/0}, + {document_id5, hit_section_id_mask, /*score=*/0}}; + std::shared_ptr<ResultStateV2> result_state2 = + std::make_shared<ResultStateV2>( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), + /*is_descending=*/true), + /*query_terms=*/SectionRestrictQueryTermsMap{}, + CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/2), *doc_store); + { + absl_ports::unique_lock l(&result_state2->mutex); + + result_state2->RegisterNumTotalHits(&num_total_hits_); + ASSERT_THAT(num_total_hits_, Eq(5)); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // Should get 1 doc in the first page of result_state1, and num_total_hits + // should be decremented by 1. + PageResult page_result1 = + result_retriever->RetrieveNextPage(*result_state1).first; + ASSERT_THAT(page_result1.results, SizeIs(1)); + EXPECT_THAT(num_total_hits_, Eq(4)); + + // Should get 2 docs in the first page of result_state2, and num_total_hits + // should be decremented by 2. + PageResult page_result2 = + result_retriever->RetrieveNextPage(*result_state2).first; + ASSERT_THAT(page_result2.results, SizeIs(2)); + EXPECT_THAT(num_total_hits_, Eq(2)); + + // Should get 1 doc in the second page of result_state2 (although num_per_page + // is 2, there is only 1 doc left), and num_total_hits should be decremented + // by 1. + PageResult page_result3 = + result_retriever->RetrieveNextPage(*result_state2).first; + ASSERT_THAT(page_result3.results, SizeIs(1)); + EXPECT_THAT(num_total_hits_, Eq(1)); + + // Destruct result_state1. There is 1 doc left, so num_total_hits should be + // decremented by 1 when destructing it. + result_state1.reset(); + EXPECT_THAT(num_total_hits_, Eq(0)); + + // Destruct result_state2. There is 0 doc left, so num_total_hits should be + // unchanged when destructing it. + result_state1.reset(); + EXPECT_THAT(num_total_hits_, Eq(0)); +} + +TEST_F(ResultRetrieverV2Test, ShouldLimitNumTotalBytesPerPage) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(result1.ByteSizeLong()); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // First page. Only result1 should be returned, since its byte size meets + // num_total_bytes_per_page_threshold and ResultRetriever should terminate + // early even though # of results is still below num_per_page. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result1.results, ElementsAre(EqualsProto(result1))); + // Has more results. + EXPECT_TRUE(has_more_results1); + + // Second page, result2. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, ElementsAre(EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2Test, + ShouldReturnSingleLargeResultAboveNumTotalBytesPerPageThreshold) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + int threshold = 1; + ASSERT_THAT(result1.ByteSizeLong(), Gt(threshold)); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(threshold); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // First page. Should return single result1 even though its byte size exceeds + // num_total_bytes_per_page_threshold. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result1.results, ElementsAre(EqualsProto(result1))); + // Has more results. + EXPECT_TRUE(has_more_results1); + + // Second page, result2. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, ElementsAre(EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2Test, + ShouldRetrieveNextResultWhenBelowNumTotalBytesPerPageThreshold) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + int threshold = result1.ByteSizeLong() + 1; + ASSERT_THAT(result1.ByteSizeLong() + result2.ByteSizeLong(), Gt(threshold)); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(threshold); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // After retrieving result1, total bytes are still below the threshold and # + // of results is still below num_per_page, so ResultRetriever should continue + // the retrieval process and thus include result2 into this page, even though + // finally total bytes of result1 + result2 exceed the threshold. + auto [page_result, has_more_results] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result.results, + ElementsAre(EqualsProto(result1), EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-state-manager.cc b/icing/result/result-state-manager.cc index d606e79..2783fe2 100644 --- a/icing/result/result-state-manager.cc +++ b/icing/result/result-state-manager.cc @@ -14,7 +14,16 @@ #include "icing/result/result-state-manager.h" +#include <memory> +#include <queue> +#include <utility> + #include "icing/proto/search.pb.h" +#include "icing/query/query-terms.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/util/clock.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -23,100 +32,116 @@ namespace icing { namespace lib { ResultStateManager::ResultStateManager(int max_total_hits, - const DocumentStore& document_store) + const DocumentStore& document_store, + const Clock* clock) : document_store_(document_store), max_total_hits_(max_total_hits), num_total_hits_(0), - random_generator_(GetSteadyTimeNanoseconds()) {} - -libtextclassifier3::StatusOr<PageResultState> -ResultStateManager::RankAndPaginate(ResultState result_state) { - if (!result_state.HasMoreResults()) { - return absl_ports::InvalidArgumentError("ResultState has no results"); + random_generator_(GetSteadyTimeNanoseconds()), + clock_(*clock) {} + +libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> +ResultStateManager::CacheAndRetrieveFirstPage( + std::unique_ptr<ScoredDocumentHitsRanker> ranker, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, const DocumentStore& document_store, + const ResultRetrieverV2& result_retriever) { + if (ranker == nullptr) { + return absl_ports::InvalidArgumentError("Should not provide null ranker"); } - // Gets the number before calling GetNextPage() because num_returned() may - // change after returning more results. - int num_previously_returned = result_state.num_returned(); - int num_per_page = result_state.num_per_page(); - - std::vector<ScoredDocumentHit> page_result_document_hits = - result_state.GetNextPage(document_store_); - - SnippetContext snippet_context_copy = result_state.snippet_context(); - - std::unordered_map<std::string, ProjectionTree> projection_tree_map_copy = - result_state.projection_tree_map(); - if (!result_state.HasMoreResults()) { + // Create shared pointer of ResultState. + // ResultState should be created by ResultStateManager only. + std::shared_ptr<ResultStateV2> result_state = std::make_shared<ResultStateV2>( + std::move(ranker), std::move(query_terms), search_spec, scoring_spec, + result_spec, document_store); + + // Retrieve docs outside of ResultStateManager critical section. + // Will enter ResultState critical section inside ResultRetriever. + auto [page_result, has_more_results] = + result_retriever.RetrieveNextPage(*result_state); + if (!has_more_results) { // No more pages, won't store ResultState, returns directly - return PageResultState( - std::move(page_result_document_hits), kInvalidNextPageToken, - std::move(snippet_context_copy), std::move(projection_tree_map_copy), - num_previously_returned, num_per_page); + return std::make_pair(kInvalidNextPageToken, std::move(page_result)); } - absl_ports::unique_lock l(&mutex_); - // ResultState has multiple pages, storing it - uint64_t next_page_token = Add(std::move(result_state)); + int num_hits_to_add = 0; + { + // ResultState critical section + absl_ports::unique_lock l(&result_state->mutex); + + result_state->scored_document_hits_ranker->TruncateHitsTo(max_total_hits_); + result_state->RegisterNumTotalHits(&num_total_hits_); + num_hits_to_add = result_state->scored_document_hits_ranker->size(); + } - return PageResultState(std::move(page_result_document_hits), next_page_token, - std::move(snippet_context_copy), - std::move(projection_tree_map_copy), - num_previously_returned, num_per_page); -} + // It is fine to exit ResultState critical section, since it is just created + // above and only this thread (this call stack) has access to it. Thus, it + // won't be changed during the gap before we enter ResultStateManager critical + // section. + uint64_t next_page_token = kInvalidNextPageToken; + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); + + // Remove expired result states first. + InternalInvalidateExpiredResultStates(kDefaultResultStateTtlInMs); + // Remove states to make room for this new state. + RemoveStatesIfNeeded(num_hits_to_add); + // Generate a new unique token and add it into result_state_map_. + next_page_token = Add(std::move(result_state)); + } -uint64_t ResultStateManager::Add(ResultState result_state) { - RemoveStatesIfNeeded(result_state); - result_state.TruncateHitsTo(max_total_hits_); + return std::make_pair(next_page_token, std::move(page_result)); +} +uint64_t ResultStateManager::Add(std::shared_ptr<ResultStateV2> result_state) { uint64_t new_token = GetUniqueToken(); - num_total_hits_ += result_state.num_remaining(); result_state_map_.emplace(new_token, std::move(result_state)); // Tracks the insertion order - token_queue_.push(new_token); + token_queue_.push( + std::make_pair(new_token, clock_.GetSystemTimeMilliseconds())); return new_token; } -libtextclassifier3::StatusOr<PageResultState> ResultStateManager::GetNextPage( - uint64_t next_page_token) { - absl_ports::unique_lock l(&mutex_); - - const auto& state_iterator = result_state_map_.find(next_page_token); - if (state_iterator == result_state_map_.end()) { - return absl_ports::NotFoundError("next_page_token not found"); +libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> +ResultStateManager::GetNextPage(uint64_t next_page_token, + const ResultRetrieverV2& result_retriever) { + std::shared_ptr<ResultStateV2> result_state = nullptr; + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); + + // Remove expired result states before fetching + InternalInvalidateExpiredResultStates(kDefaultResultStateTtlInMs); + + const auto& state_iterator = result_state_map_.find(next_page_token); + if (state_iterator == result_state_map_.end()) { + return absl_ports::NotFoundError("next_page_token not found"); + } + result_state = state_iterator->second; } - int num_returned = state_iterator->second.num_returned(); - int num_per_page = state_iterator->second.num_per_page(); - std::vector<ScoredDocumentHit> result_of_page = - state_iterator->second.GetNextPage(document_store_); - if (result_of_page.empty()) { - // This shouldn't happen, all our active states should contain results, but - // a sanity check here in case of any data inconsistency. - InternalInvalidateResultState(next_page_token); - return absl_ports::NotFoundError( - "No more results, token has been invalidated."); - } + // Retrieve docs outside of ResultStateManager critical section. + // Will enter ResultState critical section inside ResultRetriever. + auto [page_result, has_more_results] = + result_retriever.RetrieveNextPage(*result_state); - // Copies the SnippetContext in case the ResultState is invalidated. - SnippetContext snippet_context_copy = - state_iterator->second.snippet_context(); + if (!has_more_results) { + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); - std::unordered_map<std::string, ProjectionTree> projection_tree_map_copy = - state_iterator->second.projection_tree_map(); + InternalInvalidateResultState(next_page_token); + } - if (!state_iterator->second.HasMoreResults()) { - InternalInvalidateResultState(next_page_token); next_page_token = kInvalidNextPageToken; } - - num_total_hits_ -= result_of_page.size(); - return PageResultState( - result_of_page, next_page_token, std::move(snippet_context_copy), - std::move(projection_tree_map_copy), num_returned, num_per_page); + return std::make_pair(next_page_token, std::move(page_result)); } void ResultStateManager::InvalidateResultState(uint64_t next_page_token) { @@ -135,10 +160,12 @@ void ResultStateManager::InvalidateAllResultStates() { } void ResultStateManager::InternalInvalidateAllResultStates() { + // We don't have to reset num_total_hits_ (to 0) here, since clearing + // result_state_map_ will "eventually" invoke the destructor of ResultState + // (which decrements num_total_hits_) and num_total_hits_ will become 0. result_state_map_.clear(); invalidated_token_set_.clear(); - token_queue_ = std::queue<uint64_t>(); - num_total_hits_ = 0; + token_queue_ = std::queue<std::pair<uint64_t, int64_t>>(); } uint64_t ResultStateManager::GetUniqueToken() { @@ -154,14 +181,14 @@ uint64_t ResultStateManager::GetUniqueToken() { return new_token; } -void ResultStateManager::RemoveStatesIfNeeded(const ResultState& result_state) { +void ResultStateManager::RemoveStatesIfNeeded(int num_hits_to_add) { if (result_state_map_.empty() || token_queue_.empty()) { return; } // 1. Check if this new result_state would take up the entire result state // manager budget. - if (result_state.num_remaining() > max_total_hits_) { + if (num_hits_to_add > max_total_hits_) { // This single result state will exceed our budget. Drop everything else to // accomodate it. InternalInvalidateAllResultStates(); @@ -170,16 +197,22 @@ void ResultStateManager::RemoveStatesIfNeeded(const ResultState& result_state) { // 2. Remove any tokens that were previously invalidated. while (!token_queue_.empty() && - invalidated_token_set_.find(token_queue_.front()) != + invalidated_token_set_.find(token_queue_.front().first) != invalidated_token_set_.end()) { - invalidated_token_set_.erase(token_queue_.front()); + invalidated_token_set_.erase(token_queue_.front().first); token_queue_.pop(); } // 3. If we're over budget, remove states from oldest to newest until we fit // into our budget. - while (result_state.num_remaining() + num_total_hits_ > max_total_hits_) { - InternalInvalidateResultState(token_queue_.front()); + // Note: num_total_hits_ may not be decremented immediately after invalidating + // a result state, since other threads may still hold the shared pointer. + // Thus, we have to check if token_queue_ is empty or not, since it is + // possible that num_total_hits_ is non-zero and still greater than + // max_total_hits_ when token_queue_ is empty. Still "eventually" it will be + // decremented after the last thread releases the shared pointer. + while (!token_queue_.empty() && num_total_hits_ > max_total_hits_) { + InternalInvalidateResultState(token_queue_.front().first); token_queue_.pop(); } invalidated_token_set_.clear(); @@ -192,11 +225,34 @@ void ResultStateManager::InternalInvalidateResultState(uint64_t token) { // remove the token in RemoveStatesIfNeeded(). auto itr = result_state_map_.find(token); if (itr != result_state_map_.end()) { - num_total_hits_ -= itr->second.num_remaining(); + // We don't have to decrement num_total_hits_ here, since erasing the shared + // ptr instance will "eventually" invoke the destructor of ResultState and + // it will handle this. result_state_map_.erase(itr); invalidated_token_set_.insert(token); } } +void ResultStateManager::InternalInvalidateExpiredResultStates( + int64_t result_state_ttl) { + int64_t current_time = clock_.GetSystemTimeMilliseconds(); + while (!token_queue_.empty() && + current_time - token_queue_.front().second >= result_state_ttl) { + auto itr = result_state_map_.find(token_queue_.front().first); + if (itr != result_state_map_.end()) { + // We don't have to decrement num_total_hits_ here, since erasing the + // shared ptr instance will "eventually" invoke the destructor of + // ResultState and it will handle this. + result_state_map_.erase(itr); + } else { + // Since result_state_map_ and invalidated_token_set_ are mutually + // exclusive, we remove the token from invalidated_token_set_ only if it + // isn't present in result_state_map_. + invalidated_token_set_.erase(token_queue_.front().first); + } + token_queue_.pop(); + } +} + } // namespace lib } // namespace icing diff --git a/icing/result/result-state-manager.h b/icing/result/result-state-manager.h index c04217f..0684864 100644 --- a/icing/result/result-state-manager.h +++ b/icing/result/result-state-manager.h @@ -15,6 +15,8 @@ #ifndef ICING_RESULT_RESULT_STATE_MANAGER_H_ #define ICING_RESULT_RESULT_STATE_MANAGER_H_ +#include <atomic> +#include <memory> #include <queue> #include <random> #include <unordered_map> @@ -24,8 +26,12 @@ #include "icing/absl_ports/mutex.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" -#include "icing/result/page-result-state.h" -#include "icing/result/result-state.h" +#include "icing/query/query-terms.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/scoring/scored-document-hits-ranker.h" +#include "icing/util/clock.h" namespace icing { namespace lib { @@ -34,39 +40,60 @@ namespace lib { // SearchResultProto.next_page_token. inline constexpr uint64_t kInvalidNextPageToken = 0; +// 1 hr as the default ttl for a ResultState after being pushed into +// token_queue_. +inline constexpr int64_t kDefaultResultStateTtlInMs = 1LL * 60 * 60 * 1000; + // Used to store and manage ResultState. class ResultStateManager { public: explicit ResultStateManager(int max_total_hits, - const DocumentStore& document_store); + const DocumentStore& document_store, + const Clock* clock); ResultStateManager(const ResultStateManager&) = delete; ResultStateManager& operator=(const ResultStateManager&) = delete; - // Ranks the results and returns the first page of them. The result object - // PageResultState contains a next_page_token which can be used to fetch more - // pages later. It will be set to a default value 0 if there're no more pages. + // Creates a new result state, retrieves and returns PageResult for the first + // page. Also caches the new result state and returns a next_page_token which + // can be used to fetch more pages from the same result state later. Before + // caching the result state, adjusts (truncate) the size and evicts some old + // result states if exceeding the cache size limit. next_page_token will be + // set to a default value kInvalidNextPageToken if there're no more pages. // - // NOTE: it's caller's responsibility not to call this method with the same - // ResultState more than once, otherwise duplicate states will be stored - // internally. + // NOTE: it is possible to have empty result for the first page even if the + // ranker was not empty before the retrieval, since GroupResultLimiter + // may filter out all docs. In this case, the first page is also the + // last page and next_page_token will be set to kInvalidNextPageToken. // // Returns: - // A PageResultState on success - // INVALID_ARGUMENT if the input state contains no results - libtextclassifier3::StatusOr<PageResultState> RankAndPaginate( - ResultState result_state) ICING_LOCKS_EXCLUDED(mutex_); + // A token and PageResult wrapped by std::pair on success + // INVALID_ARGUMENT if the input ranker is null or contains no results + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + CacheAndRetrieveFirstPage(std::unique_ptr<ScoredDocumentHitsRanker> ranker, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, + const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, + const DocumentStore& document_store, + const ResultRetrieverV2& result_retriever) + ICING_LOCKS_EXCLUDED(mutex_); - // Retrieves and returns the next page of results wrapped in PageResultState. + // Retrieves and returns PageResult for the next page. // The returned results won't exist in ResultStateManager anymore. If the // query has no more pages after this retrieval, the input token will be // invalidated. // + // NOTE: it is possible to have empty result for the last page even if the + // ranker was not empty before the retrieval, since GroupResultLimiter + // may filtered out all remaining docs. + // // Returns: - // PageResultState on success, guaranteed to have non-empty results + // A token and PageResult wrapped by std::pair on success // NOT_FOUND if failed to find any more results - libtextclassifier3::StatusOr<PageResultState> GetNextPage( - uint64_t next_page_token) ICING_LOCKS_EXCLUDED(mutex_); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> GetNextPage( + uint64_t next_page_token, const ResultRetrieverV2& result_retriever) + ICING_LOCKS_EXCLUDED(mutex_); // Invalidates the result state associated with the given next-page token. void InvalidateResultState(uint64_t next_page_token) @@ -88,14 +115,15 @@ class ResultStateManager { // The number of scored document hits that all result states currently held by // the result state manager have. - int num_total_hits_; + std::atomic<int> num_total_hits_; // A hash map of (next-page token -> result state) - std::unordered_map<uint64_t, ResultState> result_state_map_ + std::unordered_map<uint64_t, std::shared_ptr<ResultStateV2>> result_state_map_ ICING_GUARDED_BY(mutex_); - // A queue used to track the insertion order of tokens - std::queue<uint64_t> token_queue_ ICING_GUARDED_BY(mutex_); + // A queue used to track the insertion order of tokens with pushed timestamps. + std::queue<std::pair<uint64_t, int64_t>> token_queue_ + ICING_GUARDED_BY(mutex_); // A set to temporarily store the invalidated tokens before they're finally // removed from token_queue_. We store the invalidated tokens to ensure the @@ -105,19 +133,23 @@ class ResultStateManager { // A random 64-bit number generator std::mt19937_64 random_generator_ ICING_GUARDED_BY(mutex_); + const Clock& clock_; // Does not own. + // Puts a new result state into the internal storage and returns a next-page // token associated with it. The token is guaranteed to be unique among all // currently valid tokens. When the maximum number of result states is // reached, the oldest / firstly added result state will be removed to make // room for the new state. - uint64_t Add(ResultState result_state) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + uint64_t Add(std::shared_ptr<ResultStateV2> result_state) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Helper method to generate a next-page token that is unique among all // existing tokens in token_queue_. uint64_t GetUniqueToken() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Helper method to remove old states to make room for incoming states. - void RemoveStatesIfNeeded(const ResultState& result_state) + // Helper method to remove old states to make room for incoming states with + // size num_hits_to_add. + void RemoveStatesIfNeeded(int num_hits_to_add) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Helper method to remove a result state from result_state_map_, the token @@ -126,12 +158,18 @@ class ResultStateManager { void InternalInvalidateResultState(uint64_t token) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Internal method to invalidates all result states / tokens currently in + // Internal method to invalidate all result states / tokens currently in // ResultStateManager. We need this separate method so that other public // methods don't need to call InvalidateAllResultStates(). Public methods // calling each other may cause deadlock issues. void InternalInvalidateAllResultStates() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal method to invalidate and remove expired result states / tokens + // currently in ResultStateManager that were created before + // current_time - result_state_ttl. + void InternalInvalidateExpiredResultStates(int64_t result_state_ttl) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); }; } // namespace lib diff --git a/icing/result/result-state-manager_test.cc b/icing/result/result-state-manager_test.cc index 8a9005d..7025c63 100644 --- a/icing/result/result-state-manager_test.cc +++ b/icing/result/result-state-manager_test.cc @@ -16,22 +16,39 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/portable/equals-proto.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "icing/util/clock.h" +#include "unicode/uloc.h" namespace icing { namespace lib { namespace { + using ::icing::lib::portable_equals_proto::EqualsProto; -using ::testing::ElementsAre; using ::testing::Eq; -using ::testing::Gt; using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::SizeIs; +using PageResultInfo = std::pair<uint64_t, PageResult>; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). ScoringSpecProto CreateScoringSpec() { ScoringSpecProto scoring_spec; @@ -45,963 +62,1355 @@ ResultSpecProto CreateResultSpec(int num_per_page) { return result_spec; } -ScoredDocumentHit CreateScoredHit(DocumentId document_id) { - return ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1); +DocumentProto CreateDocument(int id) { + return DocumentBuilder() + .SetNamespace("namespace") + .SetUri(std::to_string(id)) + .SetSchema("Document") + .SetCreationTimestampMs(1574365086666 + id) + .SetScore(1) + .Build(); } class ResultStateManagerTest : public testing::Test { protected: + ResultStateManagerTest() : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + void SetUp() override { - schema_store_base_dir_ = GetTestTempDir() + "/schema_store"; - filesystem_.CreateDirectoryRecursively(schema_store_base_dir_.c_str()); + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + clock_ = std::make_unique<FakeClock>(); + + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + ICING_ASSERT_OK_AND_ASSIGN( schema_store_, - SchemaStore::Create(&filesystem_, schema_store_base_dir_, &clock_)); + SchemaStore::Create(&filesystem_, test_dir_, clock_.get())); SchemaProto schema; schema.add_types()->set_schema_type("Document"); ICING_ASSERT_OK(schema_store_->SetSchema(std::move(schema))); - doc_store_base_dir_ = GetTestTempDir() + "/document_store"; - filesystem_.CreateDirectoryRecursively(doc_store_base_dir_.c_str()); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + ICING_ASSERT_OK_AND_ASSIGN( DocumentStore::CreateResult result, - DocumentStore::Create(&filesystem_, doc_store_base_dir_, &clock_, + DocumentStore::Create(&filesystem_, test_dir_, clock_.get(), schema_store_.get())); document_store_ = std::move(result.document_store); - } - void TearDown() override { - filesystem_.DeleteDirectoryRecursively(doc_store_base_dir_.c_str()); - filesystem_.DeleteDirectoryRecursively(schema_store_base_dir_.c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + result_retriever_, ResultRetrieverV2::Create( + document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); } - ResultState CreateResultState( - const std::vector<ScoredDocumentHit>& scored_document_hits, - int num_per_page) { - return ResultState(scored_document_hits, /*query_terms=*/{}, - SearchSpecProto::default_instance(), CreateScoringSpec(), - CreateResultSpec(num_per_page), *document_store_); + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + clock_.reset(); } - ScoredDocumentHit AddScoredDocument(DocumentId document_id) { + std::pair<ScoredDocumentHit, DocumentProto> AddScoredDocument( + DocumentId document_id) { DocumentProto document; document.set_namespace_("namespace"); document.set_uri(std::to_string(document_id)); document.set_schema("Document"); - document_store_->Put(std::move(document)); - return ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1); + document.set_creation_timestamp_ms(1574365086666 + document_id); + document_store_->Put(document); + return std::make_pair( + ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1), + std::move(document)); } + std::pair<std::vector<ScoredDocumentHit>, std::vector<DocumentProto>> + AddScoredDocuments(const std::vector<DocumentId>& document_ids) { + std::vector<ScoredDocumentHit> scored_document_hits; + std::vector<DocumentProto> document_protos; + + for (DocumentId document_id : document_ids) { + std::pair<ScoredDocumentHit, DocumentProto> pair = + AddScoredDocument(document_id); + scored_document_hits.emplace_back(std::move(pair.first)); + document_protos.emplace_back(std::move(pair.second)); + } + + std::reverse(document_protos.begin(), document_protos.end()); + + return std::make_pair(std::move(scored_document_hits), + std::move(document_protos)); + } + + FakeClock* clock() { return clock_.get(); } + const FakeClock* clock() const { return clock_.get(); } + + DocumentStore& document_store() { return *document_store_; } const DocumentStore& document_store() const { return *document_store_; } + const ResultRetrieverV2& result_retriever() const { + return *result_retriever_; + } + private: Filesystem filesystem_; - std::string doc_store_base_dir_; - std::string schema_store_base_dir_; - Clock clock_; - std::unique_ptr<DocumentStore> document_store_; + const std::string test_dir_; + std::unique_ptr<FakeClock> clock_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<ResultRetrieverV2> result_retriever_; }; -TEST_F(ResultStateManagerTest, ShouldRankAndPaginateOnePage) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/10); +TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageOnePage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}}; + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - - EXPECT_THAT(page_result_state.next_page_token, Eq(kInvalidNextPageToken)); - - // Should get the original scored document hits - EXPECT_THAT( - page_result_state.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/2)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/1)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/0)))); + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::move(ranker), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/10), + document_store(), result_retriever())); + + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); + + // Should get docs. + ASSERT_THAT(page_result_info.second.results, SizeIs(3)); + EXPECT_THAT(page_result_info.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(page_result_info.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/2))); + EXPECT_THAT(page_result_info.second.results.at(2).document(), + EqualsProto(CreateDocument(/*id=*/1))); } -TEST_F(ResultStateManagerTest, ShouldRankAndPaginateMultiplePages) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/2); +TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageMultiplePages) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store().Put(CreateDocument(/*id=*/5))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}, + {document_id4, kSectionIdMaskNone, /*score=*/1}, + {document_id5, kSectionIdMaskNone, /*score=*/1}}; + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); // First page, 2 results ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - EXPECT_THAT( - page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/4)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/3)))); - - uint64_t next_page_token = page_result_state1.next_page_token; + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::move(ranker), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2), + document_store(), result_retriever())); + EXPECT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/5))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/4))); + + uint64_t next_page_token = page_result_info1.first; // Second page, 2 results - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT( - page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/2)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/1)))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info2.first, Eq(next_page_token)); + ASSERT_THAT(page_result_info2.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(page_result_info2.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/2))); // Third page, 1 result - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT( - page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(kInvalidNextPageToken)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/1))); // No results - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), + EXPECT_THAT( + result_state_manager.GetNextPage(next_page_token, result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + +TEST_F(ResultStateManagerTest, NullRankerShouldReturnError) { + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + EXPECT_THAT(result_state_manager.CacheAndRetrieveFirstPage( + /*ranker=*/nullptr, + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(ResultStateManagerTest, EmptyRankerShouldReturnEmptyFirstPage) { + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info.second.results, IsEmpty()); +} + +TEST_F(ResultStateManagerTest, ShouldAllowEmptyFirstPage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}}; + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Create a ResultSpec that limits "namespace" to 0 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(0); + result_grouping->add_namespaces("namespace"); + + // First page, no result. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), result_spec, document_store(), + result_retriever())); + // If the first page has no result, then it should be the last page. + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info.second.results, IsEmpty()); +} + +TEST_F(ResultStateManagerTest, ShouldAllowEmptyLastPage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}, + {document_id4, kSectionIdMaskNone, /*score=*/1}}; + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Create a ResultSpec that limits "namespace" to 2 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace"); + + // First page, 2 results. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), result_spec, document_store(), + result_retriever())); + EXPECT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/4))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/3))); + + uint64_t next_page_token = page_result_info1.first; + + // Second page, all remaining documents will be filtered out by group result + // limiter, so we should get an empty page. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info2.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info2.second.results, IsEmpty()); +} + +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenCacheAndRetrieveFirstPage) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + SectionRestrictQueryTermsMap query_terms; + SearchSpecProto search_spec; + ScoringSpecProto scoring_spec = CreateScoringSpec(); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + + // Set time as 1s and add state 1. + clock()->SetSystemTimeMilliseconds(1000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + query_terms, search_spec, scoring_spec, result_spec, document_store(), + result_retriever())); + ASSERT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + + // Set time as 1hr1s and add state 2. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + query_terms, search_spec, scoring_spec, result_spec, document_store(), + result_retriever())); + + // Calling CacheAndRetrieveFirstPage() on state 2 should invalidate the + // expired state 1 internally. + // + // We test the behavior by setting time back to 1s, to make sure the + // invalidation of state 1 was done by the previous + // CacheAndRetrieveFirstPage() instead of the following GetNextPage(). + clock()->SetSystemTimeMilliseconds(1000); + // page_result_info1's token (page_result_info1.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } -TEST_F(ResultStateManagerTest, EmptyStateShouldReturnError) { - ResultState empty_result_state = CreateResultState({}, /*num_per_page=*/1); +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenGetNextPageOnOthers) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); - EXPECT_THAT( - result_state_manager.RankAndPaginate(std::move(empty_result_state)), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Set time as 1s and add state 1. + clock()->SetSystemTimeMilliseconds(1000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + + // Set time as 2s and add state 2. + clock()->SetSystemTimeMilliseconds(2000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info2.first, Not(Eq(kInvalidNextPageToken))); + + // 1. Set time as 1hr1s. + // 2. Call GetNextPage() on state 2. It should correctly invalidate the + // expired state 1. + // 3. Then calling GetNextPage() on state 1 shouldn't get anything. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + // page_result_info2's token (page_result_info2.first) should be found + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + // We test the behavior by setting time back to 2s, to make sure the + // invalidation of state 1 was done by the previous GetNextPage() instead of + // the following GetNextPage(). + clock()->SetSystemTimeMilliseconds(2000); + // page_result_info1's token (page_result_info1.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenGetNextPageOnItself) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Set time as 1s and add state. + clock()->SetSystemTimeMilliseconds(1000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info.first, Not(Eq(kInvalidNextPageToken))); + + // 1. Set time as 1hr1s. + // 2. Then calling GetNextPage() on the state shouldn't get anything. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + // page_result_info's token (page_result_info.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store().Put(CreateDocument(/*id=*/5))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, + document_store().Put(CreateDocument(/*id=*/6))); + std::vector<ScoredDocumentHit> scored_document_hits1 = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}}; + std::vector<ScoredDocumentHit> scored_document_hits2 = { + {document_id4, kSectionIdMaskNone, /*score=*/1}, + {document_id5, kSectionIdMaskNone, /*score=*/1}, + {document_id6, kSectionIdMaskNone, /*score=*/1}}; ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - result_state_manager.InvalidateResultState( - page_result_state1.next_page_token); + // Invalidate first result state by the token. + result_state_manager.InvalidateResultState(page_result_info1.first); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info1's token (page_result_info1.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state2.next_page_token() should still exist - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT( - page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/4)))); + // page_result_info2's token (page_result_info2.first) should still exist + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + // Should get docs. + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/5))); } TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); result_state_manager.InvalidateAllResultStates(); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info1's token (page_result_info1.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state2.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info2's token (page_result_info2.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager(/*max_total_hits=*/2, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + // Adding state 3 should cause state 1 to be removed. ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/2)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); } TEST_F(ResultStateManagerTest, InvalidatedResultStateShouldDecreaseCurrentHitsCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates state 2, so that the number of hits current cached should be // decremented to 2. - result_state_manager.InvalidateResultState( - page_result_state2.next_page_token); + result_state_manager.InvalidateResultState(page_result_info2.first); // If invalidating state 2 correctly decremented the current hit count to 2, // then adding state 4 should still be within our budget and no other result // states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); } TEST_F(ResultStateManagerTest, InvalidatedAllResultStatesShouldResetCurrentHitCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates all states so that the current hit count will be 0. result_state_manager.InvalidateAllResultStates(); // If invalidating all states correctly reset the current hit count to 0, - // then the entirety of state 4 should still be within our budget and no other + // then adding state 4, 5, 6 should still be within our budget and no other // result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ResultState result_state6 = - CreateResultState({AddScoredDocument(/*document_id=*/10), - AddScoredDocument(/*document_id=*/11)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state6, - result_state_manager.RankAndPaginate(std::move(result_state6))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state3.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + auto [scored_document_hits6, document_protos6] = + AddScoredDocuments({/*document_id=*/10, /*document_id=*/11}); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info6, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits6), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info3.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state6, - result_state_manager.GetNextPage(page_result_state6.next_page_token)); - EXPECT_THAT(page_result_state6.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/10)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info6, + result_state_manager.GetNextPage( + page_result_info6.first, result_retriever())); + ASSERT_THAT(page_result_info6.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info6.second.results.at(0).document(), + EqualsProto(document_protos6.at(1))); } TEST_F( ResultStateManagerTest, InvalidatedResultStateShouldDecreaseCurrentHitsCountByExactStateHitCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates state 2, so that the number of hits current cached should be // decremented to 2. - result_state_manager.InvalidateResultState( - page_result_state2.next_page_token); + result_state_manager.InvalidateResultState(page_result_info2.first); // If invalidating state 2 correctly decremented the current hit count to 2, // then adding state 4 should still be within our budget and no other result // states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // If invalidating result state 2 correctly decremented the current hit count // to 2 and adding state 4 correctly incremented it to 3, then adding this // result state should trigger the eviction of state 1. - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); } TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // GetNextPage for result state 1 should return its result and decrement the // number of cached hits to 2. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2, then adding state 4 should still be within our // budget and no other result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/2)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); } TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCountByExactlyOnePage) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // GetNextPage for result state 1 should return its result and decrement the // number of cached hits to 2. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2, then adding state 4 should still be within our // budget and no other result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2 and adding state 4 correctly incremented it to 3, // then adding this result state should trigger the eviction of state 2. - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); } TEST_F(ResultStateManagerTest, AddingOverBudgetResultStateShouldEvictAllStates) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/3, /*document_id=*/4}); // Add the first two states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1. So 3 hits will remain cached. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1. So 3 + // hits will remain cached. ResultStateManager result_state_manager(/*max_total_hits=*/4, - document_store()); + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Add a result state that is larger than the entire budget. This should // result in all previous result states being evicted, the first hit from // result state 3 being returned and the next four hits being cached (the last // hit should be dropped because it exceeds the max). - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/5), - AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7), - AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9), - AddScoredDocument(/*document_id=*/10)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + auto [scored_document_hits3, document_protos3] = AddScoredDocuments( + {/*document_id=*/5, /*document_id=*/6, /*document_id=*/7, + /*document_id=*/8, /*document_id=*/9, /*document_id=*/10}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + EXPECT_THAT(page_result_info3.first, Not(Eq(kInvalidNextPageToken))); // GetNextPage for result state 1 and 2 should return NOT_FOUND. - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); // Only the next four results in state 3 should be retrievable. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/9)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/7)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); - - // The final result should have been dropped because it exceeded the budget. + uint64_t next_page_token3 = page_result_info3.first; + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(2))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(3))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + // The final document should have been dropped because it exceeded the budget, + // so the next page token of the second last round should be + // kInvalidNextPageToken. + EXPECT_THAT(page_result_info3.first, Eq(kInvalidNextPageToken)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(4))); + + // Double check that next_page_token3 is not retrievable anymore. EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state3.next_page_token), + result_state_manager.GetNextPage(next_page_token3, result_retriever()), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } TEST_F(ResultStateManagerTest, AddingResultStateShouldEvictOverBudgetResultState) { - ResultStateManager result_state_manager(/*max_total_hits=*/4, - document_store()); // Add a result state that is larger than the entire budget. The entire result // state will still be cached - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); - - // Add a result state. Because state2 + state1 is larger than the budget, - // state1 should be evicted. - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); - - // state1 should have been evicted and state2 should still be retrievable. - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); -} - -TEST_F(ResultStateManagerTest, ShouldGetSnippetContext) { - ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); - result_spec.mutable_snippet_spec()->set_num_to_snippet(5); - result_spec.mutable_snippet_spec()->set_num_matches_per_property(5); - result_spec.mutable_snippet_spec()->set_max_window_utf32_length(5); + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2, + /*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); - - SectionRestrictQueryTermsMap query_terms_map; - query_terms_map.emplace("term1", std::unordered_set<std::string>()); - - ResultState original_result_state = ResultState( - /*scored_document_hits=*/{AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - query_terms_map, search_spec, CreateScoringSpec(), result_spec, - document_store()); + ResultStateManager result_state_manager(/*max_total_hits=*/4, + document_store(), clock()); - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - ASSERT_THAT(page_result_state.next_page_token, Gt(kInvalidNextPageToken)); - - EXPECT_THAT(page_result_state.snippet_context.match_type, - Eq(TermMatchType::EXACT_ONLY)); - EXPECT_TRUE(page_result_state.snippet_context.query_terms.find("term1") != - page_result_state.snippet_context.query_terms.end()); - EXPECT_THAT(page_result_state.snippet_context.snippet_spec, - EqualsProto(result_spec.snippet_spec())); -} - -TEST_F(ResultStateManagerTest, ShouldGetDefaultSnippetContext) { - ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); - // 0 indicates no snippeting - result_spec.mutable_snippet_spec()->set_num_to_snippet(0); - result_spec.mutable_snippet_spec()->set_num_matches_per_property(0); - result_spec.mutable_snippet_spec()->set_max_window_utf32_length(0); - - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); - - SectionRestrictQueryTermsMap query_terms_map; - query_terms_map.emplace("term1", std::unordered_set<std::string>()); - - ResultState original_result_state = ResultState( - /*scored_document_hits=*/{AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - query_terms_map, search_spec, CreateScoringSpec(), result_spec, - document_store()); - - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - - ASSERT_THAT(page_result_state.next_page_token, Gt(kInvalidNextPageToken)); - - EXPECT_THAT(page_result_state.snippet_context.query_terms, IsEmpty()); - EXPECT_THAT( - page_result_state.snippet_context.snippet_spec, - EqualsProto(ResultSpecProto::SnippetSpecProto::default_instance())); - EXPECT_THAT(page_result_state.snippet_context.match_type, - Eq(TermMatchType::UNKNOWN)); -} - -TEST_F(ResultStateManagerTest, ShouldGetCorrectNumPreviouslyReturned) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/2); - - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store()); - - // First page, 2 results + // Add a result state. Because state2 + state1 is larger than the budget, + // state1 should be evicted. + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - ASSERT_THAT(page_result_state1.scored_document_hits.size(), Eq(2)); - - // No previously returned results - EXPECT_THAT(page_result_state1.num_previously_returned, Eq(0)); - - uint64_t next_page_token = page_result_state1.next_page_token; - - // Second page, 2 results - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - ASSERT_THAT(page_result_state2.scored_document_hits.size(), Eq(2)); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - // num_previously_returned = size of first page - EXPECT_THAT(page_result_state2.num_previously_returned, Eq(2)); - - // Third page, 1 result - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - ASSERT_THAT(page_result_state3.scored_document_hits.size(), Eq(1)); - - // num_previously_returned = size of first and second pages - EXPECT_THAT(page_result_state3.num_previously_returned, Eq(4)); - - // No more results - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), + // state1 should have been evicted and state2 should still be retrievable. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); -} -TEST_F(ResultStateManagerTest, ShouldStoreAllHits) { - ScoredDocumentHit scored_hit_1 = AddScoredDocument(/*document_id=*/0); - ScoredDocumentHit scored_hit_2 = AddScoredDocument(/*document_id=*/1); - ScoredDocumentHit scored_hit_3 = AddScoredDocument(/*document_id=*/2); - ScoredDocumentHit scored_hit_4 = AddScoredDocument(/*document_id=*/3); - ScoredDocumentHit scored_hit_5 = AddScoredDocument(/*document_id=*/4); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); +} - ResultState original_result_state = CreateResultState( - {scored_hit_1, scored_hit_2, scored_hit_3, scored_hit_4, scored_hit_5}, - /*num_per_page=*/2); +TEST_F(ResultStateManagerTest, + AddingResultStateShouldNotTruncatedAfterFirstPage) { + // Add a result state that is larger than the entire budget, but within the + // entire budget after the first page. The entire result state will still be + // cached and not truncated. + auto [scored_document_hits, document_protos] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2, + /*document_id=*/3, /*document_id=*/4}); ResultStateManager result_state_manager(/*max_total_hits=*/4, - document_store()); + document_store(), clock()); // The 5 input scored document hits will not be truncated. The first page of // two hits will be returned immediately and the other three hits will fit // within our caching budget. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2), + document_store(), result_retriever())); // First page, 2 results - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_5), - EqualsScoredDocumentHit(scored_hit_4))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos.at(0))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(document_protos.at(1))); - uint64_t next_page_token = page_result_state1.next_page_token; + uint64_t next_page_token = page_result_info1.first; // Second page, 2 results. - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos.at(2))); + EXPECT_THAT(page_result_info2.second.results.at(1).document(), + EqualsProto(document_protos.at(3))); // Third page, 1 result. - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_1))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos.at(4))); // Fourth page, 0 results. - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT( + result_state_manager.GetNextPage(next_page_token, result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } } // namespace diff --git a/icing/result/result-state-v2.cc b/icing/result/result-state-v2.cc new file mode 100644 index 0000000..9cb3838 --- /dev/null +++ b/icing/result/result-state-v2.cc @@ -0,0 +1,96 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/result/result-state-v2.h" + +#include <atomic> +#include <memory> + +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/result/projection-tree.h" +#include "icing/result/snippet-context.h" +#include "icing/scoring/scored-document-hits-ranker.h" + +namespace icing { +namespace lib { + +namespace { +SnippetContext CreateSnippetContext(SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, + const ResultSpecProto& result_spec) { + if (result_spec.snippet_spec().num_to_snippet() > 0 && + result_spec.snippet_spec().num_matches_per_property() > 0) { + // Needs snippeting + return SnippetContext(std::move(query_terms), result_spec.snippet_spec(), + search_spec.term_match_type()); + } + return SnippetContext(/*query_terms_in=*/{}, + ResultSpecProto::SnippetSpecProto::default_instance(), + TermMatchType::UNKNOWN); +} +} // namespace + +ResultStateV2::ResultStateV2( + std::unique_ptr<ScoredDocumentHitsRanker> scored_document_hits_ranker_in, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, const DocumentStore& document_store) + : scored_document_hits_ranker(std::move(scored_document_hits_ranker_in)), + num_returned(0), + snippet_context_(CreateSnippetContext(std::move(query_terms), search_spec, + result_spec)), + num_per_page_(result_spec.num_per_page()), + num_total_bytes_per_page_threshold_( + result_spec.num_total_bytes_per_page_threshold()), + num_total_hits_(nullptr) { + for (const TypePropertyMask& type_field_mask : + result_spec.type_property_masks()) { + projection_tree_map_.insert( + {type_field_mask.schema_type(), ProjectionTree(type_field_mask)}); + } + + for (const ResultSpecProto::ResultGrouping& result_grouping : + result_spec.result_groupings()) { + int group_id = group_result_limits.size(); + group_result_limits.push_back(result_grouping.max_results()); + for (const std::string& name_space : result_grouping.namespaces()) { + auto namespace_id_or = document_store.GetNamespaceId(name_space); + if (!namespace_id_or.ok()) { + continue; + } + namespace_group_id_map_.insert({namespace_id_or.ValueOrDie(), group_id}); + } + } +} + +ResultStateV2::~ResultStateV2() { + IncrementNumTotalHits(-1 * scored_document_hits_ranker->size()); +} + +void ResultStateV2::RegisterNumTotalHits(std::atomic<int>* num_total_hits) { + // Decrement the original num_total_hits_ before registering a new one. + IncrementNumTotalHits(-1 * scored_document_hits_ranker->size()); + num_total_hits_ = num_total_hits; + IncrementNumTotalHits(scored_document_hits_ranker->size()); +} + +void ResultStateV2::IncrementNumTotalHits(int increment_by) { + if (num_total_hits_ != nullptr) { + *num_total_hits_ += increment_by; + } +} + +} // namespace lib +} // namespace icing diff --git a/icing/result/result-state-v2.h b/icing/result/result-state-v2.h new file mode 100644 index 0000000..97ff4b6 --- /dev/null +++ b/icing/result/result-state-v2.h @@ -0,0 +1,138 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_RESULT_RESULT_STATE_V2_H_ +#define ICING_RESULT_RESULT_STATE_V2_H_ + +#include <atomic> +#include <memory> +#include <unordered_map> +#include <vector> + +#include "icing/absl_ports/mutex.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/result/projection-tree.h" +#include "icing/result/snippet-context.h" +#include "icing/scoring/scored-document-hits-ranker.h" +#include "icing/store/document-store.h" +#include "icing/store/namespace-id.h" + +namespace icing { +namespace lib { + +// Used to hold information needed across multiple pagination requests of the +// same query. Stored in ResultStateManager. +class ResultStateV2 { + public: + explicit ResultStateV2( + std::unique_ptr<ScoredDocumentHitsRanker> scored_document_hits_ranker_in, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, const DocumentStore& document_store); + + ~ResultStateV2(); + + // Register num_total_hits_ and add current scored_document_hits_ranker.size() + // to it. When re-registering, it will subtract + // scored_document_hits_ranker.size() from the original counter. + void RegisterNumTotalHits(std::atomic<int>* num_total_hits) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex); + + // Increment the global counter num_total_hits_ by increment_by, if + // num_total_hits_ has been registered (is not nullptr). + // Note that providing a negative value for increment_by is a valid usage, + // which will actually decrement num_total_hits_. + // + // It has to be called when we change scored_document_hits_ranker. + void IncrementNumTotalHits(int increment_by) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex); + + const SnippetContext& snippet_context() const + ICING_SHARED_LOCKS_REQUIRED(mutex) { + return snippet_context_; + } + + const std::unordered_map<std::string, ProjectionTree>& projection_tree_map() + const ICING_SHARED_LOCKS_REQUIRED(mutex) { + return projection_tree_map_; + } + + const std::unordered_map<NamespaceId, int>& namespace_group_id_map() const + ICING_SHARED_LOCKS_REQUIRED(mutex) { + return namespace_group_id_map_; + } + + int num_per_page() const ICING_SHARED_LOCKS_REQUIRED(mutex) { + return num_per_page_; + } + + int32_t num_total_bytes_per_page_threshold() const + ICING_SHARED_LOCKS_REQUIRED(mutex) { + return num_total_bytes_per_page_threshold_; + } + + absl_ports::shared_mutex mutex; + + // When evaluating the next top K hits from scored_document_hits_ranker, some + // of them may be filtered out by group_result_limits and won't return to the + // client, so they shouldn't be counted into num_returned. Also the logic of + // group result limiting depends on retrieval, so it is impossible for + // ResultState itself to correctly modify these fields. Thus, we make them + // public, so users of this class can modify them directly. + + // The scored document hits ranker. + std::unique_ptr<ScoredDocumentHitsRanker> scored_document_hits_ranker + ICING_GUARDED_BY(mutex); + + // The count of remaining results to return for a group where group id is the + // index. + std::vector<int> group_result_limits ICING_GUARDED_BY(mutex); + + // Number of results that have already been returned. + int num_returned ICING_GUARDED_BY(mutex); + + private: + // Information needed for snippeting. + SnippetContext snippet_context_ ICING_GUARDED_BY(mutex); + + // Information needed for projection. + std::unordered_map<std::string, ProjectionTree> projection_tree_map_ + ICING_GUARDED_BY(mutex); + + // A map between namespace id and the id of the group that it appears in. + std::unordered_map<NamespaceId, int> namespace_group_id_map_ + ICING_GUARDED_BY(mutex); + + // Number of results to return in each page. + int num_per_page_ ICING_GUARDED_BY(mutex); + + // The threshold of total bytes of all documents to cutoff, in order to limit + // # of bytes in a single page. + // Note that it doesn't guarantee the result # of bytes will be smaller, equal + // to, or larger than the threshold. Instead, it is just a threshold to + // cutoff, and only guarantees total bytes of search results won't exceed the + // threshold too much. + int32_t num_total_bytes_per_page_threshold_ ICING_GUARDED_BY(mutex); + + // Pointer to a global counter to sum up the size of scored_document_hits in + // all ResultStates. + // Does not own. + std::atomic<int>* num_total_hits_ ICING_GUARDED_BY(mutex); +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_RESULT_RESULT_STATE_V2_H_ diff --git a/icing/result/result-state-v2_test.cc b/icing/result/result-state-v2_test.cc new file mode 100644 index 0000000..360e03a --- /dev/null +++ b/icing/result/result-state-v2_test.cc @@ -0,0 +1,486 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/result/result-state-v2.h" + +#include <atomic> +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "gtest/gtest.h" +#include "icing/absl_ports/mutex.h" +#include "icing/file/filesystem.h" +#include "icing/portable/equals-proto.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/result/projection-tree.h" +#include "icing/result/snippet-context.h" +#include "icing/schema/schema-store.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/scored-document-hits-ranker.h" +#include "icing/store/document-store.h" +#include "icing/store/namespace-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/clock.h" + +namespace icing { +namespace lib { +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +SearchSpecProto CreateSearchSpec(TermMatchType::Code match_type) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(match_type); + return search_spec; +} + +ScoringSpecProto CreateScoringSpec(bool is_descending_order) { + ScoringSpecProto scoring_spec; + scoring_spec.set_order_by(is_descending_order ? ScoringSpecProto::Order::DESC + : ScoringSpecProto::Order::ASC); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +class ResultStateV2Test : public ::testing::Test { + protected: + void SetUp() override { + schema_store_base_dir_ = GetTestTempDir() + "/schema_store"; + filesystem_.CreateDirectoryRecursively(schema_store_base_dir_.c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_base_dir_, &clock_)); + SchemaProto schema; + schema.add_types()->set_schema_type("Document"); + ICING_ASSERT_OK(schema_store_->SetSchema(std::move(schema))); + + doc_store_base_dir_ = GetTestTempDir() + "/document_store"; + filesystem_.CreateDirectoryRecursively(doc_store_base_dir_.c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult result, + DocumentStore::Create(&filesystem_, doc_store_base_dir_, &clock_, + schema_store_.get())); + document_store_ = std::move(result.document_store); + + num_total_hits_ = 0; + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(doc_store_base_dir_.c_str()); + filesystem_.DeleteDirectoryRecursively(schema_store_base_dir_.c_str()); + } + + ScoredDocumentHit AddScoredDocument(DocumentId document_id) { + DocumentProto document; + document.set_namespace_("namespace"); + document.set_uri(std::to_string(document_id)); + document.set_schema("Document"); + document_store_->Put(std::move(document)); + return ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1); + } + + DocumentStore& document_store() { return *document_store_; } + + std::atomic<int>& num_total_hits() { return num_total_hits_; } + + const std::atomic<int>& num_total_hits() const { return num_total_hits_; } + + private: + Filesystem filesystem_; + std::string doc_store_base_dir_; + std::string schema_store_base_dir_; + Clock clock_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<SchemaStore> schema_store_; + std::atomic<int> num_total_hits_; +}; + +TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToSpecs) { + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(4096); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + EXPECT_THAT(result_state.num_returned, Eq(0)); + EXPECT_THAT(result_state.num_per_page(), Eq(result_spec.num_per_page())); + EXPECT_THAT(result_state.num_total_bytes_per_page_threshold(), + Eq(result_spec.num_total_bytes_per_page_threshold())); +} + +TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToDefaultSpecs) { + ResultSpecProto default_result_spec = ResultSpecProto::default_instance(); + ASSERT_THAT(default_result_spec.num_per_page(), Eq(10)); + ASSERT_THAT(default_result_spec.num_total_bytes_per_page_threshold(), + Eq(std::numeric_limits<int32_t>::max())); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), default_result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + EXPECT_THAT(result_state.num_returned, Eq(0)); + EXPECT_THAT(result_state.num_per_page(), + Eq(default_result_spec.num_per_page())); + EXPECT_THAT(result_state.num_total_bytes_per_page_threshold(), + Eq(default_result_spec.num_total_bytes_per_page_threshold())); +} + +TEST_F(ResultStateV2Test, ShouldReturnSnippetContextAccordingToSpecs) { + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.mutable_snippet_spec()->set_num_to_snippet(5); + result_spec.mutable_snippet_spec()->set_num_matches_per_property(5); + result_spec.mutable_snippet_spec()->set_max_window_utf32_length(5); + + SectionRestrictQueryTermsMap query_terms_map; + query_terms_map.emplace("term1", std::unordered_set<std::string>()); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + const SnippetContext snippet_context = result_state.snippet_context(); + + // Snippet context should be derived from the specs above. + EXPECT_TRUE(snippet_context.query_terms.find("term1") != + snippet_context.query_terms.end()); + EXPECT_THAT(snippet_context.snippet_spec, + EqualsProto(result_spec.snippet_spec())); + EXPECT_THAT(snippet_context.match_type, Eq(TermMatchType::EXACT_ONLY)); + + // The same copy can be fetched multiple times. + const SnippetContext snippet_context2 = result_state.snippet_context(); + EXPECT_TRUE(snippet_context2.query_terms.find("term1") != + snippet_context2.query_terms.end()); + EXPECT_THAT(snippet_context2.snippet_spec, + EqualsProto(result_spec.snippet_spec())); + EXPECT_THAT(snippet_context2.match_type, Eq(TermMatchType::EXACT_ONLY)); +} + +TEST_F(ResultStateV2Test, NoSnippetingShouldReturnNull) { + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + // Setting num_to_snippet to 0 so that snippeting info won't be + // stored. + result_spec.mutable_snippet_spec()->set_num_to_snippet(0); + result_spec.mutable_snippet_spec()->set_num_matches_per_property(5); + result_spec.mutable_snippet_spec()->set_max_window_utf32_length(5); + + SectionRestrictQueryTermsMap query_terms_map; + query_terms_map.emplace("term1", std::unordered_set<std::string>()); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + const SnippetContext snippet_context = result_state.snippet_context(); + EXPECT_THAT(snippet_context.query_terms, IsEmpty()); + EXPECT_THAT( + snippet_context.snippet_spec, + EqualsProto(ResultSpecProto::SnippetSpecProto::default_instance())); + EXPECT_THAT(snippet_context.match_type, TermMatchType::UNKNOWN); +} + +TEST_F(ResultStateV2Test, ShouldConstructProjectionTreeMapAccordingToSpecs) { + // Create a ResultSpec with type property mask. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + TypePropertyMask* email_type_property_mask = + result_spec.add_type_property_masks(); + email_type_property_mask->set_schema_type("Email"); + email_type_property_mask->add_paths("sender.name"); + email_type_property_mask->add_paths("sender.emailAddress"); + TypePropertyMask* phone_type_property_mask = + result_spec.add_type_property_masks(); + phone_type_property_mask->set_schema_type("Phone"); + phone_type_property_mask->add_paths("caller"); + TypePropertyMask* wildcard_type_property_mask = + result_spec.add_type_property_masks(); + wildcard_type_property_mask->set_schema_type( + std::string(ProjectionTree::kSchemaTypeWildcard)); + wildcard_type_property_mask->add_paths("wild.card"); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + const std::unordered_map<std::string, ProjectionTree>& projection_tree_map = + result_state.projection_tree_map(); + EXPECT_THAT(projection_tree_map, + UnorderedElementsAre( + Pair("Email", ProjectionTree(*email_type_property_mask)), + Pair("Phone", ProjectionTree(*phone_type_property_mask)), + Pair(std::string(ProjectionTree::kSchemaTypeWildcard), + ProjectionTree(*wildcard_type_property_mask)))); +} + +TEST_F(ResultStateV2Test, + ShouldConstructNamespaceGroupIdMapAndGroupResultLimitsAccordingToSpecs) { + // Create 3 docs under namespace1, namespace2, namespace3. + DocumentProto document1; + document1.set_namespace_("namespace1"); + document1.set_uri("uri/1"); + document1.set_schema("Document"); + ICING_ASSERT_OK(document_store().Put(std::move(document1))); + + DocumentProto document2; + document2.set_namespace_("namespace2"); + document2.set_uri("uri/2"); + document2.set_schema("Document"); + ICING_ASSERT_OK(document_store().Put(std::move(document2))); + + DocumentProto document3; + document3.set_namespace_("namespace3"); + document3.set_uri("uri/3"); + document3.set_schema("Document"); + ICING_ASSERT_OK(document_store().Put(std::move(document3))); + + // Create a ResultSpec that limits "namespace1" to 3 results and limits + // "namespace2"+"namespace3" to a total of 2 results. Also add + // "nonexistentNamespace1" and "nonexistentNamespace2" to test the behavior. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/5); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(3); + result_grouping->add_namespaces("namespace1"); + result_grouping = result_spec.add_result_groupings(); + result_grouping->set_max_results(5); + result_grouping->add_namespaces("nonexistentNamespace2"); + result_grouping = result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace2"); + result_grouping->add_namespaces("namespace3"); + result_grouping->add_namespaces("nonexistentNamespace1"); + + // Get namespace ids. + ICING_ASSERT_OK_AND_ASSIGN(NamespaceId namespace_id1, + document_store().GetNamespaceId("namespace1")); + ICING_ASSERT_OK_AND_ASSIGN(NamespaceId namespace_id2, + document_store().GetNamespaceId("namespace2")); + ICING_ASSERT_OK_AND_ASSIGN(NamespaceId namespace_id3, + document_store().GetNamespaceId("namespace3")); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + // "namespace1" should be in group 0, and "namespace2"+"namespace3" should be + // in group 2. + // "nonexistentNamespace1" and "nonexistentNamespace2" shouldn't exist. + EXPECT_THAT( + result_state.namespace_group_id_map(), + UnorderedElementsAre(Pair(namespace_id1, 0), Pair(namespace_id2, 2), + Pair(namespace_id3, 2))); + + // group_result_limits should contain 3 (at index 0 for group 0), 5 (at index + // 1 for group 1), 2 (at index 2 for group 2), even though there is no valid + // namespace in group 1. + EXPECT_THAT(result_state.group_result_limits, ElementsAre(3, 5, 2)); +} + +TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHits) { + std::vector<ScoredDocumentHit> scored_document_hits = { + AddScoredDocument(/*document_id=*/1), + AddScoredDocument(/*document_id=*/0), + AddScoredDocument(/*document_id=*/2), + AddScoredDocument(/*document_id=*/4), + AddScoredDocument(/*document_id=*/3)}; + + // Creates a ResultState with 5 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/5), document_store()); + + absl_ports::unique_lock l(&result_state.mutex); + + EXPECT_THAT(num_total_hits(), Eq(0)); + result_state.RegisterNumTotalHits(&num_total_hits()); + EXPECT_THAT(num_total_hits(), Eq(5)); + result_state.IncrementNumTotalHits(500); + EXPECT_THAT(num_total_hits(), Eq(505)); +} + +TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHitsWhenDestructed) { + std::vector<ScoredDocumentHit> scored_document_hits1 = { + AddScoredDocument(/*document_id=*/1), + AddScoredDocument(/*document_id=*/0), + AddScoredDocument(/*document_id=*/2), + AddScoredDocument(/*document_id=*/4), + AddScoredDocument(/*document_id=*/3)}; + + std::vector<ScoredDocumentHit> scored_document_hits2 = { + AddScoredDocument(/*document_id=*/6), + AddScoredDocument(/*document_id=*/5)}; + + num_total_hits() = 2; + { + // Creates a ResultState with 5 ScoredDocumentHits. + ResultStateV2 result_state1( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/5), document_store()); + + absl_ports::unique_lock l(&result_state1.mutex); + + result_state1.RegisterNumTotalHits(&num_total_hits()); + ASSERT_THAT(num_total_hits(), Eq(7)); + + { + // Creates another ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state2( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/5), document_store()); + + absl_ports::unique_lock l(&result_state2.mutex); + + result_state2.RegisterNumTotalHits(&num_total_hits()); + ASSERT_THAT(num_total_hits(), Eq(9)); + } + + EXPECT_THAT(num_total_hits(), Eq(7)); + } + EXPECT_THAT(num_total_hits(), Eq(2)); +} + +TEST_F(ResultStateV2Test, ShouldNotUpdateNumTotalHitsWhenNotRegistered) { + std::vector<ScoredDocumentHit> scored_document_hits = { + AddScoredDocument(/*document_id=*/1), + AddScoredDocument(/*document_id=*/0), + AddScoredDocument(/*document_id=*/2), + AddScoredDocument(/*document_id=*/4), + AddScoredDocument(/*document_id=*/3)}; + + // Creates a ResultState with 5 ScoredDocumentHits. + { + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/5), document_store()); + + { + absl_ports::unique_lock l(&result_state.mutex); + + EXPECT_THAT(num_total_hits(), Eq(0)); + result_state.IncrementNumTotalHits(500); + EXPECT_THAT(num_total_hits(), Eq(0)); + } + } + EXPECT_THAT(num_total_hits(), Eq(0)); +} + +TEST_F(ResultStateV2Test, ShouldDecrementOriginalNumTotalHitsWhenReregister) { + std::atomic<int> another_num_total_hits = 11; + + std::vector<ScoredDocumentHit> scored_document_hits = { + AddScoredDocument(/*document_id=*/1), + AddScoredDocument(/*document_id=*/0), + AddScoredDocument(/*document_id=*/2), + AddScoredDocument(/*document_id=*/4), + AddScoredDocument(/*document_id=*/3)}; + + // Creates a ResultState with 5 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), + CreateResultSpec(/*num_per_page=*/5), document_store()); + + absl_ports::unique_lock l(&result_state.mutex); + + num_total_hits() = 7; + result_state.RegisterNumTotalHits(&num_total_hits()); + EXPECT_THAT(num_total_hits(), Eq(12)); + + result_state.RegisterNumTotalHits(&another_num_total_hits); + // The original num_total_hits should be decremented after re-registration. + EXPECT_THAT(num_total_hits(), Eq(7)); + // another_num_total_hits should be incremented after re-registration. + EXPECT_THAT(another_num_total_hits, Eq(16)); + + result_state.IncrementNumTotalHits(500); + // The original num_total_hits should be unchanged. + EXPECT_THAT(num_total_hits(), Eq(7)); + // Increment should be done on another_num_total_hits. + EXPECT_THAT(another_num_total_hits, Eq(516)); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/result/result-state.cc b/icing/result/result-state.cc index fc89185..24f5c09 100644 --- a/icing/result/result-state.cc +++ b/icing/result/result-state.cc @@ -82,13 +82,15 @@ class GroupResultLimiter { // Returns true if the scored_document_hit should be removed. bool operator()(const ScoredDocumentHit& scored_document_hit) { - auto document_filter_data_or = document_store_.GetDocumentFilterData( - scored_document_hit.document_id()); - if (!document_filter_data_or.ok()) { + auto document_filter_data_optional = + document_store_.GetAliveDocumentFilterData( + scored_document_hit.document_id()); + if (!document_filter_data_optional) { + // Document doesn't exist. return true; } NamespaceId namespace_id = - document_filter_data_or.ValueOrDie().namespace_id(); + document_filter_data_optional.value().namespace_id(); auto iter = namespace_group_id_map_.find(namespace_id); if (iter == namespace_group_id_map_.end()) { return false; diff --git a/icing/result/snippet-retriever.cc b/icing/result/snippet-retriever.cc index bd1524e..2391900 100644 --- a/icing/result/snippet-retriever.cc +++ b/icing/result/snippet-retriever.cc @@ -80,6 +80,20 @@ inline std::string AddIndexToPath(int values_size, int index, // is applied based on the Token's type. std::string NormalizeToken(const Normalizer& normalizer, const Token& token) { switch (token.type) { + case Token::Type::RFC822_NAME: + [[fallthrough]]; + case Token::Type::RFC822_COMMENT: + [[fallthrough]]; + case Token::Type::RFC822_LOCAL_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_HOST: + [[fallthrough]]; + case Token::Type::RFC822_TOKEN: + [[fallthrough]]; case Token::Type::REGULAR: return normalizer.NormalizeTerm(token.text); case Token::Type::VERBATIM: @@ -126,6 +140,20 @@ CharacterIterator FindMatchEnd(const Normalizer& normalizer, const Token& token, [[fallthrough]]; case Token::Type::QUERY_PROPERTY: [[fallthrough]]; + case Token::Type::RFC822_NAME: + [[fallthrough]]; + case Token::Type::RFC822_COMMENT: + [[fallthrough]]; + case Token::Type::RFC822_LOCAL_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_HOST: + [[fallthrough]]; + case Token::Type::RFC822_TOKEN: + [[fallthrough]]; case Token::Type::INVALID: ICING_LOG(WARNING) << "Unexpected Token type " << static_cast<int>(token.type) diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index fc50ea6..653f34f 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -27,6 +27,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" #include "icing/file/file-backed-proto.h" #include "icing/file/filesystem.h" #include "icing/proto/document.pb.h" @@ -35,7 +36,7 @@ #include "icing/schema/section-manager.h" #include "icing/schema/section.h" #include "icing/store/document-filter-data.h" -#include "icing/store/key-mapper.h" +#include "icing/store/dynamic-trie-key-mapper.h" #include "icing/util/crc32.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -49,8 +50,9 @@ constexpr char kSchemaStoreHeaderFilename[] = "schema_store_header"; constexpr char kSchemaFilename[] = "schema.pb"; constexpr char kSchemaTypeMapperFilename[] = "schema_type_mapper"; -// A KeyMapper stores its data across 3 arrays internally. Giving each array -// 128KiB for storage means the entire KeyMapper requires 384KiB. +// A DynamicTrieKeyMapper stores its data across 3 arrays internally. Giving +// each array 128KiB for storage means the entire DynamicTrieKeyMapper requires +// 384KiB. constexpr int32_t kSchemaTypeMapperMaxSize = 3 * 128 * 1024; // 384 KiB const std::string MakeHeaderFilename(const std::string& base_dir) { @@ -196,8 +198,8 @@ libtextclassifier3::Status SchemaStore::InitializeInternal( if (initialize_stats != nullptr) { initialize_stats->set_num_schema_types(type_config_map_.size()); } - has_schema_successfully_set_ = true; + return libtextclassifier3::Status::OK; } @@ -222,9 +224,9 @@ libtextclassifier3::Status SchemaStore::InitializeDerivedFiles() { ICING_ASSIGN_OR_RETURN( schema_type_mapper_, - KeyMapper<SchemaTypeId>::Create(*filesystem_, - MakeSchemaTypeMapperFilename(base_dir_), - kSchemaTypeMapperMaxSize)); + DynamicTrieKeyMapper<SchemaTypeId>::Create( + *filesystem_, MakeSchemaTypeMapperFilename(base_dir_), + kSchemaTypeMapperMaxSize)); ICING_ASSIGN_OR_RETURN(Crc32 checksum, ComputeChecksum()); if (checksum.Get() != header.checksum) { @@ -307,8 +309,9 @@ libtextclassifier3::Status SchemaStore::ResetSchemaTypeMapper() { schema_type_mapper_.reset(); // TODO(b/216487496): Implement a more robust version of TC_RETURN_IF_ERROR // that can support error logging. - libtextclassifier3::Status status = KeyMapper<SchemaTypeId>::Delete( - *filesystem_, MakeSchemaTypeMapperFilename(base_dir_)); + libtextclassifier3::Status status = + DynamicTrieKeyMapper<SchemaTypeId>::Delete( + *filesystem_, MakeSchemaTypeMapperFilename(base_dir_)); if (!status.ok()) { ICING_LOG(ERROR) << status.error_message() << "Failed to delete old schema_type mapper"; @@ -316,9 +319,9 @@ libtextclassifier3::Status SchemaStore::ResetSchemaTypeMapper() { } ICING_ASSIGN_OR_RETURN( schema_type_mapper_, - KeyMapper<SchemaTypeId>::Create(*filesystem_, - MakeSchemaTypeMapperFilename(base_dir_), - kSchemaTypeMapperMaxSize)); + DynamicTrieKeyMapper<SchemaTypeId>::Create( + *filesystem_, MakeSchemaTypeMapperFilename(base_dir_), + kSchemaTypeMapperMaxSize)); return libtextclassifier3::Status::OK; } @@ -447,46 +450,29 @@ libtextclassifier3::Status SchemaStore::ApplySchemaChange( std::string temp_schema_store_dir_path = base_dir_ + "_temp"; if (!filesystem_->DeleteDirectoryRecursively( temp_schema_store_dir_path.c_str())) { - ICING_LOG(WARNING) << "Failed to recursively delete " + ICING_LOG(ERROR) << "Recursively deleting " << temp_schema_store_dir_path.c_str(); return absl_ports::InternalError( "Unable to delete temp directory to prepare to build new schema " "store."); } - if (!filesystem_->CreateDirectoryRecursively( - temp_schema_store_dir_path.c_str())) { + DestructibleDirectory temp_schema_store_dir( + filesystem_, std::move(temp_schema_store_dir_path)); + if (!temp_schema_store_dir.is_valid()) { return absl_ports::InternalError( "Unable to create temp directory to build new schema store."); } // Then we create our new schema store with the new schema. - auto new_schema_store_or = - SchemaStore::Create(filesystem_, temp_schema_store_dir_path, clock_, - std::move(new_schema)); - if (!new_schema_store_or.ok()) { - // Attempt to clean up the temp directory. - if (!filesystem_->DeleteDirectoryRecursively( - temp_schema_store_dir_path.c_str())) { - // Nothing to do here. Just log an error. - ICING_LOG(WARNING) << "Failed to recursively delete " - << temp_schema_store_dir_path.c_str(); - } - return new_schema_store_or.status(); - } - std::unique_ptr<SchemaStore> new_schema_store = - std::move(new_schema_store_or).ValueOrDie(); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<SchemaStore> new_schema_store, + SchemaStore::Create(filesystem_, temp_schema_store_dir.dir(), clock_, + std::move(new_schema))); // Then we swap the new schema file + new derived files with the old files. if (!filesystem_->SwapFiles(base_dir_.c_str(), - temp_schema_store_dir_path.c_str())) { - // Attempt to clean up the temp directory. - if (!filesystem_->DeleteDirectoryRecursively( - temp_schema_store_dir_path.c_str())) { - // Nothing to do here. Just log an error. - ICING_LOG(WARNING) << "Failed to recursively delete " - << temp_schema_store_dir_path.c_str(); - } + temp_schema_store_dir.dir().c_str())) { return absl_ports::InternalError( "Unable to apply new schema due to failed swap!"); } diff --git a/icing/schema/schema-store.h b/icing/schema/schema-store.h index 58e5477..82f4ffa 100644 --- a/icing/schema/schema-store.h +++ b/icing/schema/schema-store.h @@ -130,7 +130,7 @@ class SchemaStore { static libtextclassifier3::StatusOr<std::unique_ptr<SchemaStore>> Create( const Filesystem* filesystem, const std::string& base_dir, const Clock* clock, InitializeStatsProto* initialize_stats = nullptr); - + SchemaStore(SchemaStore&&) = default; SchemaStore& operator=(SchemaStore&&) = default; @@ -282,7 +282,6 @@ class SchemaStore { const Filesystem* filesystem, const std::string& base_dir, const Clock* clock, SchemaProto schema); - // Use SchemaStore::Create instead. explicit SchemaStore(const Filesystem* filesystem, std::string base_dir, const Clock* clock); diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc index 3fd41c4..ffd1292 100644 --- a/icing/schema/schema-store_test.cc +++ b/icing/schema/schema-store_test.cc @@ -18,6 +18,7 @@ #include <string> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" @@ -35,7 +36,6 @@ #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" -#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/util/crc32.h" namespace icing { @@ -73,8 +73,8 @@ constexpr PropertyConfigProto::DataType::Code TYPE_DOUBLE = class SchemaStoreTest : public ::testing::Test { protected: void SetUp() override { - temp_dir_ = GetTestTempDir() + "/icing"; - schema_store_dir_ = temp_dir_ + "/schema_store"; + test_dir_ = GetTestTempDir() + "/icing"; + schema_store_dir_ = test_dir_ + "/schema_store"; filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); schema_ = @@ -93,24 +93,24 @@ class SchemaStoreTest : public ::testing::Test { // schema_store_dir_. IOW, ensure that all temporary directories have been // properly cleaned up. std::vector<std::string> sub_dirs; - ASSERT_TRUE(filesystem_.ListDirectory(temp_dir_.c_str(), &sub_dirs)); + ASSERT_TRUE(filesystem_.ListDirectory(test_dir_.c_str(), &sub_dirs)); ASSERT_THAT(sub_dirs, ElementsAre("schema_store")); // Finally, clean everything up. - ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(temp_dir_.c_str())); + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); } Filesystem filesystem_; - std::string temp_dir_; + std::string test_dir_; std::string schema_store_dir_; SchemaProto schema_; FakeClock fake_clock_; }; TEST_F(SchemaStoreTest, CreationWithNullPointerShouldFail) { - EXPECT_THAT( - SchemaStore::Create(/*filesystem=*/nullptr, schema_store_dir_, &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(SchemaStore::Create(/*filesystem=*/nullptr, schema_store_dir_, + &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(SchemaStoreTest, SchemaStoreMoveConstructible) { @@ -215,15 +215,17 @@ TEST_F(SchemaStoreTest, CorruptSchemaError) { .AddType(SchemaTypeConfigBuilder().SetType("corrupted")) .Build(); - const std::string schema_file = absl_ports::StrCat(schema_store_dir_, "/schema.pb"); + const std::string schema_file = + absl_ports::StrCat(schema_store_dir_, "/schema.pb"); const std::string serialized_schema = corrupt_schema.SerializeAsString(); filesystem_.Write(schema_file.c_str(), serialized_schema.data(), serialized_schema.size()); // If ground truth was corrupted, we won't know what to do - EXPECT_THAT(SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::INTERNAL)); + EXPECT_THAT( + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } TEST_F(SchemaStoreTest, RecoverCorruptDerivedFileOk) { @@ -350,8 +352,9 @@ TEST_F(SchemaStoreTest, CreateWithPreviousSchemaOk) { IsOkAndHolds(EqualsSetSchemaResult(result))); schema_store.reset(); - EXPECT_THAT(SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_), - IsOk()); + EXPECT_THAT( + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_), + IsOk()); } TEST_F(SchemaStoreTest, MultipleCreateOk) { @@ -383,7 +386,8 @@ TEST_F(SchemaStoreTest, MultipleCreateOk) { schema_store.reset(); ICING_ASSERT_OK_AND_ASSIGN( - schema_store, SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + schema_store, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); // Verify that our in-memory structures are ok EXPECT_THAT(schema_store->GetSchemaTypeConfig("email"), @@ -1017,7 +1021,8 @@ TEST_F(SchemaStoreTest, ComputeChecksumSameAcrossInstances) { schema_store.reset(); ICING_ASSERT_OK_AND_ASSIGN( - schema_store, SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + schema_store, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); EXPECT_THAT(schema_store->ComputeChecksum(), IsOkAndHolds(checksum)); } @@ -1082,7 +1087,8 @@ TEST_F(SchemaStoreTest, PersistToDiskPreservesAcrossInstances) { // And we get the same schema back on reinitialization ICING_ASSERT_OK_AND_ASSIGN( - schema_store, SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + schema_store, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ICING_ASSERT_OK_AND_ASSIGN(actual_schema, schema_store->GetSchema()); EXPECT_THAT(*actual_schema, EqualsProto(schema)); } diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc index 3dcc5a9..cb7c561 100644 --- a/icing/schema/section-manager_test.cc +++ b/icing/schema/section-manager_test.cc @@ -23,6 +23,7 @@ #include "icing/proto/schema.pb.h" #include "icing/proto/term.pb.h" #include "icing/schema/schema-util.h" +#include "icing/store/dynamic-trie-key-mapper.h" #include "icing/store/key-mapper.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -78,11 +79,11 @@ class SectionManagerTest : public ::testing::Test { } void SetUp() override { - // KeyMapper uses 3 internal arrays for bookkeeping. Give each one 128KiB so - // the total KeyMapper should get 384KiB + // DynamicTrieKeyMapper uses 3 internal arrays for bookkeeping. Give each + // one 128KiB so the total DynamicTrieKeyMapper should get 384KiB int key_mapper_size = 3 * 128 * 1024; ICING_ASSERT_OK_AND_ASSIGN(schema_type_mapper_, - KeyMapper<SchemaTypeId>::Create( + DynamicTrieKeyMapper<SchemaTypeId>::Create( filesystem_, test_dir_, key_mapper_size)); ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeEmail, 0)); ICING_ASSERT_OK(schema_type_mapper_->Put(kTypeConversation, 1)); @@ -397,13 +398,14 @@ TEST_F(SectionManagerTest, type_with_non_string_properties); type_config_map.emplace(empty_type.schema_type(), empty_type); - // KeyMapper uses 3 internal arrays for bookkeeping. Give each one 128KiB so - // the total KeyMapper should get 384KiB + // DynamicTrieKeyMapper uses 3 internal arrays for bookkeeping. Give each one + // 128KiB so the total DynamicTrieKeyMapper should get 384KiB int key_mapper_size = 3 * 128 * 1024; std::string dir = GetTestTempDir() + "/non_string_fields"; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper, - KeyMapper<SchemaTypeId>::Create(filesystem_, dir, key_mapper_size)); + DynamicTrieKeyMapper<SchemaTypeId>::Create(filesystem_, dir, + key_mapper_size)); ICING_ASSERT_OK(schema_type_mapper->Put( type_with_non_string_properties.schema_type(), /*schema_type_id=*/0)); ICING_ASSERT_OK(schema_type_mapper->Put(empty_type.schema_type(), @@ -486,13 +488,14 @@ TEST_F(SectionManagerTest, AssignSectionsRecursivelyForDocumentFields) { type_config_map.emplace(type.schema_type(), type); type_config_map.emplace(document_type.schema_type(), document_type); - // KeyMapper uses 3 internal arrays for bookkeeping. Give each one 128KiB so - // the total KeyMapper should get 384KiB + // DynamicTrieKeyMapper uses 3 internal arrays for bookkeeping. Give each one + // 128KiB so the total DynamicTrieKeyMapper should get 384KiB int key_mapper_size = 3 * 128 * 1024; std::string dir = GetTestTempDir() + "/recurse_into_document"; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper, - KeyMapper<SchemaTypeId>::Create(filesystem_, dir, key_mapper_size)); + DynamicTrieKeyMapper<SchemaTypeId>::Create(filesystem_, dir, + key_mapper_size)); int type_schema_type_id = 0; int document_type_schema_type_id = 1; ICING_ASSERT_OK( @@ -560,13 +563,14 @@ TEST_F(SectionManagerTest, DontAssignSectionsRecursivelyForDocumentFields) { type_config_map.emplace(type.schema_type(), type); type_config_map.emplace(document_type.schema_type(), document_type); - // KeyMapper uses 3 internal arrays for bookkeeping. Give each one 128KiB so - // the total KeyMapper should get 384KiB + // DynamicTrieKeyMapper uses 3 internal arrays for bookkeeping. Give each one + // 128KiB so the total DynamicTrieKeyMapper should get 384KiB int key_mapper_size = 3 * 128 * 1024; std::string dir = GetTestTempDir() + "/recurse_into_document"; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<KeyMapper<SchemaTypeId>> schema_type_mapper, - KeyMapper<SchemaTypeId>::Create(filesystem_, dir, key_mapper_size)); + DynamicTrieKeyMapper<SchemaTypeId>::Create(filesystem_, dir, + key_mapper_size)); int type_schema_type_id = 0; int document_type_schema_type_id = 1; ICING_ASSERT_OK( diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 28d385e..28ee2ba 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -20,7 +20,6 @@ #include <unordered_set> #include <vector> -#include "icing/absl_ports/str_cat.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/store/corpus-associated-scoring-data.h" @@ -116,9 +115,8 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, score += idf_weight * normalized_tf; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "BM25F: corpus_id:%d docid:%d score:%f\n", data.corpus_id(), - hit_info.document_id(), score); + ICING_VLOG(1) << "BM25F: corpus_id:" << data.corpus_id() << " docid:" + << hit_info.document_id() << " score:" << score; return score; } @@ -144,8 +142,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, // First, figure out corpus scoring data. auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); if (!status_or.ok()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "No scoring data for corpus [%d]", corpus_id); + ICING_LOG(ERROR) << "No scoring data for corpus [" << corpus_id << "]"; return 0; } CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); @@ -155,9 +152,8 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, float idf = nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, - std::string(term).c_str(), num_docs, nqi, idf); + ICING_VLOG(1) << "corpus_id:" << corpus_id << " term:" + << term << " N:" << num_docs << "nqi:" << nqi << " idf:" << idf; return idf; } @@ -176,8 +172,7 @@ float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { // First, figure out corpus scoring data. auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); if (!status_or.ok()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "No scoring data for corpus [%d]", corpus_id); + ICING_LOG(ERROR) << "No scoring data for corpus [" << corpus_id << "]"; return 0; } CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); @@ -205,9 +200,9 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "corpus_id:%d docid:%d dl:%d avgdl:%f f_q:%f norm_tf:%f\n", - data.corpus_id(), hit_info.document_id(), dl, avgdl, f_q, normalized_tf); + ICING_VLOG(1) << "corpus_id:" << data.corpus_id() << " docid:" + << hit_info.document_id() << " dl:" << dl << " avgdl:" << avgdl << " f_q:" + << f_q << " norm_tf:" << normalized_tf; return normalized_tf; } @@ -233,18 +228,18 @@ float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( } SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { - auto filter_data_or = document_store_->GetDocumentFilterData(document_id); - if (!filter_data_or.ok()) { + auto filter_data_optional = + document_store_->GetAliveDocumentFilterData(document_id); + if (!filter_data_optional) { // This should never happen. The only failure case for // GetDocumentFilterData is if the document_id is outside of the range of // allocated document_ids, which shouldn't be possible since we're getting // this document_id from the posting lists. - ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( - "No document filter data for document [%d]", document_id); + ICING_LOG(WARNING) << "No document filter data for document [" + << document_id << "]"; return kInvalidSchemaTypeId; } - DocumentFilterData data = filter_data_or.ValueOrDie(); - return data.schema_type_id(); + return filter_data_optional.value().schema_type_id(); } } // namespace lib diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.cc b/icing/scoring/priority-queue-scored-document-hits-ranker.cc new file mode 100644 index 0000000..691b088 --- /dev/null +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.cc @@ -0,0 +1,53 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" + +#include <queue> +#include <vector> + +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +PriorityQueueScoredDocumentHitsRanker::PriorityQueueScoredDocumentHitsRanker( + std::vector<ScoredDocumentHit>&& scored_document_hits, bool is_descending) + : comparator_(/*is_ascending=*/!is_descending), + scored_document_hits_pq_(comparator_, std::move(scored_document_hits)) {} + +ScoredDocumentHit PriorityQueueScoredDocumentHitsRanker::PopNext() { + ScoredDocumentHit ret = scored_document_hits_pq_.top(); + scored_document_hits_pq_.pop(); + return ret; +} + +void PriorityQueueScoredDocumentHitsRanker::TruncateHitsTo(int new_size) { + if (new_size < 0 || scored_document_hits_pq_.size() <= new_size) { + return; + } + + // Copying the best new_size results. + std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>, + Comparator> + new_pq(comparator_); + for (int i = 0; i < new_size; ++i) { + new_pq.push(scored_document_hits_pq_.top()); + scored_document_hits_pq_.pop(); + } + scored_document_hits_pq_ = std::move(new_pq); +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.h b/icing/scoring/priority-queue-scored-document-hits-ranker.h new file mode 100644 index 0000000..e0ae4b0 --- /dev/null +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.h @@ -0,0 +1,72 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_PRIORITY_QUEUE_SCORED_DOCUMENT_HITS_RANKER_H_ +#define ICING_SCORING_PRIORITY_QUEUE_SCORED_DOCUMENT_HITS_RANKER_H_ + +#include <queue> +#include <vector> + +#include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/scored-document-hits-ranker.h" + +namespace icing { +namespace lib { + +// ScoredDocumentHitsRanker interface implementation, based on +// std::priority_queue. We can get next top hit in O(lgN) time. +class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { + public: + explicit PriorityQueueScoredDocumentHitsRanker( + std::vector<ScoredDocumentHit>&& scored_document_hits, + bool is_descending = true); + + ~PriorityQueueScoredDocumentHitsRanker() override = default; + + ScoredDocumentHit PopNext() override; + + void TruncateHitsTo(int new_size) override; + + int size() const override { return scored_document_hits_pq_.size(); } + + bool empty() const override { return scored_document_hits_pq_.empty(); } + + private: + // Comparator for std::priority_queue. Since std::priority is a max heap + // (descending order), reverse it if we want ascending order. + class Comparator { + public: + explicit Comparator(bool is_ascending) : is_ascending_(is_ascending) {} + + bool operator()(const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) const { + return is_ascending_ == !(lhs < rhs); + } + + private: + bool is_ascending_; + }; + + Comparator comparator_; + + // Use priority queue to get top K hits in O(KlgN) time. + std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>, + Comparator> + scored_document_hits_pq_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_PRIORITY_QUEUE_SCORED_DOCUMENT_HITS_RANKER_H_ diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc new file mode 100644 index 0000000..a575eaf --- /dev/null +++ b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc @@ -0,0 +1,239 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; + +std::vector<ScoredDocumentHit> PopAll( + PriorityQueueScoredDocumentHitsRanker& ranker) { + std::vector<ScoredDocumentHit> hits; + while (!ranker.empty()) { + hits.push_back(ranker.PopNext()); + } + return hits; +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldGetCorrectSizeAndEmpty) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2}, + /*is_descending=*/true); + EXPECT_THAT(ranker.size(), Eq(3)); + EXPECT_FALSE(ranker.empty()); + + ranker.PopNext(); + EXPECT_THAT(ranker.size(), Eq(2)); + EXPECT_FALSE(ranker.empty()); + + ranker.PopNext(); + EXPECT_THAT(ranker.size(), Eq(1)); + EXPECT_FALSE(ranker.empty()); + + ranker.PopNext(); + EXPECT_THAT(ranker.size(), Eq(0)); + EXPECT_TRUE(ranker.empty()); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInDescendingOrder) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/true); + + EXPECT_THAT(ranker, SizeIs(5)); + std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT(scored_document_hits, + ElementsAre(EqualsScoredDocumentHit(scored_hit_4), + EqualsScoredDocumentHit(scored_hit_3), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_1), + EqualsScoredDocumentHit(scored_hit_0))); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/false); + + EXPECT_THAT(ranker, SizeIs(5)); + std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT(scored_document_hits, + ElementsAre(EqualsScoredDocumentHit(scored_hit_0), + EqualsScoredDocumentHit(scored_hit_1), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_3), + EqualsScoredDocumentHit(scored_hit_4))); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, + ShouldRankDuplicateScoredDocumentHits) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_2, scored_hit_4, scored_hit_1, scored_hit_0, scored_hit_2, + scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/true); + + EXPECT_THAT(ranker, SizeIs(8)); + std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT(scored_document_hits, + ElementsAre(EqualsScoredDocumentHit(scored_hit_4), + EqualsScoredDocumentHit(scored_hit_4), + EqualsScoredDocumentHit(scored_hit_3), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_1), + EqualsScoredDocumentHit(scored_hit_0))); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, + ShouldRankEmptyScoredDocumentHits) { + PriorityQueueScoredDocumentHitsRanker ranker(/*scored_document_hits=*/{}, + /*is_descending=*/true); + EXPECT_THAT(ranker, IsEmpty()); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToNewSize) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/true); + ASSERT_THAT(ranker, SizeIs(5)); + + ranker.TruncateHitsTo(/*new_size=*/3); + EXPECT_THAT(ranker, SizeIs(3)); + std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT(scored_document_hits, + ElementsAre(EqualsScoredDocumentHit(scored_hit_4), + EqualsScoredDocumentHit(scored_hit_3), + EqualsScoredDocumentHit(scored_hit_2))); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/true); + ASSERT_THAT(ranker, SizeIs(5)); + + ranker.TruncateHitsTo(/*new_size=*/0); + EXPECT_THAT(ranker, IsEmpty()); +} + +TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) { + ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone, + /*score=*/1); + ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, + /*score=*/1); + + PriorityQueueScoredDocumentHitsRanker ranker( + {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, + /*is_descending=*/true); + ASSERT_THAT(ranker, SizeIs(Eq(5))); + + ranker.TruncateHitsTo(/*new_size=*/-1); + EXPECT_THAT(ranker, SizeIs(Eq(5))); + // Contents are not affected. + std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT(scored_document_hits, + ElementsAre(EqualsScoredDocumentHit(scored_hit_4), + EqualsScoredDocumentHit(scored_hit_3), + EqualsScoredDocumentHit(scored_hit_2), + EqualsScoredDocumentHit(scored_hit_1), + EqualsScoredDocumentHit(scored_hit_0))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/ranker.cc b/icing/scoring/ranker.cc index 117f44c..ad971d3 100644 --- a/icing/scoring/ranker.cc +++ b/icing/scoring/ranker.cc @@ -103,8 +103,7 @@ void HeapifyTermDown(std::vector<TermMetadata>& scored_terms, // If the minimum is not the subtree root, swap and continue heapifying the // lower level subtree. if (min != target_subtree_root_index) { - std::swap(scored_terms.at(min), - scored_terms.at(target_subtree_root_index)); + std::swap(scored_terms.at(min), scored_terms.at(target_subtree_root_index)); HeapifyTermDown(scored_terms, min); } } @@ -146,35 +145,6 @@ TermMetadata PopRootTerm(std::vector<TermMetadata>& scored_terms) { return root; } -// Helper function to extract the root from the heap. The heap structure will be -// maintained. -// -// Returns: -// The current root element on success -// RESOURCE_EXHAUSTED_ERROR if heap is empty -libtextclassifier3::StatusOr<ScoredDocumentHit> PopRoot( - std::vector<ScoredDocumentHit>* scored_document_hits_heap, - const ScoredDocumentHitComparator& scored_document_hit_comparator) { - if (scored_document_hits_heap->empty()) { - // An invalid ScoredDocumentHit - return absl_ports::ResourceExhaustedError("Heap is empty"); - } - - // Steps to extract root from heap: - // 1. copy out root - ScoredDocumentHit root = scored_document_hits_heap->at(0); - const size_t last_node_index = scored_document_hits_heap->size() - 1; - // 2. swap root and the last node - std::swap(scored_document_hits_heap->at(0), - scored_document_hits_heap->at(last_node_index)); - // 3. remove last node - scored_document_hits_heap->pop_back(); - // 4. heapify root - Heapify(scored_document_hits_heap, /*target_subtree_root_index=*/0, - scored_document_hit_comparator); - return root; -} - } // namespace void BuildHeapInPlace( @@ -203,6 +173,29 @@ void PushToTermHeap(TermMetadata term, int number_to_return, } } +libtextclassifier3::StatusOr<ScoredDocumentHit> PopNextTopResultFromHeap( + std::vector<ScoredDocumentHit>* scored_document_hits_heap, + const ScoredDocumentHitComparator& scored_document_hit_comparator) { + if (scored_document_hits_heap->empty()) { + // An invalid ScoredDocumentHit + return absl_ports::ResourceExhaustedError("Heap is empty"); + } + + // Steps to extract root from heap: + // 1. copy out root + ScoredDocumentHit root = scored_document_hits_heap->at(0); + const size_t last_node_index = scored_document_hits_heap->size() - 1; + // 2. swap root and the last node + std::swap(scored_document_hits_heap->at(0), + scored_document_hits_heap->at(last_node_index)); + // 3. remove last node + scored_document_hits_heap->pop_back(); + // 4. heapify root + Heapify(scored_document_hits_heap, /*target_subtree_root_index=*/0, + scored_document_hit_comparator); + return root; +} + std::vector<ScoredDocumentHit> PopTopResultsFromHeap( std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results, const ScoredDocumentHitComparator& scored_document_hit_comparator) { @@ -211,7 +204,8 @@ std::vector<ScoredDocumentHit> PopTopResultsFromHeap( num_results, static_cast<int>(scored_document_hits_heap->size())); while (result_size-- > 0) { libtextclassifier3::StatusOr<ScoredDocumentHit> next_best_document_hit_or = - PopRoot(scored_document_hits_heap, scored_document_hit_comparator); + PopNextTopResultFromHeap(scored_document_hits_heap, + scored_document_hit_comparator); if (next_best_document_hit_or.ok()) { scored_document_hit_result.push_back( std::move(next_best_document_hit_or).ValueOrDie()); diff --git a/icing/scoring/ranker.h b/icing/scoring/ranker.h index 81838f3..bfe1077 100644 --- a/icing/scoring/ranker.h +++ b/icing/scoring/ranker.h @@ -17,6 +17,7 @@ #include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/term-metadata.h" #include "icing/scoring/scored-document-hit.h" @@ -32,6 +33,17 @@ void BuildHeapInPlace( std::vector<ScoredDocumentHit>* scored_document_hits, const ScoredDocumentHitComparator& scored_document_hit_comparator); +// Returns the single next top result (i.e. the current root element) from the +// given heap and remove it from the heap. The heap structure will be +// maintained. +// +// Returns: +// The next top result element on success +// RESOURCE_EXHAUSTED_ERROR if heap is empty +libtextclassifier3::StatusOr<ScoredDocumentHit> PopNextTopResultFromHeap( + std::vector<ScoredDocumentHit>* scored_document_hits_heap, + const ScoredDocumentHitComparator& scored_document_hit_comparator); + // Returns the top num_results results from the given heap and remove those // results from the heap. An empty vector will be returned if heap is empty. // diff --git a/icing/scoring/ranker_benchmark.cc b/icing/scoring/ranker_benchmark.cc index 8983dd9..c2f13de 100644 --- a/icing/scoring/ranker_benchmark.cc +++ b/icing/scoring/ranker_benchmark.cc @@ -27,7 +27,7 @@ namespace { // $ blaze build -c opt --dynamic_mode=off --copt=-gmlt // //icing/scoring:ranker_benchmark // -// $ blaze-bin/icing/scoring/ranker_benchmark --benchmarks=all +// $ blaze-bin/icing/scoring/ranker_benchmark --benchmark_filter=all // --benchmark_memory_usage // // Run on an Android device: @@ -38,7 +38,7 @@ namespace { // $ adb push blaze-bin/icing/scoring/ranker_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/ranker_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/ranker_benchmark --benchmark_filter=all void BM_GetTopN(benchmark::State& state) { int num_to_score = state.range(0); diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index cc1d995..44dda3c 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -49,7 +49,7 @@ // //icing/scoring:score-and-rank_benchmark // // $ blaze-bin/icing/scoring/score-and-rank_benchmark -// --benchmarks=all --benchmark_memory_usage +// --benchmark_filter=all --benchmark_memory_usage // // Run on an Android device: // $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" @@ -59,7 +59,7 @@ // $ adb push blaze-bin/icing/scoring/score-and-rank_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/score-and-rank_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/score-and-rank_benchmark --benchmark_filter=all namespace icing { namespace lib { diff --git a/icing/scoring/scored-document-hits-ranker.h b/icing/scoring/scored-document-hits-ranker.h new file mode 100644 index 0000000..0287452 --- /dev/null +++ b/icing/scoring/scored-document-hits-ranker.h @@ -0,0 +1,53 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SCORED_DOCUMENT_HITS_RANKER_H_ +#define ICING_SCORING_SCORED_DOCUMENT_HITS_RANKER_H_ + +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +// TODO(sungyc): re-evaluate other similar implementations (e.g. std::sort + +// std::queue/std::vector). Also revisit the capacity shrinking +// issue for PopNext(). + +// ScoredDocumentHitsRanker is an interface class for ranking +// ScoredDocumentHits. +class ScoredDocumentHitsRanker { + public: + virtual ~ScoredDocumentHitsRanker() = default; + + // Pop the next top ScoredDocumentHit and return. It is undefined to call + // PopNext on an empty ranker, so the caller should check if it is not empty + // before calling. + virtual ScoredDocumentHit PopNext() = 0; + + // Truncates the remaining ScoredDocumentHits to the given size. The best + // ScoredDocumentHits (according to the ranking policy) should be kept. + // If new_size is invalid (< 0), or greater or equal to # of remaining + // ScoredDocumentHits, then no action will be taken. Otherwise truncates the + // the remaining ScoredDocumentHits to the given size. + virtual void TruncateHitsTo(int new_size) = 0; + + virtual int size() const = 0; + + virtual bool empty() const = 0; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SCORED_DOCUMENT_HITS_RANKER_H_ diff --git a/icing/store/document-log-creator.cc b/icing/store/document-log-creator.cc index 5e23a8e..1739a50 100644 --- a/icing/store/document-log-creator.cc +++ b/icing/store/document-log-creator.cc @@ -18,7 +18,6 @@ #include <string> #include <utility> -#include "icing/text_classifier/lib3/utils/base/logging.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/annotate.h" diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc index 8c8369c..8a79b6d 100644 --- a/icing/store/document-store.cc +++ b/icing/store/document-store.cc @@ -46,13 +46,14 @@ #include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" #include "icing/store/document-log-creator.h" -#include "icing/store/key-mapper.h" +#include "icing/store/dynamic-trie-key-mapper.h" #include "icing/store/namespace-id.h" #include "icing/store/usage-store.h" #include "icing/tokenization/language-segmenter.h" #include "icing/util/clock.h" #include "icing/util/crc32.h" #include "icing/util/data-loss.h" +#include "icing/util/fingerprint-util.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" #include "icing/util/tokenized-document.h" @@ -77,8 +78,8 @@ constexpr char kCorpusIdMapperFilename[] = "corpus_mapper"; // because we allow up to 1 million DocumentIds. constexpr int32_t kUriMapperMaxSize = 36 * 1024 * 1024; // 36 MiB -// 384 KiB for a KeyMapper would allow each internal array to have a max of -// 128 KiB for storage. +// 384 KiB for a DynamicTrieKeyMapper would allow each internal array to have a +// max of 128 KiB for storage. constexpr int32_t kNamespaceMapperMaxSize = 3 * 128 * 1024; // 384 KiB constexpr int32_t kCorpusMapperMaxSize = 3 * 128 * 1024; // 384 KiB @@ -125,22 +126,13 @@ std::string MakeCorpusMapperFilename(const std::string& base_dir) { // overhead per key. As we know that these fingerprints are always 8-bytes in // length and that they're random, we might be able to store them more // compactly. -std::string MakeFingerprint(std::string_view name_space, std::string_view uri) { +std::string MakeFingerprint(std::string_view field1, std::string_view field2) { // Using a 64-bit fingerprint to represent the key could lead to collisions. // But, even with 200K unique keys, the probability of collision is about // one-in-a-billion (https://en.wikipedia.org/wiki/Birthday_attack). uint64_t fprint = - tc3farmhash::Fingerprint64(absl_ports::StrCat(name_space, uri)); - - std::string encoded_fprint; - // DynamicTrie cannot handle keys with '0' as bytes. So, we encode it in - // base128 and add 1 to make sure that no byte is '0'. This increases the - // size of the encoded_fprint from 8-bytes to 10-bytes. - while (fprint) { - encoded_fprint.push_back((fprint & 0x7F) + 1); - fprint >>= 7; - } - return encoded_fprint; + tc3farmhash::Fingerprint64(absl_ports::StrCat(field1, field2)); + return fingerprint_util::GetFingerprintString(fprint); } int64_t CalculateExpirationTimestampMs(int64_t creation_timestamp_ms, @@ -266,12 +258,13 @@ libtextclassifier3::StatusOr<DataLoss> DocumentStore::Initialize( GetRecoveryCause(create_result, force_recovery_and_revalidate_documents); if (recovery_cause != InitializeStatsProto::NONE || create_result.new_file) { - ICING_LOG(WARNING) << "Starting Document Store Recovery with cause=" - << recovery_cause << ", and create result { new_file=" - << create_result.new_file << ", preeisting_file_version=" - << create_result.preexisting_file_version << ", data_loss=" - << create_result.log_create_result.data_loss << "} and kCurrentVersion=" - << DocumentLogCreator::kCurrentVersion; + ICING_LOG(INFO) << "Starting Document Store Recovery with cause=" + << recovery_cause << ", and create result { new_file=" + << create_result.new_file << ", preeisting_file_version=" + << create_result.preexisting_file_version << ", data_loss=" + << create_result.log_create_result.data_loss + << "} and kCurrentVersion=" + << DocumentLogCreator::kCurrentVersion; // We can't rely on any existing derived files. Recreate them from scratch. // Currently happens if: // 1) This is a new log and we don't have derived files yet @@ -348,8 +341,11 @@ libtextclassifier3::Status DocumentStore::InitializeExistingDerivedFiles() { // TODO(b/144458732): Implement a more robust version of TC_ASSIGN_OR_RETURN // that can support error logging. - auto document_key_mapper_or = - KeyMapper<DocumentId>::Create(*filesystem_, base_dir_, kUriMapperMaxSize); + auto document_key_mapper_or = DynamicTrieKeyMapper< + DocumentId, + fingerprint_util::FingerprintStringFormatter>::Create(*filesystem_, + base_dir_, + kUriMapperMaxSize); if (!document_key_mapper_or.ok()) { ICING_LOG(ERROR) << document_key_mapper_or.status().error_message() << "Failed to initialize KeyMapper"; @@ -381,18 +377,23 @@ libtextclassifier3::Status DocumentStore::InitializeExistingDerivedFiles() { ICING_ASSIGN_OR_RETURN( namespace_mapper_, - KeyMapper<NamespaceId>::Create(*filesystem_, - MakeNamespaceMapperFilename(base_dir_), - kNamespaceMapperMaxSize)); + DynamicTrieKeyMapper<NamespaceId>::Create( + *filesystem_, MakeNamespaceMapperFilename(base_dir_), + kNamespaceMapperMaxSize)); ICING_ASSIGN_OR_RETURN( usage_store_, UsageStore::Create(filesystem_, MakeUsageStoreDirectoryName(base_dir_))); - ICING_ASSIGN_OR_RETURN(corpus_mapper_, - KeyMapper<CorpusId>::Create( - *filesystem_, MakeCorpusMapperFilename(base_dir_), - kCorpusMapperMaxSize)); + auto corpus_mapper_or = + DynamicTrieKeyMapper<CorpusId, + fingerprint_util::FingerprintStringFormatter>:: + Create(*filesystem_, MakeCorpusMapperFilename(base_dir_), + kCorpusMapperMaxSize); + if (!corpus_mapper_or.ok()) { + return std::move(corpus_mapper_or).status(); + } + corpus_mapper_ = std::move(corpus_mapper_or).ValueOrDie(); ICING_ASSIGN_OR_RETURN(corpus_score_cache_, FileBackedVector<CorpusAssociatedScoreData>::Create( @@ -561,7 +562,7 @@ libtextclassifier3::Status DocumentStore::ResetDocumentKeyMapper() { // TODO(b/216487496): Implement a more robust version of TC_RETURN_IF_ERROR // that can support error logging. libtextclassifier3::Status status = - KeyMapper<DocumentId>::Delete(*filesystem_, base_dir_); + DynamicTrieKeyMapper<DocumentId>::Delete(*filesystem_, base_dir_); if (!status.ok()) { ICING_LOG(ERROR) << status.error_message() << "Failed to delete old key mapper"; @@ -570,8 +571,11 @@ libtextclassifier3::Status DocumentStore::ResetDocumentKeyMapper() { // TODO(b/216487496): Implement a more robust version of TC_ASSIGN_OR_RETURN // that can support error logging. - auto document_key_mapper_or = - KeyMapper<DocumentId>::Create(*filesystem_, base_dir_, kUriMapperMaxSize); + auto document_key_mapper_or = DynamicTrieKeyMapper< + DocumentId, + fingerprint_util::FingerprintStringFormatter>::Create(*filesystem_, + base_dir_, + kUriMapperMaxSize); if (!document_key_mapper_or.ok()) { ICING_LOG(ERROR) << document_key_mapper_or.status().error_message() << "Failed to re-init key mapper"; @@ -648,7 +652,7 @@ libtextclassifier3::Status DocumentStore::ResetNamespaceMapper() { namespace_mapper_.reset(); // TODO(b/216487496): Implement a more robust version of TC_RETURN_IF_ERROR // that can support error logging. - libtextclassifier3::Status status = KeyMapper<NamespaceId>::Delete( + libtextclassifier3::Status status = DynamicTrieKeyMapper<NamespaceId>::Delete( *filesystem_, MakeNamespaceMapperFilename(base_dir_)); if (!status.ok()) { ICING_LOG(ERROR) << status.error_message() @@ -657,9 +661,9 @@ libtextclassifier3::Status DocumentStore::ResetNamespaceMapper() { } ICING_ASSIGN_OR_RETURN( namespace_mapper_, - KeyMapper<NamespaceId>::Create(*filesystem_, - MakeNamespaceMapperFilename(base_dir_), - kNamespaceMapperMaxSize)); + DynamicTrieKeyMapper<NamespaceId>::Create( + *filesystem_, MakeNamespaceMapperFilename(base_dir_), + kNamespaceMapperMaxSize)); return libtextclassifier3::Status::OK; } @@ -668,17 +672,22 @@ libtextclassifier3::Status DocumentStore::ResetCorpusMapper() { corpus_mapper_.reset(); // TODO(b/216487496): Implement a more robust version of TC_RETURN_IF_ERROR // that can support error logging. - libtextclassifier3::Status status = KeyMapper<CorpusId>::Delete( + libtextclassifier3::Status status = DynamicTrieKeyMapper<CorpusId>::Delete( *filesystem_, MakeCorpusMapperFilename(base_dir_)); if (!status.ok()) { ICING_LOG(ERROR) << status.error_message() << "Failed to delete old corpus_id mapper"; return status; } - ICING_ASSIGN_OR_RETURN(corpus_mapper_, - KeyMapper<CorpusId>::Create( - *filesystem_, MakeCorpusMapperFilename(base_dir_), - kCorpusMapperMaxSize)); + auto corpus_mapper_or = + DynamicTrieKeyMapper<CorpusId, + fingerprint_util::FingerprintStringFormatter>:: + Create(*filesystem_, MakeCorpusMapperFilename(base_dir_), + kCorpusMapperMaxSize); + if (!corpus_mapper_or.ok()) { + return std::move(corpus_mapper_or).status(); + } + corpus_mapper_ = std::move(corpus_mapper_or).ValueOrDie(); return libtextclassifier3::Status::OK; } @@ -931,7 +940,18 @@ libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( DocumentId document_id, bool clear_internal_fields) const { - ICING_RETURN_IF_ERROR(DoesDocumentExistWithStatus(document_id)); + auto document_filter_data_optional_ = GetAliveDocumentFilterData(document_id); + if (!document_filter_data_optional_) { + // The document doesn't exist. Let's check if the document id is invalid, we + // will return InvalidArgumentError. Otherwise we should return NOT_FOUND + // error. + if (!IsDocumentIdValid(document_id)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Document id '%d' invalid.", document_id)); + } + return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( + "Document id '%d' doesn't exist", document_id)); + } auto document_log_offset_or = document_id_mapper_->Get(document_id); if (!document_log_offset_or.ok()) { @@ -991,7 +1011,7 @@ std::vector<std::string> DocumentStore::GetAllNamespaces() const { } const DocumentFilterData* data = status_or_data.ValueOrDie(); - if (InternalDoesDocumentExist(document_id)) { + if (GetAliveDocumentFilterData(document_id)) { existing_namespace_ids.insert(data->namespace_id()); } } @@ -1004,43 +1024,15 @@ std::vector<std::string> DocumentStore::GetAllNamespaces() const { return existing_namespaces; } -bool DocumentStore::DoesDocumentExist(DocumentId document_id) const { - if (!IsDocumentIdValid(document_id)) { - return false; - } - - if (document_id >= document_id_mapper_->num_elements()) { - // Somehow got an validly constructed document_id that the document store - // doesn't know about - return false; - } - - return InternalDoesDocumentExist(document_id); -} - -libtextclassifier3::Status DocumentStore::DoesDocumentExistWithStatus( +std::optional<DocumentFilterData> DocumentStore::GetAliveDocumentFilterData( DocumentId document_id) const { if (!IsDocumentIdValid(document_id)) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "Document id '%d' invalid.", document_id)); + return std::nullopt; } - - if (document_id >= document_id_mapper_->num_elements()) { - // Somehow got a validly constructed document_id that the document store - // doesn't know about. - return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( - "Unknown document id '%d'.", document_id)); + if (IsDeleted(document_id)) { + return std::nullopt; } - - if (!InternalDoesDocumentExist(document_id)) { - return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( - "Document id '%d' doesn't exist", document_id)); - }; - return libtextclassifier3::Status::OK; -} - -bool DocumentStore::InternalDoesDocumentExist(DocumentId document_id) const { - return !IsDeleted(document_id) && !IsExpired(document_id); + return GetNonExpiredDocumentFilterData(document_id); } bool DocumentStore::IsDeleted(DocumentId document_id) const { @@ -1057,21 +1049,27 @@ bool DocumentStore::IsDeleted(DocumentId document_id) const { return file_offset == kDocDeletedFlag; } -bool DocumentStore::IsExpired(DocumentId document_id) const { - auto filter_data_or = filter_cache_->Get(document_id); +// Returns DocumentFilterData if the document is not expired. Otherwise, +// std::nullopt. +std::optional<DocumentFilterData> +DocumentStore::GetNonExpiredDocumentFilterData(DocumentId document_id) const { + auto filter_data_or = filter_cache_->GetCopy(document_id); if (!filter_data_or.ok()) { // This would only happen if document_id is out of range of the // filter_cache, meaning we got some invalid document_id. Callers should // already have checked that their document_id is valid or used // DoesDocumentExist(WithStatus). Regardless, return true since the // document doesn't exist. - return true; + return std::nullopt; } - const DocumentFilterData* filter_data = filter_data_or.ValueOrDie(); + DocumentFilterData document_filter_data = filter_data_or.ValueOrDie(); // Check if it's past the expiration time - return clock_.GetSystemTimeMilliseconds() >= - filter_data->expiration_timestamp_ms(); + if (clock_.GetSystemTimeMilliseconds() >= + document_filter_data.expiration_timestamp_ms()) { + return std::nullopt; + } + return document_filter_data; } libtextclassifier3::Status DocumentStore::Delete( @@ -1088,7 +1086,17 @@ libtextclassifier3::Status DocumentStore::Delete( } libtextclassifier3::Status DocumentStore::Delete(DocumentId document_id) { - ICING_RETURN_IF_ERROR(DoesDocumentExistWithStatus(document_id)); + auto document_filter_data_optional_ = GetAliveDocumentFilterData(document_id); + if (!document_filter_data_optional_) { + // The document doesn't exist. We should return InvalidArgumentError if the + // document id is invalid. Otherwise we should return NOT_FOUND error. + if (!IsDocumentIdValid(document_id)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Document id '%d' invalid.", document_id)); + } + return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( + "Document id '%d' doesn't exist", document_id)); + } auto document_log_offset_or = document_id_mapper_->Get(document_id); if (!document_log_offset_or.ok()) { @@ -1113,7 +1121,7 @@ libtextclassifier3::StatusOr<CorpusId> DocumentStore::GetCorpusId( libtextclassifier3::StatusOr<DocumentAssociatedScoreData> DocumentStore::GetDocumentAssociatedScoreData(DocumentId document_id) const { - if (!DoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( "Can't get usage scores, document id '%d' doesn't exist", document_id)); } @@ -1162,27 +1170,9 @@ DocumentStore::GetCorpusAssociatedScoreDataToUpdate(CorpusId corpus_id) const { return corpus_scoring_data_or.status(); } -libtextclassifier3::StatusOr<DocumentFilterData> -DocumentStore::GetDocumentFilterData(DocumentId document_id) const { - if (!DoesDocumentExist(document_id)) { - return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( - "Can't get filter data, document id '%d' doesn't exist", document_id)); - } - - auto filter_data_or = filter_cache_->GetCopy(document_id); - if (!filter_data_or.ok()) { - ICING_LOG(ERROR) << " while trying to access DocumentId " << document_id - << " from filter_cache_"; - return filter_data_or.status(); - } - DocumentFilterData document_filter_data = - std::move(filter_data_or).ValueOrDie(); - return document_filter_data; -} - libtextclassifier3::StatusOr<UsageStore::UsageScores> DocumentStore::GetUsageScores(DocumentId document_id) const { - if (!DoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( "Can't get usage scores, document id '%d' doesn't exist", document_id)); } @@ -1197,7 +1187,7 @@ libtextclassifier3::Status DocumentStore::ReportUsage( // We can use the internal version here because we got our document_id from // our internal data structures. We would have thrown some error if the // namespace and/or uri were incorrect. - if (!InternalDoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { // Document was probably deleted or expired. return absl_ports::NotFoundError(absl_ports::StrCat( "Couldn't report usage on a nonexistent document: (namespace: '", @@ -1415,7 +1405,7 @@ DocumentStorageInfoProto DocumentStore::CalculateDocumentStatusCounts( UsageStore::UsageScores usage_scores = usage_scores_or.ValueOrDie(); // Update our stats - if (IsExpired(document_id)) { + if (!GetNonExpiredDocumentFilterData(document_id)) { ++total_num_expired; namespace_storage_info.set_num_expired_documents( namespace_storage_info.num_expired_documents() + 1); @@ -1499,8 +1489,11 @@ libtextclassifier3::Status DocumentStore::UpdateSchemaStore( // Update the SchemaTypeId for this entry ICING_ASSIGN_OR_RETURN(SchemaTypeId schema_type_id, schema_store_->GetSchemaTypeId(document.schema())); - filter_cache_->mutable_array()[document_id].set_schema_type_id( - schema_type_id); + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<DocumentFilterData>::MutableView + doc_filter_data_view, + filter_cache_->GetMutable(document_id)); + doc_filter_data_view.Get().set_schema_type_id(schema_type_id); } else { // Document is no longer valid with the new SchemaStore. Mark as // deleted @@ -1529,7 +1522,7 @@ libtextclassifier3::Status DocumentStore::OptimizedUpdateSchemaStore( int size = document_id_mapper_->num_elements(); for (DocumentId document_id = 0; document_id < size; document_id++) { - if (!InternalDoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { // Skip nonexistent documents continue; } @@ -1560,8 +1553,11 @@ libtextclassifier3::Status DocumentStore::OptimizedUpdateSchemaStore( ICING_ASSIGN_OR_RETURN( SchemaTypeId schema_type_id, schema_store_->GetSchemaTypeId(document.schema())); - filter_cache_->mutable_array()[document_id].set_schema_type_id( - schema_type_id); + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<DocumentFilterData>::MutableView + doc_filter_data_view, + filter_cache_->GetMutable(document_id)); + doc_filter_data_view.Get().set_schema_type_id(schema_type_id); } if (revalidate_document) { delete_document = !document_validator_.Validate(document).ok(); @@ -1586,9 +1582,10 @@ libtextclassifier3::Status DocumentStore::Optimize() { return libtextclassifier3::Status::OK; } -libtextclassifier3::Status DocumentStore::OptimizeInto( - const std::string& new_directory, const LanguageSegmenter* lang_segmenter, - OptimizeStatsProto* stats) { +libtextclassifier3::StatusOr<std::vector<DocumentId>> +DocumentStore::OptimizeInto(const std::string& new_directory, + const LanguageSegmenter* lang_segmenter, + OptimizeStatsProto* stats) { // Validates directory if (new_directory == base_dir_) { return absl_ports::InvalidArgumentError( @@ -1606,12 +1603,13 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( int num_deleted = 0; int num_expired = 0; UsageStore::UsageScores default_usage; + std::vector<DocumentId> document_id_old_to_new(size, kInvalidDocumentId); for (DocumentId document_id = 0; document_id < size; document_id++) { auto document_or = Get(document_id, /*clear_internal_fields=*/false); if (absl_ports::IsNotFound(document_or.status())) { if (IsDeleted(document_id)) { ++num_deleted; - } else if (IsExpired(document_id)) { + } else if (!GetNonExpiredDocumentFilterData(document_id)) { ++num_expired; } continue; @@ -1651,6 +1649,8 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( return new_document_id_or.status(); } + document_id_old_to_new[document_id] = new_document_id_or.ValueOrDie(); + // Copy over usage scores. ICING_ASSIGN_OR_RETURN(UsageStore::UsageScores usage_scores, usage_store_->GetUsageScores(document_id)); @@ -1669,7 +1669,7 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( stats->set_num_expired_documents(num_expired); } ICING_RETURN_IF_ERROR(new_doc_store->PersistToDisk(PersistType::FULL)); - return libtextclassifier3::Status::OK; + return document_id_old_to_new; } libtextclassifier3::StatusOr<DocumentStore::OptimizeInfo> @@ -1680,7 +1680,7 @@ DocumentStore::GetOptimizeInfo() const { int32_t num_documents = document_id_mapper_->num_elements(); for (DocumentId document_id = kMinDocumentId; document_id < num_documents; ++document_id) { - if (!InternalDoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { ++optimize_info.optimizable_docs; } @@ -1713,8 +1713,8 @@ DocumentStore::GetOptimizeInfo() const { ICING_ASSIGN_OR_RETURN(const int64_t usage_store_file_size, usage_store_->GetElementsFileSize()); - // We use a combined disk usage and file size for the KeyMapper because it's - // backed by a trie, which has some sparse property bitmaps. + // We use a combined disk usage and file size for the DynamicTrieKeyMapper + // because it's backed by a trie, which has some sparse property bitmaps. ICING_ASSIGN_OR_RETURN(const int64_t document_key_mapper_size, document_key_mapper_->GetElementsSize()); @@ -1794,7 +1794,7 @@ DocumentStore::CollectCorpusInfo() const { const SchemaProto* schema_proto = schema_proto_or.ValueOrDie(); for (DocumentId document_id = 0; document_id < filter_cache_->num_elements(); ++document_id) { - if (!InternalDoesDocumentExist(document_id)) { + if (!GetAliveDocumentFilterData(document_id)) { continue; } ICING_ASSIGN_OR_RETURN(const DocumentFilterData* filter_data, diff --git a/icing/store/document-store.h b/icing/store/document-store.h index e6d2e5c..41dd6a9 100644 --- a/icing/store/document-store.h +++ b/icing/store/document-store.h @@ -48,6 +48,7 @@ #include "icing/util/crc32.h" #include "icing/util/data-loss.h" #include "icing/util/document-validator.h" +#include "icing/util/fingerprint-util.h" namespace icing { namespace lib { @@ -198,19 +199,6 @@ class DocumentStore { // or expired). Order of namespaces is undefined. std::vector<std::string> GetAllNamespaces() const; - // Check if a document exists. Existence means it hasn't been deleted and it - // hasn't expired yet. - // - // NOTE: This should be used when callers don't care about error messages, - // expect documents to be deleted/not found, or in frequently called code - // paths that could cause performance issues. A signficant amount of CPU - // cycles can be saved if we don't construct strings and create new Status - // objects on the heap. See b/185822483. - // - // Returns: - // boolean whether a document exists or not - bool DoesDocumentExist(DocumentId document_id) const; - // Deletes the document identified by the given namespace and uri. The // document proto will be erased immediately. // @@ -280,14 +268,15 @@ class DocumentStore { libtextclassifier3::StatusOr<CorpusAssociatedScoreData> GetCorpusAssociatedScoreData(CorpusId corpus_id) const; - // Returns the DocumentFilterData of the document specified by the DocumentId. + // Gets the document filter data if a document exists. Otherwise, will get a + // false optional. + // + // Existence means it hasn't been deleted and it hasn't expired yet. // // Returns: - // DocumentFilterData on success - // OUT_OF_RANGE if document_id is negative or exceeds previously seen - // DocumentIds - // NOT_FOUND if the document or the filter data is not found - libtextclassifier3::StatusOr<DocumentFilterData> GetDocumentFilterData( + // True:DocumentFilterData if the given document exists. + // False if the given document doesn't exist. + std::optional<DocumentFilterData> GetAliveDocumentFilterData( DocumentId document_id) const; // Gets the usage scores of a document. @@ -399,10 +388,10 @@ class DocumentStore { // method based on device usage. // // Returns: - // OK on success + // A vector that maps from old document id to new document id on success // INVALID_ARGUMENT if new_directory is same as current base directory // INTERNAL_ERROR on IO error - libtextclassifier3::Status OptimizeInto( + libtextclassifier3::StatusOr<std::vector<DocumentId>> OptimizeInto( const std::string& new_directory, const LanguageSegmenter* lang_segmenter, OptimizeStatsProto* stats = nullptr); @@ -455,7 +444,9 @@ class DocumentStore { std::unique_ptr<PortableFileBackedProtoLog<DocumentWrapper>> document_log_; // Key (namespace + uri) to DocumentId mapping - std::unique_ptr<KeyMapper<DocumentId>> document_key_mapper_; + std::unique_ptr< + KeyMapper<DocumentId, fingerprint_util::FingerprintStringFormatter>> + document_key_mapper_; // DocumentId to file offset mapping std::unique_ptr<FileBackedVector<int64_t>> document_id_mapper_; @@ -491,7 +482,9 @@ class DocumentStore { // unique id. A coprus is assigned an // id when the first document belonging to that corpus is added to the // DocumentStore. Corpus ids may be removed from the mapper during compaction. - std::unique_ptr<KeyMapper<CorpusId>> corpus_mapper_; + std::unique_ptr< + KeyMapper<CorpusId, fingerprint_util::FingerprintStringFormatter>> + corpus_mapper_; // A storage class that caches all usage scores. Usage scores are not // considered as ground truth. Usage scores are associated with document ids @@ -648,18 +641,6 @@ class DocumentStore { libtextclassifier3::Status DoesDocumentExistWithStatus( DocumentId document_id) const; - // Check if a document exists. Existence means it hasn't been deleted and it - // hasn't expired yet. - // - // This is for internal-use only because we assume that the document_id is - // already valid. If you're unsure if the document_id is valid, use - // DoesDocumentExist(document_id) instead, which will perform those additional - // checks. - // - // Returns: - // boolean whether a document exists or not - bool InternalDoesDocumentExist(DocumentId document_id) const; - // Checks if a document has been deleted // // This is for internal-use only because we assume that the document_id is @@ -674,7 +655,12 @@ class DocumentStore { // already valid. If you're unsure if the document_id is valid, use // DoesDocumentExist(document_id) instead, which will perform those additional // checks. - bool IsExpired(DocumentId document_id) const; + + // Returns: + // True:DocumentFilterData if the given document isn't expired. + // False if the given doesn't document is expired. + std::optional<DocumentFilterData> GetNonExpiredDocumentFilterData( + DocumentId document_id) const; // Updates the entry in the score cache for document_id. libtextclassifier3::Status UpdateDocumentAssociatedScoreCache( diff --git a/icing/store/document-store_benchmark.cc b/icing/store/document-store_benchmark.cc index fc3fd9d..c4d2346 100644 --- a/icing/store/document-store_benchmark.cc +++ b/icing/store/document-store_benchmark.cc @@ -46,7 +46,7 @@ // //icing/store:document-store_benchmark // // $ blaze-bin/icing/store/document-store_benchmark -// --benchmarks=all --benchmark_memory_usage +// --benchmark_filter=all --benchmark_memory_usage // // Run on an Android device: // $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" @@ -57,7 +57,7 @@ // /data/local/tmp/ // // $ adb shell /data/local/tmp/document-store_benchmark -// --benchmarks=all +// --benchmark_filter=all namespace icing { namespace lib { @@ -164,7 +164,8 @@ void BM_DoesDocumentExistBenchmark(benchmark::State& state) { // Check random document ids to see if they exist. Hopefully to simulate // page faulting in different sections of our mmapped derived files. int document_id = dist(random); - benchmark::DoNotOptimize(document_store->DoesDocumentExist(document_id)); + benchmark::DoNotOptimize( + document_store->GetAliveDocumentFilterData(document_id)); } } BENCHMARK(BM_DoesDocumentExistBenchmark); diff --git a/icing/store/document-store_test.cc b/icing/store/document-store_test.cc index a30b4e4..6f444cb 100644 --- a/icing/store/document-store_test.cc +++ b/icing/store/document-store_test.cc @@ -59,6 +59,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::_; +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Ge; using ::testing::Gt; @@ -358,23 +359,22 @@ TEST_F(DocumentStoreTest, IsDocumentExistingWithoutStatus) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store->Put(DocumentProto(test_document2_))); - EXPECT_THAT(doc_store->DoesDocumentExist(document_id1), IsTrue()); - EXPECT_THAT(doc_store->DoesDocumentExist(document_id2), IsTrue()); + EXPECT_TRUE(doc_store->GetAliveDocumentFilterData(document_id1)); + EXPECT_TRUE(doc_store->GetAliveDocumentFilterData(document_id2)); DocumentId invalid_document_id_negative = -1; - EXPECT_THAT(doc_store->DoesDocumentExist(invalid_document_id_negative), - IsFalse()); + EXPECT_FALSE( + doc_store->GetAliveDocumentFilterData(invalid_document_id_negative)); DocumentId invalid_document_id_greater_than_max = kMaxDocumentId + 2; - EXPECT_THAT( - doc_store->DoesDocumentExist(invalid_document_id_greater_than_max), - IsFalse()); + EXPECT_FALSE(doc_store->GetAliveDocumentFilterData( + invalid_document_id_greater_than_max)); - EXPECT_THAT(doc_store->DoesDocumentExist(kInvalidDocumentId), IsFalse()); + EXPECT_FALSE(doc_store->GetAliveDocumentFilterData(kInvalidDocumentId)); DocumentId invalid_document_id_out_of_range = document_id2 + 1; - EXPECT_THAT(doc_store->DoesDocumentExist(invalid_document_id_out_of_range), - IsFalse()); + EXPECT_FALSE( + doc_store->GetAliveDocumentFilterData(invalid_document_id_out_of_range)); } TEST_F(DocumentStoreTest, GetDeletedDocumentNotFound) { @@ -485,6 +485,35 @@ TEST_F(DocumentStoreTest, DeleteNonexistentDocumentNotFound) { EXPECT_THAT(document_log_size_before, Eq(document_log_size_after)); } +TEST_F(DocumentStoreTest, DeleteNonexistentDocumentPrintableErrorMessage) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> document_store = + std::move(create_result.document_store); + + // Validates that deleting something non-existing won't append anything to + // ground truth + int64_t document_log_size_before = filesystem_.GetFileSize( + absl_ports::StrCat(document_store_dir_, "/", + DocumentLogCreator::GetDocumentLogFilename()) + .c_str()); + + libtextclassifier3::Status status = + document_store->Delete("android$contacts/", "661"); + EXPECT_THAT(status, StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + for (char c : status.error_message()) { + EXPECT_THAT(std::isprint(c), IsTrue()); + } + + int64_t document_log_size_after = filesystem_.GetFileSize( + absl_ports::StrCat(document_store_dir_, "/", + DocumentLogCreator::GetDocumentLogFilename()) + .c_str()); + EXPECT_THAT(document_log_size_before, Eq(document_log_size_after)); +} + TEST_F(DocumentStoreTest, DeleteAlreadyDeletedDocumentNotFound) { ICING_ASSERT_OK_AND_ASSIGN( DocumentStore::CreateResult create_result, @@ -1030,8 +1059,8 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // deleted ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(0, 1, 2))); int64_t optimized_size1 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_EQ(original_size, optimized_size1); @@ -1041,8 +1070,9 @@ TEST_F(DocumentStoreTest, OptimizeInto) { ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); ICING_ASSERT_OK(doc_store->Delete("namespace", "uri1")); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + // DocumentId 0 is removed. + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, 0, 1))); int64_t optimized_size2 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(original_size, Gt(optimized_size2)); @@ -1055,11 +1085,39 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // expired ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + // DocumentId 0 is removed, and DocumentId 2 is expired. + EXPECT_THAT( + doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, 0, kInvalidDocumentId))); int64_t optimized_size3 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(optimized_size2, Gt(optimized_size3)); + + // Delete the last document + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); + ICING_ASSERT_OK(doc_store->Delete("namespace", "uri2")); + // DocumentId 0 and 1 is removed, and DocumentId 2 is expired. + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, kInvalidDocumentId, + kInvalidDocumentId))); + int64_t optimized_size4 = + filesystem_.GetFileSize(optimized_document_log.c_str()); + EXPECT_THAT(optimized_size3, Gt(optimized_size4)); +} + +TEST_F(DocumentStoreTest, OptimizeIntoForEmptyDocumentStore) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + std::string optimized_dir = document_store_dir_ + "_optimize"; + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(IsEmpty())); } TEST_F(DocumentStoreTest, ShouldRecoverFromDataLoss) { @@ -1130,12 +1188,15 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromDataLoss) { StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); EXPECT_THAT(doc_store->Get(document_id2), IsOkAndHolds(EqualsProto(test_document2_))); - // Checks derived filter cache - EXPECT_THAT(doc_store->GetDocumentFilterData(document_id2), - IsOkAndHolds(DocumentFilterData( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id2)); + EXPECT_THAT(doc_filter_data, + Eq(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); + // Checks derived score cache EXPECT_THAT( doc_store->GetDocumentAssociatedScoreData(document_id2), @@ -1220,10 +1281,14 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromCorruptDerivedFile) { IsOkAndHolds(EqualsProto(test_document2_))); // Checks derived filter cache - EXPECT_THAT(doc_store->GetDocumentFilterData(document_id2), - IsOkAndHolds(DocumentFilterData( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id2)); + EXPECT_THAT(doc_filter_data, + Eq(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); + // Checks derived score cache - note that they aren't regenerated from // scratch. EXPECT_THAT( @@ -1293,8 +1358,11 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromBadChecksum) { IsOkAndHolds(EqualsProto(test_document2_))); // Checks derived filter cache - EXPECT_THAT(doc_store->GetDocumentFilterData(document_id2), - IsOkAndHolds(DocumentFilterData( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id2)); + EXPECT_THAT(doc_filter_data, + Eq(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); // Checks derived score cache @@ -1704,8 +1772,7 @@ TEST_F(DocumentStoreTest, NonexistentDocumentFilterDataNotFound) { std::unique_ptr<DocumentStore> doc_store = std::move(create_result.document_store); - EXPECT_THAT(doc_store->GetDocumentFilterData(/*document_id=*/0), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_FALSE(doc_store->GetAliveDocumentFilterData(/*document_id=*/0)); } TEST_F(DocumentStoreTest, DeleteClearsFilterCache) { @@ -1719,17 +1786,17 @@ TEST_F(DocumentStoreTest, DeleteClearsFilterCache) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, doc_store->Put(test_document1_)); - EXPECT_THAT( - doc_store->GetDocumentFilterData(document_id), - IsOkAndHolds(DocumentFilterData( - /*namespace_id=*/0, - /*schema_type_id=*/0, - /*expiration_timestamp_ms=*/document1_expiration_timestamp_))); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id)); + EXPECT_THAT(doc_filter_data, + Eq(DocumentFilterData( + /*namespace_id=*/0, + /*schema_type_id=*/0, document1_expiration_timestamp_))); ICING_ASSERT_OK(doc_store->Delete("icing", "email/1")); // Associated entry of the deleted document is removed. - EXPECT_THAT(doc_store->GetDocumentFilterData(document_id), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_FALSE(doc_store->GetAliveDocumentFilterData(document_id)); } TEST_F(DocumentStoreTest, DeleteClearsScoreCache) { @@ -1857,12 +1924,13 @@ TEST_F(DocumentStoreTest, std::move(create_result.document_store); ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, doc_store->Put(document)); - - EXPECT_THAT( - doc_store->GetDocumentFilterData(document_id), - IsOkAndHolds(DocumentFilterData(/*namespace_id=*/0, - /*schema_type_id=*/0, - /*expiration_timestamp_ms=*/1100))); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id)); + EXPECT_THAT(doc_filter_data, Eq(DocumentFilterData( + /*namespace_id=*/0, + /*schema_type_id=*/0, + /*expiration_timestamp_ms=*/1100))); } TEST_F(DocumentStoreTest, ExpirationTimestampIsInt64MaxIfTtlIsZero) { @@ -1882,9 +1950,13 @@ TEST_F(DocumentStoreTest, ExpirationTimestampIsInt64MaxIfTtlIsZero) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, doc_store->Put(document)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id)); + EXPECT_THAT( - doc_store->GetDocumentFilterData(document_id), - IsOkAndHolds(DocumentFilterData( + doc_filter_data, + Eq(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, /*expiration_timestamp_ms=*/std::numeric_limits<int64_t>::max()))); @@ -1908,9 +1980,13 @@ TEST_F(DocumentStoreTest, ExpirationTimestampIsInt64MaxOnOverflow) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, doc_store->Put(document)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData doc_filter_data, + doc_store->GetAliveDocumentFilterData(document_id)); + EXPECT_THAT( - doc_store->GetDocumentFilterData(document_id), - IsOkAndHolds(DocumentFilterData( + doc_filter_data, + Eq(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, /*expiration_timestamp_ms=*/std::numeric_limits<int64_t>::max()))); @@ -2108,9 +2184,9 @@ TEST_F(DocumentStoreTest, RegenerateDerivedFilesSkipsUnknownSchemaTypeIds) { email_document_id, document_store->Put(DocumentProto(email_document))); EXPECT_THAT(document_store->Get(email_document_id), IsOkAndHolds(EqualsProto(email_document))); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData email_data, - document_store->GetDocumentFilterData(email_document_id)); + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(email_schema_type_id)); email_namespace_id = email_data.namespace_id(); email_expiration_timestamp = email_data.expiration_timestamp_ms(); @@ -2121,9 +2197,9 @@ TEST_F(DocumentStoreTest, RegenerateDerivedFilesSkipsUnknownSchemaTypeIds) { document_store->Put(DocumentProto(message_document))); EXPECT_THAT(document_store->Get(message_document_id), IsOkAndHolds(EqualsProto(message_document))); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData message_data, - document_store->GetDocumentFilterData(message_document_id)); + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(message_schema_type_id)); message_namespace_id = message_data.namespace_id(); message_expiration_timestamp = message_data.expiration_timestamp_ms(); @@ -2161,9 +2237,9 @@ TEST_F(DocumentStoreTest, RegenerateDerivedFilesSkipsUnknownSchemaTypeIds) { // "email" document is fine EXPECT_THAT(document_store->Get(email_document_id), IsOkAndHolds(EqualsProto(email_document))); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData email_data, - document_store->GetDocumentFilterData(email_document_id)); + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(email_schema_type_id)); // Make sure that all the other fields are stll valid/the same EXPECT_THAT(email_data.namespace_id(), Eq(email_namespace_id)); @@ -2173,9 +2249,9 @@ TEST_F(DocumentStoreTest, RegenerateDerivedFilesSkipsUnknownSchemaTypeIds) { // "message" document has an invalid SchemaTypeId EXPECT_THAT(document_store->Get(message_document_id), IsOkAndHolds(EqualsProto(message_document))); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData message_data, - document_store->GetDocumentFilterData(message_document_id)); + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(-1)); // Make sure that all the other fields are stll valid/the same EXPECT_THAT(message_data.namespace_id(), Eq(message_namespace_id)); @@ -2227,16 +2303,16 @@ TEST_F(DocumentStoreTest, UpdateSchemaStoreUpdatesSchemaTypeIds) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId email_document_id, document_store->Put(email_document)); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData email_data, - document_store->GetDocumentFilterData(email_document_id)); + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(old_email_schema_type_id)); ICING_ASSERT_OK_AND_ASSIGN(DocumentId message_document_id, document_store->Put(message_document)); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData message_data, - document_store->GetDocumentFilterData(message_document_id)); + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(old_message_schema_type_id)); // Rearrange the schema types. Since SchemaTypeId is assigned based on order, @@ -2260,12 +2336,14 @@ TEST_F(DocumentStoreTest, UpdateSchemaStoreUpdatesSchemaTypeIds) { ICING_EXPECT_OK(document_store->UpdateSchemaStore(schema_store.get())); // Check that the FilterCache holds the new SchemaTypeIds - ICING_ASSERT_OK_AND_ASSIGN( - email_data, document_store->GetDocumentFilterData(email_document_id)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + email_data, + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(new_email_schema_type_id)); - ICING_ASSERT_OK_AND_ASSIGN( - message_data, document_store->GetDocumentFilterData(message_document_id)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + message_data, + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(new_message_schema_type_id)); } @@ -2457,16 +2535,16 @@ TEST_F(DocumentStoreTest, OptimizedUpdateSchemaStoreUpdatesSchemaTypeIds) { ICING_ASSERT_OK_AND_ASSIGN(DocumentId email_document_id, document_store->Put(email_document)); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData email_data, - document_store->GetDocumentFilterData(email_document_id)); + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(old_email_schema_type_id)); ICING_ASSERT_OK_AND_ASSIGN(DocumentId message_document_id, document_store->Put(message_document)); - ICING_ASSERT_OK_AND_ASSIGN( + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( DocumentFilterData message_data, - document_store->GetDocumentFilterData(message_document_id)); + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(old_message_schema_type_id)); // Rearrange the schema types. Since SchemaTypeId is assigned based on order, @@ -2492,12 +2570,14 @@ TEST_F(DocumentStoreTest, OptimizedUpdateSchemaStoreUpdatesSchemaTypeIds) { schema_store.get(), set_schema_result)); // Check that the FilterCache holds the new SchemaTypeIds - ICING_ASSERT_OK_AND_ASSIGN( - email_data, document_store->GetDocumentFilterData(email_document_id)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + email_data, + document_store->GetAliveDocumentFilterData(email_document_id)); EXPECT_THAT(email_data.schema_type_id(), Eq(new_email_schema_type_id)); - ICING_ASSERT_OK_AND_ASSIGN( - message_data, document_store->GetDocumentFilterData(message_document_id)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + message_data, + document_store->GetAliveDocumentFilterData(message_document_id)); EXPECT_THAT(message_data.schema_type_id(), Eq(new_message_schema_type_id)); } @@ -3379,8 +3459,9 @@ TEST_F(DocumentStoreTest, InitializeForceRecoveryUpdatesTypeIds) { .SetTtlMs(document1_ttl_) .Build(); ICING_ASSERT_OK_AND_ASSIGN(docid, doc_store->Put(doc)); - ICING_ASSERT_OK_AND_ASSIGN(DocumentFilterData filter_data, - doc_store->GetDocumentFilterData(docid)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData filter_data, + doc_store->GetAliveDocumentFilterData(docid)); ASSERT_THAT(filter_data.schema_type_id(), Eq(0)); } @@ -3420,8 +3501,9 @@ TEST_F(DocumentStoreTest, InitializeForceRecoveryUpdatesTypeIds) { std::move(create_result.document_store); // Ensure that the type id of the email document has been correctly updated. - ICING_ASSERT_OK_AND_ASSIGN(DocumentFilterData filter_data, - doc_store->GetDocumentFilterData(docid)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData filter_data, + doc_store->GetAliveDocumentFilterData(docid)); EXPECT_THAT(filter_data.schema_type_id(), Eq(1)); EXPECT_THAT(initialize_stats.document_store_recovery_cause(), Eq(InitializeStatsProto::SCHEMA_CHANGES_OUT_OF_SYNC)); @@ -3477,8 +3559,9 @@ TEST_F(DocumentStoreTest, InitializeDontForceRecoveryDoesntUpdateTypeIds) { .SetTtlMs(document1_ttl_) .Build(); ICING_ASSERT_OK_AND_ASSIGN(docid, doc_store->Put(doc)); - ICING_ASSERT_OK_AND_ASSIGN(DocumentFilterData filter_data, - doc_store->GetDocumentFilterData(docid)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData filter_data, + doc_store->GetAliveDocumentFilterData(docid)); ASSERT_THAT(filter_data.schema_type_id(), Eq(0)); } @@ -3516,8 +3599,9 @@ TEST_F(DocumentStoreTest, InitializeDontForceRecoveryDoesntUpdateTypeIds) { std::move(create_result.document_store); // Check that the type id of the email document has not been updated. - ICING_ASSERT_OK_AND_ASSIGN(DocumentFilterData filter_data, - doc_store->GetDocumentFilterData(docid)); + ICING_ASSERT_HAS_VALUE_AND_ASSIGN( + DocumentFilterData filter_data, + doc_store->GetAliveDocumentFilterData(docid)); ASSERT_THAT(filter_data.schema_type_id(), Eq(0)); } } @@ -3733,7 +3817,6 @@ TEST_F(DocumentStoreTest, InitializeDontForceRecoveryKeepsInvalidDocument) { } } -#ifndef DISABLE_BACKWARDS_COMPAT_TEST TEST_F(DocumentStoreTest, MigrateToPortableFileBackedProtoLog) { // Set up schema. SchemaProto schema = @@ -3854,7 +3937,6 @@ TEST_F(DocumentStoreTest, MigrateToPortableFileBackedProtoLog) { EXPECT_THAT(document_store->Get(/*document_id=*/2), IsOkAndHolds(EqualsProto(document3))); } -#endif // DISABLE_BACKWARDS_COMPAT_TEST TEST_F(DocumentStoreTest, GetDebugInfo) { SchemaProto schema = @@ -3928,8 +4010,9 @@ TEST_F(DocumentStoreTest, GetDebugInfo) { .Build(); ICING_ASSERT_OK(document_store->Put(document4, 2)); - ICING_ASSERT_OK_AND_ASSIGN(DocumentDebugInfoProto out1, - document_store->GetDebugInfo(/*verbosity=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentDebugInfoProto out1, + document_store->GetDebugInfo(DebugInfoVerbosity::DETAILED)); EXPECT_THAT(out1.crc(), Gt(0)); EXPECT_THAT(out1.document_storage_info().num_alive_documents(), Eq(4)); EXPECT_THAT(out1.document_storage_info().num_deleted_documents(), Eq(0)); @@ -3957,8 +4040,9 @@ TEST_F(DocumentStoreTest, GetDebugInfo) { // Delete document3. ICING_ASSERT_OK(document_store->Delete("namespace2", "email/3")); - ICING_ASSERT_OK_AND_ASSIGN(DocumentDebugInfoProto out2, - document_store->GetDebugInfo(/*verbosity=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentDebugInfoProto out2, + document_store->GetDebugInfo(DebugInfoVerbosity::DETAILED)); EXPECT_THAT(out2.crc(), Gt(0)); EXPECT_THAT(out2.crc(), Not(Eq(out1.crc()))); EXPECT_THAT(out2.document_storage_info().num_alive_documents(), Eq(3)); @@ -3970,8 +4054,9 @@ TEST_F(DocumentStoreTest, GetDebugInfo) { UnorderedElementsAre(EqualsProto(info1), EqualsProto(info2), EqualsProto(info3))); - ICING_ASSERT_OK_AND_ASSIGN(DocumentDebugInfoProto out3, - document_store->GetDebugInfo(/*verbosity=*/0)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentDebugInfoProto out3, + document_store->GetDebugInfo(DebugInfoVerbosity::BASIC)); EXPECT_THAT(out3.corpus_info(), IsEmpty()); } @@ -3989,8 +4074,9 @@ TEST_F(DocumentStoreTest, GetDebugInfoWithoutSchema) { schema_store.get())); std::unique_ptr<DocumentStore> document_store = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(DocumentDebugInfoProto out, - document_store->GetDebugInfo(/*verbosity=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentDebugInfoProto out, + document_store->GetDebugInfo(DebugInfoVerbosity::DETAILED)); EXPECT_THAT(out.crc(), Gt(0)); EXPECT_THAT(out.document_storage_info().num_alive_documents(), Eq(0)); EXPECT_THAT(out.document_storage_info().num_deleted_documents(), Eq(0)); @@ -4005,8 +4091,9 @@ TEST_F(DocumentStoreTest, GetDebugInfoForEmptyDocumentStore) { schema_store_.get())); std::unique_ptr<DocumentStore> document_store = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(DocumentDebugInfoProto out, - document_store->GetDebugInfo(/*verbosity=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentDebugInfoProto out, + document_store->GetDebugInfo(DebugInfoVerbosity::DETAILED)); EXPECT_THAT(out.crc(), Gt(0)); EXPECT_THAT(out.document_storage_info().num_alive_documents(), Eq(0)); EXPECT_THAT(out.document_storage_info().num_deleted_documents(), Eq(0)); diff --git a/icing/store/dynamic-trie-key-mapper.h b/icing/store/dynamic-trie-key-mapper.h new file mode 100644 index 0000000..dedd7b9 --- /dev/null +++ b/icing/store/dynamic-trie-key-mapper.h @@ -0,0 +1,299 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_STORE_DYNAMIC_TRIE_KEY_MAPPER_H_ +#define ICING_STORE_DYNAMIC_TRIE_KEY_MAPPER_H_ + +#include <cstdint> +#include <cstring> +#include <memory> +#include <string> +#include <string_view> +#include <type_traits> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" +#include "icing/file/filesystem.h" +#include "icing/legacy/index/icing-dynamic-trie.h" +#include "icing/legacy/index/icing-filesystem.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +// File-backed mapping between the string key and a trivially copyable value +// type. +// +// DynamicTrieKeyMapper is thread-compatible +template <typename T, typename Formatter = absl_ports::DefaultFormatter> +class DynamicTrieKeyMapper : public KeyMapper<T, Formatter> { + public: + // Returns an initialized instance of DynamicTrieKeyMapper that can + // immediately handle read/write operations. + // Returns any encountered IO errors. + // + // base_dir : Base directory used to save all the files required to persist + // DynamicTrieKeyMapper. If this base_dir was previously used to + // create a DynamicTrieKeyMapper, then this existing data would be + // loaded. Otherwise, an empty DynamicTrieKeyMapper would be + // created. + // maximum_size_bytes : The maximum allowable size of the key mapper storage. + static libtextclassifier3::StatusOr< + std::unique_ptr<DynamicTrieKeyMapper<T, Formatter>>> + Create(const Filesystem& filesystem, std::string_view base_dir, + int maximum_size_bytes); + + // Deletes all the files associated with the DynamicTrieKeyMapper. Returns + // success or any encountered IO errors + // + // base_dir : Base directory used to save all the files required to persist + // DynamicTrieKeyMapper. Should be the same as passed into + // Create(). + static libtextclassifier3::Status Delete(const Filesystem& filesystem, + std::string_view base_dir); + + ~DynamicTrieKeyMapper() override = default; + + libtextclassifier3::Status Put(std::string_view key, T value) override; + + libtextclassifier3::StatusOr<T> GetOrPut(std::string_view key, + T next_value) override; + + libtextclassifier3::StatusOr<T> Get(std::string_view key) const override; + + bool Delete(std::string_view key) override; + + std::unordered_map<T, std::string> GetValuesToKeys() const override; + + int32_t num_keys() const override { return trie_.size(); } + + libtextclassifier3::Status PersistToDisk() override; + + libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const override; + + libtextclassifier3::StatusOr<int64_t> GetElementsSize() const override; + + Crc32 ComputeChecksum() override; + + private: + static constexpr char kDynamicTrieKeyMapperDir[] = "key_mapper_dir"; + static constexpr char kDynamicTrieKeyMapperPrefix[] = "key_mapper"; + + // Use DynamicTrieKeyMapper::Create() to instantiate. + explicit DynamicTrieKeyMapper(std::string_view key_mapper_dir); + + // Load any existing DynamicTrieKeyMapper data from disk, or creates a new + // instance of DynamicTrieKeyMapper on disk and gets ready to process + // read/write operations. + // + // Returns any encountered IO errors. + libtextclassifier3::Status Initialize(int maximum_size_bytes); + + const std::string file_prefix_; + + // TODO(adorokhine) Filesystem is a forked class that's available both in + // icing and icing namespaces. We will need icing::Filesystem in order + // to use IcingDynamicTrie. Filesystem class should be fully refactored + // to have a single definition across both namespaces. Such a class should + // use icing (and general google3) coding conventions and behave like + // a proper C++ class. + const IcingFilesystem icing_filesystem_; + IcingDynamicTrie trie_; + + static_assert(std::is_trivially_copyable<T>::value, + "T must be trivially copyable"); +}; + +template <typename T, typename Formatter> +libtextclassifier3::StatusOr< + std::unique_ptr<DynamicTrieKeyMapper<T, Formatter>>> +DynamicTrieKeyMapper<T, Formatter>::Create(const Filesystem& filesystem, + std::string_view base_dir, + int maximum_size_bytes) { + // We create a subdirectory since the trie creates and stores multiple files. + // This makes it easier to isolate the trie files away from other files that + // could potentially be in the same base_dir, and makes it easier to delete. + const std::string key_mapper_dir = + absl_ports::StrCat(base_dir, "/", kDynamicTrieKeyMapperDir); + if (!filesystem.CreateDirectoryRecursively(key_mapper_dir.c_str())) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to create DynamicTrieKeyMapper directory: ", key_mapper_dir)); + } + auto mapper = std::unique_ptr<DynamicTrieKeyMapper<T, Formatter>>( + new DynamicTrieKeyMapper<T, Formatter>(key_mapper_dir)); + ICING_RETURN_IF_ERROR(mapper->Initialize(maximum_size_bytes)); + return mapper; +} + +template <typename T, typename Formatter> +libtextclassifier3::Status DynamicTrieKeyMapper<T, Formatter>::Delete( + const Filesystem& filesystem, std::string_view base_dir) { + std::string key_mapper_dir = + absl_ports::StrCat(base_dir, "/", kDynamicTrieKeyMapperDir); + if (!filesystem.DeleteDirectoryRecursively(key_mapper_dir.c_str())) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to delete DynamicTrieKeyMapper directory: ", key_mapper_dir)); + } + return libtextclassifier3::Status::OK; +} + +template <typename T, typename Formatter> +DynamicTrieKeyMapper<T, Formatter>::DynamicTrieKeyMapper( + std::string_view key_mapper_dir) + : file_prefix_( + absl_ports::StrCat(key_mapper_dir, "/", kDynamicTrieKeyMapperPrefix)), + trie_(file_prefix_, + IcingDynamicTrie::RuntimeOptions().set_storage_policy( + IcingDynamicTrie::RuntimeOptions::kMapSharedWithCrc), + &icing_filesystem_) {} + +template <typename T, typename Formatter> +libtextclassifier3::Status DynamicTrieKeyMapper<T, Formatter>::Initialize( + int maximum_size_bytes) { + IcingDynamicTrie::Options options; + // Divide the max space between the three internal arrays: nodes, nexts and + // suffixes. MaxNodes and MaxNexts are in units of their own data structures. + // MaxSuffixesSize is in units of bytes. + options.max_nodes = maximum_size_bytes / (3 * sizeof(IcingDynamicTrie::Node)); + options.max_nexts = options.max_nodes; + options.max_suffixes_size = + sizeof(IcingDynamicTrie::Node) * options.max_nodes; + options.value_size = sizeof(T); + + if (!trie_.CreateIfNotExist(options)) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to create DynamicTrieKeyMapper file: ", file_prefix_)); + } + if (!trie_.Init()) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to init DynamicTrieKeyMapper file: ", file_prefix_)); + } + return libtextclassifier3::Status::OK; +} + +template <typename T, typename Formatter> +libtextclassifier3::StatusOr<T> DynamicTrieKeyMapper<T, Formatter>::GetOrPut( + std::string_view key, T next_value) { + std::string string_key(key); + uint32_t value_index; + if (!trie_.Insert(string_key.c_str(), &next_value, &value_index, + /*replace=*/false)) { + return absl_ports::InternalError( + absl_ports::StrCat("Unable to insert key ", Formatter()(string_key), + " into DynamicTrieKeyMapper ", file_prefix_, ".")); + } + // This memory address could be unaligned since we're just grabbing the value + // from somewhere in the trie's suffix array. The suffix array is filled with + // chars, so the address might not be aligned to T values. + const T* unaligned_value = + static_cast<const T*>(trie_.GetValueAtIndex(value_index)); + + // memcpy the value to ensure that the returned value here is in a T-aligned + // address + T aligned_value; + memcpy(&aligned_value, unaligned_value, sizeof(T)); + return aligned_value; +} + +template <typename T, typename Formatter> +libtextclassifier3::Status DynamicTrieKeyMapper<T, Formatter>::Put( + std::string_view key, T value) { + std::string string_key(key); + if (!trie_.Insert(string_key.c_str(), &value)) { + return absl_ports::InternalError( + absl_ports::StrCat("Unable to insert key ", Formatter()(string_key), + " into DynamicTrieKeyMapper ", file_prefix_, ".")); + } + return libtextclassifier3::Status::OK; +} + +template <typename T, typename Formatter> +libtextclassifier3::StatusOr<T> DynamicTrieKeyMapper<T, Formatter>::Get( + std::string_view key) const { + std::string string_key(key); + T value; + if (!trie_.Find(string_key.c_str(), &value)) { + return absl_ports::NotFoundError( + absl_ports::StrCat("Key not found ", Formatter()(string_key), + " in DynamicTrieKeyMapper ", file_prefix_, ".")); + } + return value; +} + +template <typename T, typename Formatter> +bool DynamicTrieKeyMapper<T, Formatter>::Delete(std::string_view key) { + return trie_.Delete(key); +} + +template <typename T, typename Formatter> +std::unordered_map<T, std::string> +DynamicTrieKeyMapper<T, Formatter>::GetValuesToKeys() const { + std::unordered_map<T, std::string> values_to_keys; + for (IcingDynamicTrie::Iterator itr(trie_, /*prefix=*/""); itr.IsValid(); + itr.Advance()) { + if (itr.IsValid()) { + T value; + memcpy(&value, itr.GetValue(), sizeof(T)); + values_to_keys.insert({value, itr.GetKey()}); + } + } + + return values_to_keys; +} + +template <typename T, typename Formatter> +libtextclassifier3::Status DynamicTrieKeyMapper<T, Formatter>::PersistToDisk() { + if (!trie_.Sync()) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to sync DynamicTrieKeyMapper file: ", file_prefix_)); + } + + return libtextclassifier3::Status::OK; +} + +template <typename T, typename Formatter> +libtextclassifier3::StatusOr<int64_t> +DynamicTrieKeyMapper<T, Formatter>::GetDiskUsage() const { + int64_t size = trie_.GetDiskUsage(); + if (size == IcingFilesystem::kBadFileSize || size < 0) { + return absl_ports::InternalError("Failed to get disk usage of key mapper"); + } + return size; +} + +template <typename T, typename Formatter> +libtextclassifier3::StatusOr<int64_t> +DynamicTrieKeyMapper<T, Formatter>::GetElementsSize() const { + int64_t size = trie_.GetElementsSize(); + if (size == IcingFilesystem::kBadFileSize || size < 0) { + return absl_ports::InternalError( + "Failed to get disk usage of elements in the key mapper"); + } + return size; +} + +template <typename T, typename Formatter> +Crc32 DynamicTrieKeyMapper<T, Formatter>::ComputeChecksum() { + return Crc32(trie_.UpdateCrc()); +} + +} // namespace lib +} // namespace icing + +#endif // ICING_STORE_DYNAMIC_TRIE_KEY_MAPPER_H_ diff --git a/icing/store/key-mapper_test.cc b/icing/store/dynamic-trie-key-mapper_test.cc index 4e3dd8a..03ba5f2 100644 --- a/icing/store/key-mapper_test.cc +++ b/icing/store/dynamic-trie-key-mapper_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/store/key-mapper.h" +#include "icing/store/dynamic-trie-key-mapper.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -29,9 +29,9 @@ using ::testing::UnorderedElementsAre; namespace icing { namespace lib { namespace { -constexpr int kMaxKeyMapperSize = 3 * 1024 * 1024; // 3 MiB +constexpr int kMaxDynamicTrieKeyMapperSize = 3 * 1024 * 1024; // 3 MiB -class KeyMapperTest : public testing::Test { +class DynamicTrieKeyMapperTest : public testing::Test { protected: void SetUp() override { base_dir_ = GetTestTempDir() + "/key_mapper"; } @@ -43,36 +43,39 @@ class KeyMapperTest : public testing::Test { Filesystem filesystem_; }; -TEST_F(KeyMapperTest, InvalidBaseDir) { - ASSERT_THAT( - KeyMapper<DocumentId>::Create(filesystem_, "/dev/null", kMaxKeyMapperSize) - .status() - .error_message(), - HasSubstr("Failed to create KeyMapper")); +TEST_F(DynamicTrieKeyMapperTest, InvalidBaseDir) { + ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, "/dev/null", kMaxDynamicTrieKeyMapperSize) + .status() + .error_message(), + HasSubstr("Failed to create DynamicTrieKeyMapper")); } -TEST_F(KeyMapperTest, NegativeMaxKeyMapperSizeReturnsInternalError) { - ASSERT_THAT(KeyMapper<DocumentId>::Create(filesystem_, base_dir_, -1), - StatusIs(libtextclassifier3::StatusCode::INTERNAL)); +TEST_F(DynamicTrieKeyMapperTest, NegativeMaxKeyMapperSizeReturnsInternalError) { + ASSERT_THAT( + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, -1), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } -TEST_F(KeyMapperTest, TooLargeMaxKeyMapperSizeReturnsInternalError) { - ASSERT_THAT(KeyMapper<DocumentId>::Create(filesystem_, base_dir_, - std::numeric_limits<int>::max()), +TEST_F(DynamicTrieKeyMapperTest, TooLargeMaxKeyMapperSizeReturnsInternalError) { + ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, base_dir_, std::numeric_limits<int>::max()), StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } -TEST_F(KeyMapperTest, CreateNewKeyMapper) { +TEST_F(DynamicTrieKeyMapperTest, CreateNewKeyMapper) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); EXPECT_THAT(key_mapper->num_keys(), 0); } -TEST_F(KeyMapperTest, CanUpdateSameKeyMultipleTimes) { +TEST_F(DynamicTrieKeyMapperTest, CanUpdateSameKeyMultipleTimes) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 50)); @@ -88,10 +91,11 @@ TEST_F(KeyMapperTest, CanUpdateSameKeyMultipleTimes) { EXPECT_THAT(key_mapper->num_keys(), 2); } -TEST_F(KeyMapperTest, GetOrPutOk) { +TEST_F(DynamicTrieKeyMapperTest, GetOrPutOk) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); EXPECT_THAT(key_mapper->Get("foo"), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); @@ -99,15 +103,16 @@ TEST_F(KeyMapperTest, GetOrPutOk) { EXPECT_THAT(key_mapper->Get("foo"), IsOkAndHolds(1)); } -TEST_F(KeyMapperTest, CanPersistToDiskRegularly) { +TEST_F(DynamicTrieKeyMapperTest, CanPersistToDiskRegularly) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); - // Can persist an empty KeyMapper. + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); + // Can persist an empty DynamicTrieKeyMapper. ICING_EXPECT_OK(key_mapper->PersistToDisk()); EXPECT_THAT(key_mapper->num_keys(), 0); - // Can persist the smallest KeyMapper. + // Can persist the smallest DynamicTrieKeyMapper. ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); ICING_EXPECT_OK(key_mapper->PersistToDisk()); EXPECT_THAT(key_mapper->num_keys(), 1); @@ -124,17 +129,18 @@ TEST_F(KeyMapperTest, CanPersistToDiskRegularly) { EXPECT_THAT(key_mapper->num_keys(), 2); } -TEST_F(KeyMapperTest, CanUseAcrossMultipleInstances) { +TEST_F(DynamicTrieKeyMapperTest, CanUseAcrossMultipleInstances) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); ICING_EXPECT_OK(key_mapper->PersistToDisk()); key_mapper.reset(); ICING_ASSERT_OK_AND_ASSIGN( - key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + key_mapper, DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize)); EXPECT_THAT(key_mapper->num_keys(), 1); EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); @@ -146,30 +152,34 @@ TEST_F(KeyMapperTest, CanUseAcrossMultipleInstances) { EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); } -TEST_F(KeyMapperTest, CanDeleteAndRestartKeyMapping) { +TEST_F(DynamicTrieKeyMapperTest, CanDeleteAndRestartKeyMapping) { // Can delete even if there's nothing there - ICING_EXPECT_OK(KeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); + ICING_EXPECT_OK( + DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); ICING_EXPECT_OK(key_mapper->PersistToDisk()); - ICING_EXPECT_OK(KeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); + ICING_EXPECT_OK( + DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); key_mapper.reset(); ICING_ASSERT_OK_AND_ASSIGN( - key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + key_mapper, DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize)); EXPECT_THAT(key_mapper->num_keys(), 0); ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); EXPECT_THAT(key_mapper->num_keys(), 1); } -TEST_F(KeyMapperTest, GetValuesToKeys) { +TEST_F(DynamicTrieKeyMapperTest, GetValuesToKeys) { ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<KeyMapper<DocumentId>> key_mapper, - KeyMapper<DocumentId>::Create(filesystem_, base_dir_, kMaxKeyMapperSize)); + std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, + DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, + kMaxDynamicTrieKeyMapperSize)); EXPECT_THAT(key_mapper->GetValuesToKeys(), IsEmpty()); ICING_EXPECT_OK(key_mapper->Put("foo", /*value=*/1)); diff --git a/icing/store/key-mapper.h b/icing/store/key-mapper.h index 23c7b69..e05d1b7 100644 --- a/icing/store/key-mapper.h +++ b/icing/store/key-mapper.h @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Google LLC +// Copyright (C) 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,81 +17,56 @@ #include <cstdint> #include <cstring> -#include <memory> #include <string> #include <string_view> #include <type_traits> +#include <unordered_map> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" -#include "icing/absl_ports/canonical_errors.h" -#include "icing/absl_ports/str_cat.h" -#include "icing/file/filesystem.h" -#include "icing/legacy/index/icing-dynamic-trie.h" -#include "icing/legacy/index/icing-filesystem.h" +#include "icing/absl_ports/str_join.h" #include "icing/util/crc32.h" -#include "icing/util/status-macros.h" namespace icing { namespace lib { -// File-backed mapping between the string key and a trivially copyable value -// type. +// An interface for file-backed mapping between the string key and a trivially +// copyable value type. // -// KeyMapper is thread-compatible -template <typename T> +// The implementation for KeyMapper should be thread-compatible +template <typename T, typename Formatter = absl_ports::DefaultFormatter> class KeyMapper { public: - // Returns an initialized instance of KeyMapper that can immediately handle - // read/write operations. - // Returns any encountered IO errors. - // - // base_dir : Base directory used to save all the files required to persist - // KeyMapper. If this base_dir was previously used to create a - // KeyMapper, then this existing data would be loaded. Otherwise, - // an empty KeyMapper would be created. - // maximum_size_bytes : The maximum allowable size of the key mapper storage. - static libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<T>>> Create( - const Filesystem& filesystem, std::string_view base_dir, - int maximum_size_bytes); - - // Deletes all the files associated with the KeyMapper. Returns success or any - // encountered IO errors - // - // base_dir : Base directory used to save all the files required to persist - // KeyMapper. Should be the same as passed into Create(). - static libtextclassifier3::Status Delete(const Filesystem& filesystem, - std::string_view base_dir); - - ~KeyMapper() = default; + virtual ~KeyMapper() = default; // Inserts/Updates value for key. // Returns any encountered IO errors. // // NOTE: Put() doesn't automatically flush changes to disk and relies on // either explicit calls to PersistToDisk() or a clean shutdown of the class. - libtextclassifier3::Status Put(std::string_view key, T value); + virtual libtextclassifier3::Status Put(std::string_view key, T value) = 0; // Finds the current value for key and returns it. If key is not present, it // is inserted with next_value and next_value is returned. // // Returns any IO errors that may occur during Put. - libtextclassifier3::StatusOr<T> GetOrPut(std::string_view key, T next_value); + virtual libtextclassifier3::StatusOr<T> GetOrPut(std::string_view key, + T next_value) = 0; // Returns the value corresponding to the key. // // Returns NOT_FOUND error if the key was missing. // Returns any encountered IO errors. - libtextclassifier3::StatusOr<T> Get(std::string_view key) const; + virtual libtextclassifier3::StatusOr<T> Get(std::string_view key) const = 0; // Deletes data related to the given key. Returns true on success. - bool Delete(std::string_view key); + virtual bool Delete(std::string_view key) = 0; // Returns a map of values to keys. Empty map if the mapper is empty. - std::unordered_map<T, std::string> GetValuesToKeys() const; + virtual std::unordered_map<T, std::string> GetValuesToKeys() const = 0; // Count of unique keys stored in the KeyMapper. - int32_t num_keys() const { return trie_.size(); } + virtual int32_t num_keys() const = 0; // Syncs all the changes made to the KeyMapper to disk. // Returns any encountered IO errors. @@ -103,7 +78,7 @@ class KeyMapper { // Returns: // OK on success // INTERNAL on I/O error - libtextclassifier3::Status PersistToDisk(); + virtual libtextclassifier3::Status PersistToDisk() = 0; // Calculates and returns the disk usage in bytes. Rounds up to the nearest // block size. @@ -111,7 +86,7 @@ class KeyMapper { // Returns: // Disk usage on success // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const; + virtual libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const = 0; // Returns the size of the elements held in the key mapper. This excludes the // size of any internal metadata of the key mapper, e.g. the key mapper's @@ -120,197 +95,16 @@ class KeyMapper { // Returns: // File size on success // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<int64_t> GetElementsSize() const; + virtual libtextclassifier3::StatusOr<int64_t> GetElementsSize() const = 0; // Computes and returns the checksum of the header and contents. - Crc32 ComputeChecksum(); + virtual Crc32 ComputeChecksum() = 0; private: - static constexpr char kKeyMapperDir[] = "key_mapper_dir"; - static constexpr char kKeyMapperPrefix[] = "key_mapper"; - - // Use KeyMapper::Create() to instantiate. - explicit KeyMapper(std::string_view key_mapper_dir); - - // Load any existing KeyMapper data from disk, or creates a new instance - // of KeyMapper on disk and gets ready to process read/write operations. - // - // Returns any encountered IO errors. - libtextclassifier3::Status Initialize(int maximum_size_bytes); - - const std::string file_prefix_; - - // TODO(adorokhine) Filesystem is a forked class that's available both in - // icing and icing namespaces. We will need icing::Filesystem in order - // to use IcingDynamicTrie. Filesystem class should be fully refactored - // to have a single definition across both namespaces. Such a class should - // use icing (and general google3) coding conventions and behave like - // a proper C++ class. - const IcingFilesystem icing_filesystem_; - IcingDynamicTrie trie_; - static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable"); }; -template <typename T> -libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<T>>> -KeyMapper<T>::Create(const Filesystem& filesystem, std::string_view base_dir, - int maximum_size_bytes) { - // We create a subdirectory since the trie creates and stores multiple files. - // This makes it easier to isolate the trie files away from other files that - // could potentially be in the same base_dir, and makes it easier to delete. - const std::string key_mapper_dir = - absl_ports::StrCat(base_dir, "/", kKeyMapperDir); - if (!filesystem.CreateDirectoryRecursively(key_mapper_dir.c_str())) { - return absl_ports::InternalError(absl_ports::StrCat( - "Failed to create KeyMapper directory: ", key_mapper_dir)); - } - auto mapper = std::unique_ptr<KeyMapper<T>>(new KeyMapper<T>(key_mapper_dir)); - ICING_RETURN_IF_ERROR(mapper->Initialize(maximum_size_bytes)); - return mapper; -} - -template <typename T> -libtextclassifier3::Status KeyMapper<T>::Delete(const Filesystem& filesystem, - std::string_view base_dir) { - std::string key_mapper_dir = absl_ports::StrCat(base_dir, "/", kKeyMapperDir); - if (!filesystem.DeleteDirectoryRecursively(key_mapper_dir.c_str())) { - return absl_ports::InternalError(absl_ports::StrCat( - "Failed to delete KeyMapper directory: ", key_mapper_dir)); - } - return libtextclassifier3::Status::OK; -} - -template <typename T> -KeyMapper<T>::KeyMapper(std::string_view key_mapper_dir) - : file_prefix_(absl_ports::StrCat(key_mapper_dir, "/", kKeyMapperPrefix)), - trie_(file_prefix_, - IcingDynamicTrie::RuntimeOptions().set_storage_policy( - IcingDynamicTrie::RuntimeOptions::kMapSharedWithCrc), - &icing_filesystem_) {} - -template <typename T> -libtextclassifier3::Status KeyMapper<T>::Initialize(int maximum_size_bytes) { - IcingDynamicTrie::Options options; - // Divide the max space between the three internal arrays: nodes, nexts and - // suffixes. MaxNodes and MaxNexts are in units of their own data structures. - // MaxSuffixesSize is in units of bytes. - options.max_nodes = maximum_size_bytes / (3 * sizeof(IcingDynamicTrie::Node)); - options.max_nexts = options.max_nodes; - options.max_suffixes_size = - sizeof(IcingDynamicTrie::Node) * options.max_nodes; - options.value_size = sizeof(T); - - if (!trie_.CreateIfNotExist(options)) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to create KeyMapper file: ", file_prefix_)); - } - if (!trie_.Init()) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to init KeyMapper file: ", file_prefix_)); - } - return libtextclassifier3::Status::OK; -} - -template <typename T> -libtextclassifier3::StatusOr<T> KeyMapper<T>::GetOrPut(std::string_view key, - T next_value) { - std::string string_key(key); - uint32_t value_index; - if (!trie_.Insert(string_key.c_str(), &next_value, &value_index, - /*replace=*/false)) { - return absl_ports::InternalError(absl_ports::StrCat( - "Unable to insert key ", key, " into KeyMapper ", file_prefix_, ".")); - } - // This memory address could be unaligned since we're just grabbing the value - // from somewhere in the trie's suffix array. The suffix array is filled with - // chars, so the address might not be aligned to T values. - const T* unaligned_value = - static_cast<const T*>(trie_.GetValueAtIndex(value_index)); - - // memcpy the value to ensure that the returned value here is in a T-aligned - // address - T aligned_value; - memcpy(&aligned_value, unaligned_value, sizeof(T)); - return aligned_value; -} - -template <typename T> -libtextclassifier3::Status KeyMapper<T>::Put(std::string_view key, T value) { - std::string string_key(key); - if (!trie_.Insert(string_key.c_str(), &value)) { - return absl_ports::InternalError(absl_ports::StrCat( - "Unable to insert key ", key, " into KeyMapper ", file_prefix_, ".")); - } - return libtextclassifier3::Status::OK; -} - -template <typename T> -libtextclassifier3::StatusOr<T> KeyMapper<T>::Get(std::string_view key) const { - std::string string_key(key); - T value; - if (!trie_.Find(string_key.c_str(), &value)) { - return absl_ports::NotFoundError(absl_ports::StrCat( - "Key not found ", key, " in KeyMapper ", file_prefix_, ".")); - } - return value; -} - -template <typename T> -bool KeyMapper<T>::Delete(std::string_view key) { - return trie_.Delete(key); -} - -template <typename T> -std::unordered_map<T, std::string> KeyMapper<T>::GetValuesToKeys() const { - std::unordered_map<T, std::string> values_to_keys; - for (IcingDynamicTrie::Iterator itr(trie_, /*prefix=*/""); itr.IsValid(); - itr.Advance()) { - if (itr.IsValid()) { - T value; - memcpy(&value, itr.GetValue(), sizeof(T)); - values_to_keys.insert({value, itr.GetKey()}); - } - } - - return values_to_keys; -} - -template <typename T> -libtextclassifier3::Status KeyMapper<T>::PersistToDisk() { - if (!trie_.Sync()) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to sync KeyMapper file: ", file_prefix_)); - } - - return libtextclassifier3::Status::OK; -} - -template <typename T> -libtextclassifier3::StatusOr<int64_t> KeyMapper<T>::GetDiskUsage() const { - int64_t size = trie_.GetDiskUsage(); - if (size == IcingFilesystem::kBadFileSize || size < 0) { - return absl_ports::InternalError("Failed to get disk usage of key mapper"); - } - return size; -} - -template <typename T> -libtextclassifier3::StatusOr<int64_t> KeyMapper<T>::GetElementsSize() const { - int64_t size = trie_.GetElementsSize(); - if (size == IcingFilesystem::kBadFileSize || size < 0) { - return absl_ports::InternalError( - "Failed to get disk usage of elements in the key mapper"); - } - return size; -} - -template <typename T> -Crc32 KeyMapper<T>::ComputeChecksum() { - return Crc32(trie_.UpdateCrc()); -} - } // namespace lib } // namespace icing diff --git a/icing/store/namespace-checker-impl.h b/icing/store/namespace-checker-impl.h index bcd0643..0b6fca9 100644 --- a/icing/store/namespace-checker-impl.h +++ b/icing/store/namespace-checker-impl.h @@ -32,14 +32,18 @@ class NamespaceCheckerImpl : public NamespaceChecker { target_namespace_ids_(std::move(target_namespace_ids)) {} bool BelongsToTargetNamespaces(DocumentId document_id) const override { + auto document_filter_data_optional_ = + document_store_.GetAliveDocumentFilterData(document_id); + if (!document_filter_data_optional_) { + // The document doesn't exist. + return false; + } if (target_namespace_ids_.empty()) { return true; } - auto document_filter_data_or_ = - document_store_.GetDocumentFilterData(document_id); - return document_filter_data_or_.ok() && - target_namespace_ids_.count( - document_filter_data_or_.ValueOrDie().namespace_id())> 0; + DocumentFilterData document_filter_data = + document_filter_data_optional_.value(); + return target_namespace_ids_.count(document_filter_data.namespace_id()) > 0; } const DocumentStore& document_store_; std::unordered_set<NamespaceId> target_namespace_ids_; diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index f83fe0a..81f65b2 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -460,6 +460,10 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") { ICING_ASSERT_OK(statusor.status()); \ lhs = std::move(statusor).ValueOrDie() +#define ICING_ASSERT_HAS_VALUE_AND_ASSIGN(lhs, rexpr) \ + ASSERT_TRUE(rexpr); \ + lhs = rexpr.value() + } // namespace lib } // namespace icing diff --git a/icing/tokenization/combined-tokenizer_test.cc b/icing/tokenization/combined-tokenizer_test.cc index 0212e4f..42c7743 100644 --- a/icing/tokenization/combined-tokenizer_test.cc +++ b/icing/tokenization/combined-tokenizer_test.cc @@ -15,19 +15,19 @@ #include <string_view> #include <vector> -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" -#include "third_party/icing/portable/platform.h" -#include "third_party/icing/proto/schema_proto_portable.pb.h" -#include "third_party/icing/testing/common-matchers.h" -#include "third_party/icing/testing/icu-data-file-helper.h" -#include "third_party/icing/testing/jni-test-helpers.h" -#include "third_party/icing/testing/test-data.h" -#include "third_party/icing/tokenization/language-segmenter-factory.h" -#include "third_party/icing/tokenization/language-segmenter.h" -#include "third_party/icing/tokenization/tokenizer-factory.h" -#include "third_party/icing/tokenization/tokenizer.h" -#include "third_party/icu/include/unicode/uloc.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/portable/platform.h" +#include "icing/proto/schema.pb.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/testing/test-data.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/tokenization/tokenizer-factory.h" +#include "icing/tokenization/tokenizer.h" +#include "unicode/uloc.h" namespace icing { namespace lib { @@ -43,9 +43,9 @@ class CombinedTokenizerTest : public ::testing::Test { void SetUp() override { if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { ICING_ASSERT_OK( - // File generated via icu_data_file rule in //third_party/icing/BUILD. + // File generated via icu_data_file rule in //icing/BUILD. icu_data_file_helper::SetUpICUDataFile( - GetTestFilePath("third_party/icing/icu.dat"))); + GetTestFilePath("icing/icu.dat"))); } jni_cache_ = GetTestJniCache(); diff --git a/icing/tokenization/icu/icu-language-segmenter_test.cc b/icing/tokenization/icu/icu-language-segmenter_test.cc index 4098be5..71e04e2 100644 --- a/icing/tokenization/icu/icu-language-segmenter_test.cc +++ b/icing/tokenization/icu/icu-language-segmenter_test.cc @@ -15,12 +15,12 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" +#include "icing/jni/jni-cache.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/icu-i18n-test-utils.h" diff --git a/icing/tokenization/language-segmenter-factory.h b/icing/tokenization/language-segmenter-factory.h index cae3eee..2505a07 100644 --- a/icing/tokenization/language-segmenter-factory.h +++ b/icing/tokenization/language-segmenter-factory.h @@ -18,9 +18,8 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" - #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter.h" namespace icing { diff --git a/icing/tokenization/language-segmenter_benchmark.cc b/icing/tokenization/language-segmenter_benchmark.cc index 6f7d4df..748a322 100644 --- a/icing/tokenization/language-segmenter_benchmark.cc +++ b/icing/tokenization/language-segmenter_benchmark.cc @@ -27,7 +27,7 @@ // //icing/tokenization:language-segmenter_benchmark // // $ blaze-bin/icing/tokenization/language-segmenter_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // Make target //icing/tokenization:language-segmenter depend on @@ -41,7 +41,7 @@ // blaze-bin/icing/tokenization/language-segmenter_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/language-segmenter_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/language-segmenter_benchmark --benchmark_filter=all // --adb // Flag to tell the benchmark that it'll be run on an Android device via adb, diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc index 8e1e563..dbd7f5a 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc @@ -21,11 +21,11 @@ #include <cmath> #include <map> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/text_classifier/lib3/utils/java/jni-base.h" #include "icing/text_classifier/lib3/utils/java/jni-helper.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/jni/jni-cache.h" #include "icing/util/status-macros.h" namespace icing { diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h index 41b470c..537666c 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h @@ -20,8 +20,8 @@ #include <queue> #include <string> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/java/jni-base.h" +#include "icing/jni/jni-cache.h" namespace icing { namespace lib { diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc index 0da4c2d..a251f90 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/jni/jni-cache.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter-factory.h" #include "icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h" #include "icing/util/logging.h" diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc index e5de6e6..bd80718 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc @@ -74,6 +74,7 @@ class ReverseJniLanguageSegmenterIterator : public LanguageSegmenter::Iterator { MarkAsDone(); return false; } + return true; } diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h index f06dac9..29df4ee 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h @@ -21,8 +21,8 @@ #include <string_view> #include <vector> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter.h" namespace icing { diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc index 277ece6..47a01fe 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc @@ -17,11 +17,11 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "icing/absl_ports/str_cat.h" +#include "icing/jni/jni-cache.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-i18n-test-utils.h" #include "icing/testing/jni-test-helpers.h" @@ -423,7 +423,6 @@ TEST_P(ReverseJniLanguageSegmenterTest, CJKT) { // Khmer EXPECT_THAT(language_segmenter->GetAllTerms("ញុំដើរទៅធ្វើការរាល់ថ្ងៃ។"), IsOkAndHolds(ElementsAre("ញុំ", "ដើរទៅ", "ធ្វើការ", "រាល់ថ្ងៃ", "។"))); - // Thai EXPECT_THAT( language_segmenter->GetAllTerms("ฉันเดินไปทำงานทุกวัน"), diff --git a/icing/tokenization/rfc822-tokenizer.cc b/icing/tokenization/rfc822-tokenizer.cc new file mode 100644 index 0000000..4a96783 --- /dev/null +++ b/icing/tokenization/rfc822-tokenizer.cc @@ -0,0 +1,565 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/tokenization/rfc822-tokenizer.h" + +#include <algorithm> +#include <deque> +#include <queue> +#include <string_view> +#include <utility> + +#include "icing/tokenization/token.h" +#include "icing/tokenization/tokenizer.h" +#include "icing/util/character-iterator.h" +#include "icing/util/i18n-utils.h" +#include "icing/util/status-macros.h" +#include "unicode/umachine.h" + +namespace icing { +namespace lib { + +class Rfc822TokenIterator : public Tokenizer::Iterator { + public: + // Cursor is the index into the string_view, text_end_ is the length. + explicit Rfc822TokenIterator(std::string_view text) + : term_(std::move(text)), + iterator_(text, 0, 0, 0), + text_end_(text.length()) {} + + struct NameInfo { + NameInfo(const char* at_sign, bool name_found) + : at_sign(at_sign), name_found(name_found) {} + const char* at_sign; + bool name_found; + }; + + bool Advance() override { + // Advance through the queue. + if (!token_queue_.empty()) { + token_queue_.pop_front(); + } + + // There is still something left. + if (!token_queue_.empty()) { + return true; + } + + // Done with the entire string_view + if (iterator_.utf8_index() >= text_end_) { + return false; + } + + AdvancePastWhitespace(); + + GetNextRfc822Token(); + + return true; + } + + // Advance until the next email delimiter, generating as many tokens as + // necessary. + void GetNextRfc822Token() { + int token_start = iterator_.utf8_index(); + const char* at_sign_in_name = nullptr; + bool address_found = false; + bool name_found = false; + // We start at unquoted and run until a ",;\n<( . + while (iterator_.utf8_index() < text_end_) { + UChar32 c = iterator_.GetCurrentChar(); + if (c == ',' || c == ';' || c == '\n') { + // End of the token, advance cursor past this then quit + token_queue_.push_back(Token( + Token::Type::RFC822_TOKEN, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + AdvanceCursor(); + break; + } + + if (c == '"') { + NameInfo quoted_result = ConsumeQuotedSection(); + if (quoted_result.at_sign != nullptr) { + at_sign_in_name = quoted_result.at_sign; + } + if (!name_found) { + name_found = quoted_result.name_found; + } + } else if (c == '(') { + ConsumeParenthesizedSection(); + } else if (c == '<') { + // Only set address_found to true if ConsumeAdress returns true. + // Otherwise, keep address_found as is to prevent setting address_found + // back to false if it is true + if (ConsumeAddress()) { + address_found = true; + } + } else { + NameInfo unquoted_result = ConsumeUnquotedSection(); + if (unquoted_result.at_sign != nullptr) { + at_sign_in_name = unquoted_result.at_sign; + } + if (!name_found) { + name_found = unquoted_result.name_found; + } + } + } + if (iterator_.utf8_index() >= text_end_) { + token_queue_.push_back( + Token(Token::Type::RFC822_TOKEN, + term_.substr(token_start, text_end_ - token_start))); + } + + // At this point the token_queue is not empty. + // If an address is found, use the tokens we have + // If an address isn't found, and a name isn't found, also use the tokens + // we have. + // If an address isn't found but a name is, convert name Tokens to email + // Tokens + if (!address_found && name_found) { + ConvertNameToEmail(at_sign_in_name); + } + } + + void ConvertNameToEmail(const char* at_sign_in_name) { + // The name tokens will be will be used as the address now + const char* address_start = nullptr; + const char* local_address_end = nullptr; + const char* address_end = term_.begin(); + + // If we need to transform name tokens into various tokens, we keep the + // order of which the name tokens appeared. Name tokens that appear before + // an @ sign in the name will become RFC822_ADDRESS_COMPONENT_LOCAL, and + // those after will become RFC822_ADDRESS_COMPONENT_HOST. We aren't able + // to determine RFC822_ADDRESS and RFC822_LOCAL_ADDRESS before checking + // the name tokens, so they will be added after the component tokens. + + for (Token& token : token_queue_) { + if (token.type == Token::Type::RFC822_NAME) { + // Names need to be converted to address tokens + std::string_view text = token.text; + + // Find the ADDRESS and LOCAL_ADDRESS. + if (address_start == nullptr) { + address_start = text.begin(); + } + + if (at_sign_in_name >= text.end()) { + local_address_end = text.end(); + } + + address_end = text.end(); + + if (text.begin() < at_sign_in_name) { + token = Token(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, text); + } else if (text.begin() > at_sign_in_name) { + token = Token(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, text); + } + } + } + + token_queue_.push_back( + Token(Token::Type::RFC822_ADDRESS, + std::string_view(address_start, address_end - address_start))); + + if (local_address_end != nullptr) { + token_queue_.push_back(Token( + Token::Type::RFC822_LOCAL_ADDRESS, + std::string_view(address_start, local_address_end - address_start))); + } + } + + // Returns the location of the last at sign in the unquoted section, and if + // we have found a name. This is useful in case we do not find an address + // and have to use the name. An unquoted section may look like "Alex Sav", or + // "alex@google.com". In the absense of a bracketed email address, the + // unquoted section will be used as the email address along with the quoted + // section. + NameInfo ConsumeUnquotedSection() { + const char* at_sign_location = nullptr; + UChar32 c; + + int token_start = -1; + bool name_found = false; + + // Advance to another state or a character marking the end of token, one + // of \n,; . + while (iterator_.utf8_index() < text_end_) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + name_found = true; + + if (token_start == -1) { + // Start recording + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + if (token_start != -1) { + if (c == '@') { + // Mark the last @ sign. + at_sign_location = term_.data() + iterator_.utf8_index(); + } + + // The character is non alphabetic, save a token. + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '"' || c == '<' || c == '(' || c == '\n' || c == ';' || + c == ',') { + // Stay on the token. + break; + } + + AdvanceCursor(); + } + } + if (token_start != -1) { + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + return NameInfo(at_sign_location, name_found); + } + + // Names that are within quotes should have all characters blindly unescaped. + // When a name is made into an address, it isn't re-escaped. + + // Returns the location of the last at sign in the quoted section. This is + // useful in case we do not find an address and have to use the name. The + // quoted section may contain whitespaces + NameInfo ConsumeQuotedSection() { + // Get past the first quote. + AdvanceCursor(); + const char* at_sign_location = nullptr; + + bool end_quote_found = false; + bool name_found = false; + UChar32 c; + + int token_start = -1; + + while (!end_quote_found && (iterator_.utf8_index() < text_end_)) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + name_found = true; + + if (token_start == -1) { + // Start tracking the token. + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + // Non- alphabetic + if (c == '\\') { + // A backslash, let's look at the next character. + CharacterIterator temp = iterator_; + temp.AdvanceToUtf32(iterator_.utf32_index() + 1); + UChar32 n = temp.GetCurrentChar(); + if (i18n_utils::IsAlphaNumeric(n)) { + // The next character is alphabetic, skip the slash and don't end + // the last token. For quoted sections, the only things that are + // escaped are double quotes and slashes. For example, in "a\lex", + // an l appears after the slash. We want to treat this as if it was + // just "alex". So we tokenize it as <RFC822_NAME, "a\lex">. + AdvanceCursor(); + } else { + // Not alphabetic, so save the last token if necessary. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_NAME, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + // Skip the backslash. + AdvanceCursor(); + + if (n == '"' || n == '\\' || n == '@') { + // Skip these too if they're next. + AdvanceCursor(); + } + } + + } else { + // Not a backslash. + + if (c == '@') { + // Mark the last @ sign. + at_sign_location = term_.data() + iterator_.utf8_index(); + } + + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_NAME, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '"') { + end_quote_found = true; + } + // Advance one more time to get past the non-alphabetic character. + AdvanceCursor(); + } + } + } + if (token_start != -1) { + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + return NameInfo(at_sign_location, name_found); + } + + // '(', ')', '\\' chars should be escaped. All other escaped chars should be + // unescaped. + void ConsumeParenthesizedSection() { + // Skip the initial ( + AdvanceCursor(); + + int paren_layer = 1; + UChar32 c; + + int token_start = -1; + + while (paren_layer > 0 && (iterator_.utf8_index() < text_end_)) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + if (token_start == -1) { + // Start tracking a token. + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + // Non alphabetic. + if (c == '\\') { + // A backslash, let's look at the next character. + UChar32 n = i18n_utils::GetUChar32At(term_.data(), term_.length(), + iterator_.utf8_index() + 1); + if (i18n_utils::IsAlphaNumeric(n)) { + // Alphabetic, skip the slash and don't end the last token. + AdvanceCursor(); + } else { + // Not alphabetic, save the last token if necessary. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_COMMENT, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + // Skip the backslash. + AdvanceCursor(); + + if (n == ')' || n == '(' || n == '\\') { + // Skip these too if they're next. + AdvanceCursor(); + } + } + } else { + // Not a backslash. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_COMMENT, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '(') { + paren_layer++; + } else if (c == ')') { + paren_layer--; + } + AdvanceCursor(); + } + } + } + + if (token_start != -1) { + // Ran past the end of term_ without getting the last token. + + // substr returns "a view of the substring [pos, pos + // rcount), where + // rcount is the smaller of count and size() - pos" therefore the count + // argument can be any value >= this->cursor - token_start. Therefore, + // ignoring the mutation warning. + token_queue_.push_back(Token( + Token::Type::RFC822_COMMENT, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + } + + // Returns true if we find an address. + bool ConsumeAddress() { + // Skip the first <. + AdvanceCursor(); + + // Save the start position. + CharacterIterator address_start_iterator = iterator_; + + int at_sign = -1; + int address_end = -1; + + UChar32 c = iterator_.GetCurrentChar(); + // Quick scan for @ and > signs. + while (c != '>' && iterator_.utf8_index() < text_end_) { + AdvanceCursor(); + c = iterator_.GetCurrentChar(); + if (c == '@') { + at_sign = iterator_.utf8_index(); + } + } + + if (iterator_.utf8_index() <= address_start_iterator.utf8_index()) { + // There is nothing between the brackets, either we have "<" or "<>" + return false; + } + + // Either we find a > or run to the end, either way this is the end of the + // address. The ending bracket will be handled by ConsumeUnquoted. + address_end = iterator_.utf8_index(); + + // Reset to the start. + iterator_ = address_start_iterator; + + int address_start = address_start_iterator.utf8_index(); + + Token::Type type = Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL; + + // Create a local address token. + if (at_sign != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_LOCAL_ADDRESS, + term_.substr(address_start, at_sign - address_start))); + } else { + // All the tokens in the address are host components. + type = Token::Type::RFC822_ADDRESS_COMPONENT_HOST; + } + + token_queue_.push_back( + Token(Token::Type::RFC822_ADDRESS, + term_.substr(address_start, address_end - address_start))); + + int token_start = -1; + + while (iterator_.utf8_index() < address_end) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + if (token_start == -1) { + token_start = iterator_.utf8_index(); + } + + } else { + // non alphabetic + if (c == '\\') { + // A backslash, let's look at the next character. + CharacterIterator temp = iterator_; + temp.AdvanceToUtf32(iterator_.utf32_index() + 1); + UChar32 n = temp.GetCurrentChar(); + if (!i18n_utils::IsAlphaNumeric(n)) { + // Not alphabetic, end the last token if necessary. + if (token_start != -1) { + token_queue_.push_back(Token( + type, term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + } + } else { + // Not backslash. + if (token_start != -1) { + token_queue_.push_back(Token( + type, term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + // Switch to host component tokens. + if (iterator_.utf8_index() == at_sign) { + type = Token::Type::RFC822_ADDRESS_COMPONENT_HOST; + } + } + } + AdvanceCursor(); + } + if (token_start != -1) { + token_queue_.push_back(Token( + type, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + // Unquoted will handle the closing bracket > if these is one. + return true; + } + + Token GetToken() const override { + if (token_queue_.empty()) { + return Token(Token::Type::INVALID, term_); + } + return token_queue_.front(); + } + + private: + void AdvanceCursor() { + iterator_.AdvanceToUtf32(iterator_.utf32_index() + 1); + } + + void AdvancePastWhitespace() { + while (i18n_utils::IsWhitespaceAt(term_, iterator_.utf8_index())) { + AdvanceCursor(); + } + } + + std::string_view term_; + CharacterIterator iterator_; + int text_end_; + + // A temporary store of Tokens. As we advance through the provided string, we + // parse entire addresses at a time rather than one token at a time. However, + // since we call the tokenizer with Advance() alternating with GetToken(), we + // need to store tokens for subsequent GetToken calls if Advance generates + // multiple tokens (it usually does). A queue is used as we want the first + // token generated to be the first token returned from GetToken. + std::deque<Token> token_queue_; +}; + +libtextclassifier3::StatusOr<std::unique_ptr<Tokenizer::Iterator>> +Rfc822Tokenizer::Tokenize(std::string_view text) const { + return std::make_unique<Rfc822TokenIterator>(text); +} + +libtextclassifier3::StatusOr<std::vector<Token>> Rfc822Tokenizer::TokenizeAll( + std::string_view text) const { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> iterator, + Tokenize(text)); + std::vector<Token> tokens; + while (iterator->Advance()) { + tokens.push_back(iterator->GetToken()); + } + return tokens; +} + +} // namespace lib +} // namespace icing diff --git a/icing/absl_ports/status_imports.h b/icing/tokenization/rfc822-tokenizer.h index 3a97fd6..09e4624 100644 --- a/icing/absl_ports/status_imports.h +++ b/icing/tokenization/rfc822-tokenizer.h @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Google LLC +// Copyright (C) 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_ABSL_PORTS_STATUS_IMPORTS_H_ -#define ICING_ABSL_PORTS_STATUS_IMPORTS_H_ +#ifndef ICING_TOKENIZATION_RFC822_TOKENIZER_H_ +#define ICING_TOKENIZATION_RFC822_TOKENIZER_H_ -#include "icing/text_classifier/lib3/utils/base/status.h" +#include <vector> + +#include "icing/tokenization/tokenizer.h" namespace icing { namespace lib { -namespace absl_ports { -// TODO(b/144458732) Delete this file once visibility on TC3 Status has been -// granted to the sample app. -using Status = libtextclassifier3::Status; +class Rfc822Tokenizer : public Tokenizer { + public: + libtextclassifier3::StatusOr<std::unique_ptr<Tokenizer::Iterator>> Tokenize( + std::string_view text) const override; + + libtextclassifier3::StatusOr<std::vector<Token>> TokenizeAll( + std::string_view text) const override; + +}; -} // namespace absl_ports } // namespace lib } // namespace icing -#endif // ICING_ABSL_PORTS_STATUS_IMPORTS_H_ +#endif // ICING_TOKENIZATION_RFC822_TOKENIZER_H_ diff --git a/icing/tokenization/rfc822-tokenizer_test.cc b/icing/tokenization/rfc822-tokenizer_test.cc new file mode 100644 index 0000000..e3c6da6 --- /dev/null +++ b/icing/tokenization/rfc822-tokenizer_test.cc @@ -0,0 +1,797 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/tokenization/rfc822-tokenizer.h" + +#include <memory> +#include <string> +#include <string_view> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { +namespace { +using ::testing::ElementsAre; + +class Rfc822TokenizerTest : public testing::Test { + protected: + void SetUp() override { + jni_cache_ = GetTestJniCache(); + language_segmenter_factory::SegmenterOptions options(ULOC_US, + jni_cache_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + } + std::unique_ptr<const JniCache> jni_cache_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; +}; + +TEST_F(Rfc822TokenizerTest, Simple) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("<你alex@google.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "你alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<你alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, Small) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("\"a\""); + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"a\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "a")))); + + s = "\"a\", \"b\""; + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"a\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "a"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "b"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"b\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "b")))); + + s = "(a)"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre(EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "(a)")))); +} + +TEST_F(Rfc822TokenizerTest, PB) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("peanut (comment) butter, <alex@google.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_ADDRESS, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, NoBrackets) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("alex@google.com"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex")))); +} + +TEST_F(Rfc822TokenizerTest, TwoAddresses) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("<你alex@google.com>; <alexsav@gmail.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "你alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<你alex@google.com>"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alexsav@gmail.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gmail"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alexsav@gmail.com>")))); +} + +TEST_F(Rfc822TokenizerTest, CommentB) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("(a comment) <alex@google.com>"); + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "(a comment) <alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, NameAndComment) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("\"a name\" also a name <alex@google.com>"); + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_NAME, "also"), + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "\"a name\" also a name <alex@google.com>")))); +} + +// Test from tokenizer_test.cc. +TEST_F(Rfc822TokenizerTest, Rfc822SanityCheck) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string addr1("A name (A comment) <address@domain.com>"); + std::string addr2( + "\"(Another name)\" (A different comment) " + "<bob-loblaw@foo.bar.com>"); + std::string addr3("<no.at.sign.present>"); + std::string addr4("<double@at@signs.present>"); + std::string rfc822 = addr1 + ", " + addr2 + ", " + addr3 + ", " + addr4; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(rfc822), + IsOkAndHolds(ElementsAre( + + EqualsToken(Token::Type::RFC822_NAME, "A"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "A"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address@domain.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "domain"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, addr1), + + EqualsToken(Token::Type::RFC822_NAME, "Another"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "A"), + EqualsToken(Token::Type::RFC822_COMMENT, "different"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "bob-loblaw"), + EqualsToken(Token::Type::RFC822_ADDRESS, "bob-loblaw@foo.bar.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "bob"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "loblaw"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, addr2), + + EqualsToken(Token::Type::RFC822_ADDRESS, "no.at.sign.present"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "no"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "at"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "sign"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "present"), + EqualsToken(Token::Type::RFC822_TOKEN, addr3), + + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "double@at"), + EqualsToken(Token::Type::RFC822_ADDRESS, "double@at@signs.present"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "double"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "at"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "signs"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "present"), + EqualsToken(Token::Type::RFC822_TOKEN, addr4)))); +} + +// Tests from rfc822 converter. +TEST_F(Rfc822TokenizerTest, SimpleRfcText) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string test_string = + "foo@google.com,bar@google.com,baz@google.com,foo+hello@google.com,baz@" + "corp.google.com"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(test_string), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "foo@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "bar@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "bar@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "baz@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "baz@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "hello"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "foo+hello@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo+hello@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo+hello"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "corp"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "baz@corp.google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "baz@corp.google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "baz")))); +} + +TEST_F(Rfc822TokenizerTest, ComplicatedRfcText) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string test_string = + R"raw("Weird, But&(Also)\\Valid" Name (!With, "an" \\odd\\ cmt too¡) <Foo B(a)r,Baz@g.co> + <easy@google.com>)raw"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(test_string), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Weird"), + EqualsToken(Token::Type::RFC822_NAME, "But"), + EqualsToken(Token::Type::RFC822_NAME, "Also"), + EqualsToken(Token::Type::RFC822_NAME, "Valid"), + EqualsToken(Token::Type::RFC822_NAME, "Name"), + EqualsToken(Token::Type::RFC822_COMMENT, "With"), + EqualsToken(Token::Type::RFC822_COMMENT, "an"), + EqualsToken(Token::Type::RFC822_COMMENT, "odd"), + EqualsToken(Token::Type::RFC822_COMMENT, "cmt"), + EqualsToken(Token::Type::RFC822_COMMENT, "too"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "Foo B(a)r,Baz"), + EqualsToken(Token::Type::RFC822_ADDRESS, "Foo B(a)r,Baz@g.co"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "Foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "B"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "a"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "r"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "Baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "g"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "co"), + EqualsToken( + Token::Type::RFC822_TOKEN, + R"raw("Weird, But&(Also)\\Valid" Name (!With, "an" \\odd\\ cmt too¡) <Foo B(a)r,Baz@g.co>)raw"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "easy"), + EqualsToken(Token::Type::RFC822_ADDRESS, "easy@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "easy"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<easy@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, FromHtmlBugs) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // This input used to cause HTML parsing exception. We don't do HTML parsing + // any more (b/8388100) so we are just checking that it does not crash and + // that it retains the input. + + // http://b/8988210. Put crashing string "&\r" x 100 into name and comment + // field of rfc822 token. + + std::string s("\""); + for (int i = 0; i < 100; i++) { + s.append("&\r"); + } + s.append("\" ("); + for (int i = 0; i < 100; i++) { + s.append("&\r"); + } + s.append(") <foo@google.com>"); + + // It shouldn't change anything + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, s)))); +} + +TEST_F(Rfc822TokenizerTest, EmptyComponentsTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(""), + IsOkAndHolds(testing::IsEmpty())); + + // Name is considered the address if address is empty. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name<>"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "name<>"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + // Empty name and address means that there is no token. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("(a long comment with nothing else)"), + IsOkAndHolds( + ElementsAre(EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_COMMENT, "long"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_COMMENT, "with"), + EqualsToken(Token::Type::RFC822_COMMENT, "nothing"), + EqualsToken(Token::Type::RFC822_COMMENT, "else"), + EqualsToken(Token::Type::RFC822_TOKEN, + "(a long comment with nothing else)")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name ()"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "name ()"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(R"((comment) "")"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "(comment) \"\"")))); +} + +TEST_F(Rfc822TokenizerTest, NameTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + // Name spread between address or comment. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("peanut <address> butter"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_NAME, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut <address> butter")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("peanut (comment) butter"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_ADDRESS, + "peanut (comment) butter")))); + + // Dropping quotes when they're not needed. + std::string s = R"(peanut <address> "butter")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_NAME, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, s)))); + + s = R"(peanut "butter")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, s), + EqualsToken(Token::Type::RFC822_ADDRESS, "peanut \"butter")))); + // Adding quotes when they are needed. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("ple@se quote this <addr>"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "ple"), + EqualsToken(Token::Type::RFC822_NAME, "se"), + EqualsToken(Token::Type::RFC822_NAME, "quote"), + EqualsToken(Token::Type::RFC822_NAME, "this"), + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + + EqualsToken(Token::Type::RFC822_TOKEN, "ple@se quote this <addr>")))); +} + +TEST_F(Rfc822TokenizerTest, CommentEscapeTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // '(', ')', '\\' chars should be escaped. All other escaped chars should be + // unescaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((co\)mm\\en\(t))"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "co"), + EqualsToken(Token::Type::RFC822_COMMENT, "mm"), + EqualsToken(Token::Type::RFC822_COMMENT, "en"), + EqualsToken(Token::Type::RFC822_COMMENT, "t"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((co\)mm\\en\(t))")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((c\om\ment) name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, R"(c\om\ment)"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((c\om\ment) name)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((co(m\))ment) name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "co"), + EqualsToken(Token::Type::RFC822_COMMENT, "m"), + EqualsToken(Token::Type::RFC822_COMMENT, "ment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((co(m\))ment) name)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, QuoteEscapeTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // All names that include non-alphanumeric chars must be quoted and have '\\' + // and '"' chars escaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(n\\a\me <addr>)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "n"), + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "me"), + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + EqualsToken(Token::Type::RFC822_TOKEN, R"(n\\a\me <addr>)")))); + + // Names that are within quotes should have all characters blindly unescaped. + // When a name is made into an address, it isn't re-escaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"("n\\a\m\"e")"), + // <n\am"e> + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "n"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a\\m"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "e"), + EqualsToken(Token::Type::RFC822_TOKEN, R"("n\\a\m\"e")"), + EqualsToken(Token::Type::RFC822_ADDRESS, R"(n\\a\m\"e)")))); +} + +TEST_F(Rfc822TokenizerTest, UnterminatedComponentTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name (comment"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "name (comment"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(half of "the name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "half"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "of"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "the"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "half of \"the name"), + EqualsToken(Token::Type::RFC822_ADDRESS, "half of \"the name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"("name\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"name\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(name (comment\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "name (comment\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(<addr> "name\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "<addr> \"name\\")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(name (comment\))"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, R"(name (comment\))"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, Tokenize) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string text = + R"raw("Berg" (home) <berg\@google.com>, tom\@google.com (work))raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Berg"), + EqualsToken(Token::Type::RFC822_COMMENT, "home"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "berg\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "berg\\@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "berg"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"("Berg" (home) <berg\@google.com>)"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "tom"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_COMMENT, "work"), + EqualsToken(Token::Type::RFC822_TOKEN, "tom\\@google.com (work)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "tom\\@google.com")))); + + text = R"raw(Foo Bar (something) <foo\@google.com>, )raw" + R"raw(blah\@google.com (something))raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Foo"), + EqualsToken(Token::Type::RFC822_NAME, "Bar"), + EqualsToken(Token::Type::RFC822_COMMENT, "something"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo\\@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "Foo Bar (something) <foo\\@google.com>"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "blah"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_COMMENT, "something"), + EqualsToken(Token::Type::RFC822_TOKEN, + "blah\\@google.com (something)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "blah\\@google.com")))); +} + +TEST_F(Rfc822TokenizerTest, EdgeCases) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + // Text to trigger the scenario where you have a non-alphabetic followed + // by a \ followed by non alphabetic to end an in-address token. + std::string text = R"raw(<be.\&rg@google.com>)raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "be.\\&rg"), + EqualsToken(Token::Type::RFC822_ADDRESS, "be.\\&rg@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "be"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "rg"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"raw(<be.\&rg@google.com>)raw")))); + + // A \ followed by an alphabetic shouldn't end the token. + text = "<a\\lex@google.com>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "a\\lex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "a\\lex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "a\\lex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<a\\lex@google.com>")))); + + // \\ or \" in a quoted section. + text = R"("al\\ex@goo\"<idk>gle.com")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "al"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "ex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "goo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "idk"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gle"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"("al\\ex@goo\"<idk>gle.com")"), + EqualsToken(Token::Type::RFC822_ADDRESS, + R"(al\\ex@goo\"<idk>gle.com)"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "al\\\\ex")))); + + text = "<alex@google.com"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alex@google.com")))); +} + +TEST_F(Rfc822TokenizerTest, NumberInAddress) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "<3alex@google.com>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "3alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "3alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "3alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<3alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, DoubleQuoteDoubleSlash) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("alex\"")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "alex"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex")))); + + text = R"("alex\\\a")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, R"(alex\\\a)")))); +} + +TEST_F(Rfc822TokenizerTest, TwoEmails) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "tjbarron@google.com alexsav@google.com"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "tjbarron"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, text), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, + "tjbarron@google.com alexsav")))); +} + +TEST_F(Rfc822TokenizerTest, BackSlashes) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("\name")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"\\name\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + text = R"("name@foo\@gmail")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "name"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gmail"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "name@foo\\@gmail"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, BigWhitespace) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "\"quoted\" <address>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "quoted"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_TOKEN, text)))); +} + +TEST_F(Rfc822TokenizerTest, AtSignFirst) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "\"@foo\""; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo")))); +} + +TEST_F(Rfc822TokenizerTest, SlashThenUnicode) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("quoted\你cjk")"; + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, + "quoted\\你cjk"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "quoted\\你cjk")))); +} + +TEST_F(Rfc822TokenizerTest, AddressEmptyAddress) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "<address> <> Name"; + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, + "address"), + EqualsToken(Token::Type::RFC822_NAME, "Name"), + EqualsToken(Token::Type::RFC822_TOKEN, text)))); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/tokenization/token.h b/icing/tokenization/token.h index 0c268be..24f567b 100644 --- a/icing/tokenization/token.h +++ b/icing/tokenization/token.h @@ -29,6 +29,15 @@ struct Token { VERBATIM, // A token that should be indexed and searched without any // modifications to the raw text + // An RFC822 section with the content in RFC822_TOKEN tokenizes as follows: + RFC822_NAME, // "User", "Johnsson" + RFC822_COMMENT, // "A", "comment", "here" + RFC822_LOCAL_ADDRESS, // "user.name" + RFC822_ADDRESS, // "user.name@domain.name.com" + RFC822_ADDRESS_COMPONENT_LOCAL, // "user", "name", + RFC822_ADDRESS_COMPONENT_HOST, // "domain", "name", "com" + RFC822_TOKEN, // "User Johnsson (A comment) <user.name@domain.name.com>" + // Types only used in raw query QUERY_OR, // Indicates OR logic between its left and right tokens QUERY_EXCLUSION, // Indicates exclusion operation on next token @@ -45,10 +54,10 @@ struct Token { : type(type_in), text(text_in) {} // The type of token - const Type type; + Type type; // The content of token - const std::string_view text; + std::string_view text; }; } // namespace lib diff --git a/icing/transform/icu/icu-normalizer_benchmark.cc b/icing/transform/icu/icu-normalizer_benchmark.cc index fdd4c70..fe8289a 100644 --- a/icing/transform/icu/icu-normalizer_benchmark.cc +++ b/icing/transform/icu/icu-normalizer_benchmark.cc @@ -25,7 +25,7 @@ // //icing/transform/icu:icu-normalizer_benchmark // // $ blaze-bin/icing/transform/icu/icu-normalizer_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // Make target //icing/transform:normalizer depend on @@ -39,7 +39,7 @@ // blaze-bin/icing/transform/icu/icu-normalizer_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/icu-normalizer_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/icu-normalizer_benchmark --benchmark_filter=all // --adb // Flag to tell the benchmark that it'll be run on an Android device via adb, diff --git a/icing/transform/map/map-normalizer_benchmark.cc b/icing/transform/map/map-normalizer_benchmark.cc index 8268541..4560329 100644 --- a/icing/transform/map/map-normalizer_benchmark.cc +++ b/icing/transform/map/map-normalizer_benchmark.cc @@ -24,7 +24,7 @@ // //icing/transform/map:map-normalizer_benchmark // // $ blaze-bin/icing/transform/map/map-normalizer_benchmark -// --benchmarks=all +// --benchmark_filter=all // // Run on an Android device: // $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" @@ -35,7 +35,7 @@ // blaze-bin/icing/transform/map/map-normalizer_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/map-normalizer_benchmark --benchmarks=all +// $ adb shell /data/local/tmp/map-normalizer_benchmark --benchmark_filter=all namespace icing { namespace lib { diff --git a/icing/util/clock.h b/icing/util/clock.h index 2bb7818..9e57854 100644 --- a/icing/util/clock.h +++ b/icing/util/clock.h @@ -16,6 +16,7 @@ #define ICING_UTIL_CLOCK_H_ #include <cstdint> +#include <functional> #include <memory> namespace icing { @@ -69,6 +70,32 @@ class Clock { virtual std::unique_ptr<Timer> GetNewTimer() const; }; +// A convenient RAII timer class that receives a callback. Upon destruction, the +// callback will be called with the elapsed milliseconds or nanoseconds passed +// as a parameter, depending on which Unit was passed in the constructor. +class ScopedTimer { + public: + enum class Unit { kMillisecond, kNanosecond }; + + ScopedTimer(std::unique_ptr<Timer> timer, + std::function<void(int64_t)> callback, + Unit unit = Unit::kMillisecond) + : timer_(std::move(timer)), callback_(std::move(callback)), unit_(unit) {} + + ~ScopedTimer() { + if (unit_ == Unit::kMillisecond) { + callback_(timer_->GetElapsedMilliseconds()); + } else { + callback_(timer_->GetElapsedNanoseconds()); + } + } + + private: + std::unique_ptr<Timer> timer_; + std::function<void(int64_t)> callback_; + Unit unit_; +}; + } // namespace lib } // namespace icing diff --git a/icing/util/crc32.h b/icing/util/crc32.h index 5befe44..207a80a 100644 --- a/icing/util/crc32.h +++ b/icing/util/crc32.h @@ -35,6 +35,8 @@ class Crc32 { explicit Crc32(uint32_t init_crc) : crc_(init_crc) {} + explicit Crc32(std::string_view str) : crc_(0) { Append(str); } + inline bool operator==(const Crc32& other) const { return crc_ == other.Get(); } diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc index 45c23e0..b03d3f5 100644 --- a/icing/util/document-validator_test.cc +++ b/icing/util/document-validator_test.cc @@ -125,10 +125,10 @@ class DocumentValidatorTest : public ::testing::Test { } std::string schema_dir_; - std::unique_ptr<DocumentValidator> document_validator_; - std::unique_ptr<SchemaStore> schema_store_; Filesystem filesystem_; FakeClock fake_clock_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentValidator> document_validator_; }; TEST_F(DocumentValidatorTest, ValidateSimpleSchemasOk) { diff --git a/icing/util/fingerprint-util.cc b/icing/util/fingerprint-util.cc new file mode 100644 index 0000000..0ea843f --- /dev/null +++ b/icing/util/fingerprint-util.cc @@ -0,0 +1,48 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/fingerprint-util.h" + +namespace icing { +namespace lib { + +namespace fingerprint_util { + +// A formatter to properly handle a string that is actually just a hash value. +std::string GetFingerprintString(uint64_t fingerprint) { + std::string encoded_fprint; + // DynamicTrie cannot handle keys with '0' as bytes. So, we encode it in + // base128 and add 1 to make sure that no byte is '0'. This increases the + // size of the encoded_fprint from 8-bytes to 10-bytes. + while (fingerprint) { + encoded_fprint.push_back((fingerprint & 0x7F) + 1); + fingerprint >>= 7; + } + return encoded_fprint; +} + +uint64_t GetFingerprint(std::string_view fingerprint_string) { + uint64_t fprint = 0; + for (int i = fingerprint_string.length() - 1; i >= 0; --i) { + fprint <<= 7; + char c = fingerprint_string[i] - 1; + fprint |= (c & 0x7F); + } + return fprint; +} + +} // namespace fingerprint_util + +} // namespace lib +} // namespace icing diff --git a/icing/util/fingerprint-util.h b/icing/util/fingerprint-util.h new file mode 100644 index 0000000..9e98617 --- /dev/null +++ b/icing/util/fingerprint-util.h @@ -0,0 +1,47 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_UTIL_FINGERPRINT_UTIL_H_ +#define ICING_UTIL_FINGERPRINT_UTIL_H_ + +#include <cstdint> +#include <string> +#include <string_view> + +namespace icing { +namespace lib { + +namespace fingerprint_util { + +// Converts from a fingerprint to a fingerprint string. +std::string GetFingerprintString(uint64_t fingerprint); + +// Converts from a fingerprint string to a fingerprint. +uint64_t GetFingerprint(std::string_view fingerprint_string); + +// A formatter to properly handle a string that is actually just a hash value. +class FingerprintStringFormatter { + public: + std::string operator()(std::string_view fingerprint_string) { + uint64_t fingerprint = GetFingerprint(fingerprint_string); + return std::to_string(fingerprint); + } +}; + +} // namespace fingerprint_util + +} // namespace lib +} // namespace icing + +#endif // ICING_UTIL_FINGERPRINT_UTIL_H_ diff --git a/icing/util/fingerprint-util_test.cc b/icing/util/fingerprint-util_test.cc new file mode 100644 index 0000000..948c75a --- /dev/null +++ b/icing/util/fingerprint-util_test.cc @@ -0,0 +1,75 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/fingerprint-util.h" + +#include <cstdint> +#include <limits> + +#include "icing/text_classifier/lib3/utils/hash/farmhash.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace icing { +namespace lib { +namespace fingerprint_util { + +namespace { + +using ::testing::Eq; + +TEST(FingerprintUtilTest, ConversionIsReversible) { + std::string str = "foo-bar-baz"; + uint64_t fprint = tc3farmhash::Fingerprint64(str); + std::string fprint_string = GetFingerprintString(fprint); + EXPECT_THAT(GetFingerprint(fprint_string), Eq(fprint)); +} + +TEST(FingerprintUtilTest, ZeroConversionIsReversible) { + uint64_t fprint = 0; + std::string fprint_string = GetFingerprintString(fprint); + EXPECT_THAT(GetFingerprint(fprint_string), Eq(fprint)); +} + +TEST(FingerprintUtilTest, MultipleConversionsAreReversible) { + EXPECT_THAT(GetFingerprint(GetFingerprintString(25)), Eq(25)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(766)), Eq(766)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(2305)), Eq(2305)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(6922)), Eq(6922)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(62326)), Eq(62326)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(186985)), Eq(186985)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(560962)), Eq(560962)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(1682893)), Eq(1682893)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(15146065)), Eq(15146065)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(136314613)), Eq(136314613)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(1226831545)), Eq(1226831545)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(11041483933)), + Eq(11041483933)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(2683080596566)), + Eq(2683080596566)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(72443176107373)), + Eq(72443176107373)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(1955965754899162)), + Eq(1955965754899162)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(52811075382277465)), + Eq(52811075382277465)); + EXPECT_THAT(GetFingerprint(GetFingerprintString(4277697105964474945)), + Eq(4277697105964474945)); +} + +} // namespace + +} // namespace fingerprint_util +} // namespace lib +} // namespace icing diff --git a/icing/util/logging.cc b/icing/util/logging.cc new file mode 100644 index 0000000..8498be4 --- /dev/null +++ b/icing/util/logging.cc @@ -0,0 +1,124 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/logging.h" + +#include <atomic> +#include <exception> +#include <string_view> + +#include "icing/util/logging_raw.h" + +namespace icing { +namespace lib { +namespace { +// Returns pointer to beginning of last /-separated token from file_name. +// file_name should be a pointer to a zero-terminated array of chars. +// E.g., "foo/bar.cc" -> "bar.cc", "foo/" -> "", "foo" -> "foo". +const char *JumpToBasename(const char *file_name) { + if (file_name == nullptr) { + return nullptr; + } + + // Points to the beginning of the last encountered token. + size_t last_token_start = std::string_view(file_name).find_last_of('/'); + if (last_token_start == std::string_view::npos) { + return file_name; + } + return file_name + last_token_start + 1; +} + +// Calculate the logging level value based on severity and verbosity. +constexpr uint32_t CalculateLoggingLevel(LogSeverity::Code severity, + uint16_t verbosity) { + uint32_t logging_level = static_cast<uint16_t>(severity); + logging_level = (logging_level << 16) | verbosity; + return logging_level; +} + +#if defined(ICING_DEBUG_LOGGING) +#define DEFAULT_LOGGING_LEVEL CalculateLoggingLevel(LogSeverity::VERBOSE, 1) +#else +#define DEFAULT_LOGGING_LEVEL CalculateLoggingLevel(LogSeverity::INFO, 0) +#endif + +// The current global logging level for Icing, which controls which logs are +// printed based on severity and verbosity. +// +// This needs to be global so that it can be easily accessed from ICING_LOG and +// ICING_VLOG macros spread throughout the entire code base. +// +// The first 16 bits represent the minimal log severity. +// The last 16 bits represent the current verbosity. +std::atomic<uint32_t> global_logging_level = DEFAULT_LOGGING_LEVEL; + +} // namespace + +// Whether we should log according to the current logging level. +bool ShouldLog(LogSeverity::Code severity, int16_t verbosity) { + if (verbosity < 0) { + return false; + } + // Using the relaxed order for better performance because we only need to + // guarantee the atomicity for this specific statement, without the need to + // worry about reordering. + uint32_t curr_logging_level = + global_logging_level.load(std::memory_order_relaxed); + // If severity is less than the the threshold set. + if (static_cast<uint16_t>(severity) < (curr_logging_level >> 16)) { + return false; + } + if (severity == LogSeverity::VERBOSE) { + // return whether the verbosity is within the current verbose level set. + return verbosity <= (curr_logging_level & 0xffff); + } + return true; +} + +bool SetLoggingLevel(LogSeverity::Code severity, int16_t verbosity) { + if (verbosity < 0) { + return false; + } + if (severity > LogSeverity::VERBOSE && verbosity > 0) { + return false; + } + // Using the relaxed order for better performance because we only need to + // guarantee the atomicity for this specific statement, without the need to + // worry about reordering. + global_logging_level.store(CalculateLoggingLevel(severity, verbosity), + std::memory_order_relaxed); + return true; +} + +LogMessage::LogMessage(LogSeverity::Code severity, uint16_t verbosity, + const char *file_name, int line_number) + : severity_(severity), + verbosity_(verbosity), + should_log_(ShouldLog(severity_, verbosity_)), + stream_(should_log_) { + if (should_log_) { + stream_ << JumpToBasename(file_name) << ":" << line_number << ": "; + } +} + +LogMessage::~LogMessage() { + if (should_log_) { + LowLevelLogging(severity_, kIcingLoggingTag, stream_.message); + } + if (severity_ == LogSeverity::FATAL) { + std::terminate(); // Will print a stacktrace (stdout or logcat). + } +} +} // namespace lib +} // namespace icing diff --git a/icing/util/logging.h b/icing/util/logging.h index 9d598fe..7742302 100644 --- a/icing/util/logging.h +++ b/icing/util/logging.h @@ -15,14 +15,130 @@ #ifndef ICING_UTIL_LOGGING_H_ #define ICING_UTIL_LOGGING_H_ -#include "icing/text_classifier/lib3/utils/base/logging.h" +#include <atomic> +#include <cstdint> +#include <string> +#include "icing/proto/debug.pb.h" + +// This header provides base/logging.h style macros, ICING_LOG and ICING_VLOG, +// for logging in various platforms. The macros use __android_log_write on +// Android, and log to stdout/stderr on others. It also provides a function +// SetLoggingLevel to control the log severity level for ICING_LOG and verbosity +// for ICING_VLOG. namespace icing { namespace lib { -// TODO(b/146903474) Add verbose level control -#define ICING_VLOG(verbose_level) TC3_VLOG(verbose_level) -#define ICING_LOG(severity) TC3_LOG(severity) +// Whether we should log according to the current logging level. +// The function will always return false when verbosity is negative. +bool ShouldLog(LogSeverity::Code severity, int16_t verbosity = 0); + +// Set the minimal logging severity to be enabled, and the verbose level to see +// from the logs. +// Return false if severity is set higher than VERBOSE but verbosity is not 0. +// The function will always return false when verbosity is negative. +bool SetLoggingLevel(LogSeverity::Code severity, int16_t verbosity = 0); + +// A tiny code footprint string stream for assembling log messages. +struct LoggingStringStream { + explicit LoggingStringStream(bool should_log) : should_log_(should_log) {} + LoggingStringStream& stream() { return *this; } + + std::string message; + const bool should_log_; +}; + +template <typename T> +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + const T& entry) { + if (stream.should_log_) { + stream.message.append(std::to_string(entry)); + } + return stream; +} + +template <typename T> +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + T* const entry) { + if (stream.should_log_) { + stream.message.append( + std::to_string(reinterpret_cast<const uint64_t>(entry))); + } + return stream; +} + +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + const char* message) { + if (stream.should_log_) { + stream.message.append(message); + } + return stream; +} + +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + const std::string& message) { + if (stream.should_log_) { + stream.message.append(message); + } + return stream; +} + +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + std::string_view message) { + if (stream.should_log_) { + stream.message.append(message); + } + return stream; +} + +template <typename T1, typename T2> +inline LoggingStringStream& operator<<(LoggingStringStream& stream, + const std::pair<T1, T2>& entry) { + if (stream.should_log_) { + stream << "(" << entry.first << ", " << entry.second << ")"; + } + return stream; +} + +// The class that does all the work behind our ICING_LOG(severity) macros. Each +// ICING_LOG(severity) << obj1 << obj2 << ...; logging statement creates a +// LogMessage temporary object containing a stringstream. Each operator<< adds +// info to that stringstream and the LogMessage destructor performs the actual +// logging. The reason this works is that in C++, "all temporary objects are +// destroyed as the last step in evaluating the full-expression that (lexically) +// contains the point where they were created." For more info, see +// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is +// invoked after the last << from that logging statement. +class LogMessage { + public: + LogMessage(LogSeverity::Code severity, uint16_t verbosity, + const char* file_name, int line_number) __attribute__((noinline)); + + ~LogMessage() __attribute__((noinline)); + + // Returns the stream associated with the logger object. + LoggingStringStream& stream() { return stream_; } + + private: + const LogSeverity::Code severity_; + const uint16_t verbosity_; + const bool should_log_; + + // Stream that "prints" all info into a string (not to a file). We construct + // here the entire logging message and next print it in one operation. + LoggingStringStream stream_; +}; + +inline constexpr char kIcingLoggingTag[] = "AppSearchIcing"; + +#define ICING_VLOG(verbose_level) \ + ::icing::lib::LogMessage(::icing::lib::LogSeverity::VERBOSE, verbose_level, \ + __FILE__, __LINE__) \ + .stream() +#define ICING_LOG(severity) \ + ::icing::lib::LogMessage(::icing::lib::LogSeverity::severity, \ + /*verbosity=*/0, __FILE__, __LINE__) \ + .stream() } // namespace lib } // namespace icing diff --git a/icing/util/logging_raw.cc b/icing/util/logging_raw.cc new file mode 100644 index 0000000..5e67fb3 --- /dev/null +++ b/icing/util/logging_raw.cc @@ -0,0 +1,102 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/logging_raw.h" + +#include <cstdio> +#include <string> + +// NOTE: this file contains two implementations: one for Android, one for all +// other cases. We always build exactly one implementation. +#if defined(__ANDROID__) + +// Compiled as part of Android. +#include <android/log.h> + +namespace icing { +namespace lib { + +namespace { +// Converts LogSeverity to level for __android_log_write. +int GetAndroidLogLevel(LogSeverity::Code severity) { + switch (severity) { + case LogSeverity::VERBOSE: + return ANDROID_LOG_VERBOSE; + case LogSeverity::DBG: + return ANDROID_LOG_DEBUG; + case LogSeverity::INFO: + return ANDROID_LOG_INFO; + case LogSeverity::WARNING: + return ANDROID_LOG_WARN; + case LogSeverity::ERROR: + return ANDROID_LOG_ERROR; + case LogSeverity::FATAL: + return ANDROID_LOG_FATAL; + } +} +} // namespace + +void LowLevelLogging(LogSeverity::Code severity, const std::string& tag, + const std::string& message) { + const int android_log_level = GetAndroidLogLevel(severity); +#if __ANDROID_API__ >= 30 + if (!__android_log_is_loggable(android_log_level, tag.c_str(), + /*default_prio=*/ANDROID_LOG_INFO)) { + return; + } +#endif // __ANDROID_API__ >= 30 + __android_log_write(android_log_level, tag.c_str(), message.c_str()); +} + +} // namespace lib +} // namespace icing + +#else // if defined(__ANDROID__) + +// Not on Android: implement LowLevelLogging to print to stderr (see below). +namespace icing { +namespace lib { + +namespace { +// Converts LogSeverity to human-readable text. +const char *LogSeverityToString(LogSeverity::Code severity) { + switch (severity) { + case LogSeverity::VERBOSE: + return "VERBOSE"; + case LogSeverity::DBG: + return "DEBUG"; + case LogSeverity::INFO: + return "INFO"; + case LogSeverity::WARNING: + return "WARNING"; + case LogSeverity::ERROR: + return "ERROR"; + case LogSeverity::FATAL: + return "FATAL"; + } +} +} // namespace + +void LowLevelLogging(LogSeverity::Code severity, const std::string &tag, + const std::string &message) { + // TODO(b/146903474) Do not log to stderr for logs other than FATAL and ERROR. + fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(), + message.c_str()); + fflush(stderr); +} + +} // namespace lib +} // namespace icing + +#endif // if defined(__ANDROID__) diff --git a/icing/util/logging_raw.h b/icing/util/logging_raw.h new file mode 100644 index 0000000..99dddb6 --- /dev/null +++ b/icing/util/logging_raw.h @@ -0,0 +1,34 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_UTIL_LOGGING_RAW_H_ +#define ICING_UTIL_LOGGING_RAW_H_ + +#include <string> + +#include "icing/proto/debug.pb.h" + +namespace icing { +namespace lib { + +// Low-level logging primitive. Logs a message, with the indicated log +// severity. From android/log.h: "the tag normally corresponds to the component +// that emits the log message, and should be reasonably small". +void LowLevelLogging(LogSeverity::Code severity, const std::string &tag, + const std::string &message); + +} // namespace lib +} // namespace icing + +#endif // ICING_UTIL_LOGGING_RAW_H_ diff --git a/icing/util/logging_test.cc b/icing/util/logging_test.cc new file mode 100644 index 0000000..eac018e --- /dev/null +++ b/icing/util/logging_test.cc @@ -0,0 +1,158 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/logging.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/debug.pb.h" +#include "icing/util/logging_raw.h" + +namespace icing { +namespace lib { + +namespace { +using ::testing::EndsWith; +using ::testing::IsEmpty; + +TEST(LoggingTest, SetLoggingLevelWithInvalidArguments) { + EXPECT_FALSE(SetLoggingLevel(LogSeverity::DBG, 1)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::INFO, 1)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::WARNING, 1)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::ERROR, 1)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::FATAL, 1)); + + EXPECT_FALSE(SetLoggingLevel(LogSeverity::DBG, 2)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::INFO, 2)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::WARNING, 2)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::ERROR, 2)); + EXPECT_FALSE(SetLoggingLevel(LogSeverity::FATAL, 2)); + + EXPECT_FALSE(SetLoggingLevel(LogSeverity::VERBOSE, -1)); +} + +TEST(LoggingTest, SetLoggingLevelTest) { + // Set to INFO + ASSERT_TRUE(SetLoggingLevel(LogSeverity::INFO)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + + // Set to WARNING + ASSERT_TRUE(SetLoggingLevel(LogSeverity::WARNING)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_FALSE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + + // Set to DEBUG + ASSERT_TRUE(SetLoggingLevel(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); +} + +TEST(LoggingTest, VerboseLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::VERBOSE, 1)); + EXPECT_TRUE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_TRUE(ShouldLog(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + EXPECT_TRUE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, VerboseLoggingIsControlledByVerbosity) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::VERBOSE, 2)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 3)); + EXPECT_TRUE(ShouldLog(LogSeverity::VERBOSE, 2)); + EXPECT_TRUE(ShouldLog(LogSeverity::VERBOSE, 1)); + + ASSERT_TRUE(SetLoggingLevel(LogSeverity::VERBOSE, 1)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 2)); + EXPECT_TRUE(ShouldLog(LogSeverity::VERBOSE, 1)); + + ASSERT_TRUE(SetLoggingLevel(LogSeverity::VERBOSE, 0)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_TRUE(ShouldLog(LogSeverity::VERBOSE, 0)); + + // Negative verbosity is invalid. + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, -1)); +} + +TEST(LoggingTest, DebugLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::DBG)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_TRUE(ShouldLog(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + EXPECT_TRUE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, InfoLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::INFO)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_TRUE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + EXPECT_TRUE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, WarningLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::WARNING)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_FALSE(ShouldLog(LogSeverity::INFO)); + EXPECT_TRUE(ShouldLog(LogSeverity::WARNING)); + EXPECT_TRUE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, ErrorLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::ERROR)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_FALSE(ShouldLog(LogSeverity::INFO)); + EXPECT_FALSE(ShouldLog(LogSeverity::WARNING)); + EXPECT_TRUE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, FatalLoggingTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::FATAL)); + EXPECT_FALSE(ShouldLog(LogSeverity::VERBOSE, 1)); + EXPECT_FALSE(ShouldLog(LogSeverity::DBG)); + EXPECT_FALSE(ShouldLog(LogSeverity::INFO)); + EXPECT_FALSE(ShouldLog(LogSeverity::WARNING)); + EXPECT_FALSE(ShouldLog(LogSeverity::ERROR)); + EXPECT_TRUE(ShouldLog(LogSeverity::FATAL)); +} + +TEST(LoggingTest, LoggingStreamTest) { + ASSERT_TRUE(SetLoggingLevel(LogSeverity::INFO)); + // This one should be logged. + LoggingStringStream stream1 = (ICING_LOG(INFO) << "Hello" + << "World!"); + EXPECT_THAT(stream1.message, EndsWith("HelloWorld!")); + + // This one should not be logged, thus empty. + LoggingStringStream stream2 = (ICING_LOG(DBG) << "Hello" + << "World!"); + EXPECT_THAT(stream2.message, IsEmpty()); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 95e0c84..16a4a4a 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -16,6 +16,9 @@ package com.google.android.icing; import android.util.Log; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import com.google.android.icing.proto.DebugInfoResultProto; +import com.google.android.icing.proto.DebugInfoVerbosity; import com.google.android.icing.proto.DeleteByNamespaceResultProto; import com.google.android.icing.proto.DeleteByQueryResultProto; import com.google.android.icing.proto.DeleteBySchemaTypeResultProto; @@ -29,6 +32,7 @@ import com.google.android.icing.proto.GetSchemaResultProto; import com.google.android.icing.proto.GetSchemaTypeResultProto; import com.google.android.icing.proto.IcingSearchEngineOptions; import com.google.android.icing.proto.InitializeResultProto; +import com.google.android.icing.proto.LogSeverity; import com.google.android.icing.proto.OptimizeResultProto; import com.google.android.icing.proto.PersistToDiskResultProto; import com.google.android.icing.proto.PersistType; @@ -74,7 +78,9 @@ public class IcingSearchEngine implements Closeable { System.loadLibrary("icing"); } - /** @throws IllegalStateException if IcingSearchEngine fails to be created */ + /** + * @throws IllegalStateException if IcingSearchEngine fails to be created + */ public IcingSearchEngine(@NonNull IcingSearchEngineOptions options) { nativePointer = nativeCreate(options.toByteArray()); if (nativePointer == 0) { @@ -439,9 +445,16 @@ public class IcingSearchEngine implements Closeable { @NonNull public DeleteByQueryResultProto deleteByQuery(@NonNull SearchSpecProto searchSpec) { + return deleteByQuery(searchSpec, /*returnDeletedDocumentInfo=*/ false); + } + + @NonNull + public DeleteByQueryResultProto deleteByQuery( + @NonNull SearchSpecProto searchSpec, boolean returnDeletedDocumentInfo) { throwIfClosed(); - byte[] deleteResultBytes = nativeDeleteByQuery(this, searchSpec.toByteArray()); + byte[] deleteResultBytes = + nativeDeleteByQuery(this, searchSpec.toByteArray(), returnDeletedDocumentInfo); if (deleteResultBytes == null) { Log.e(TAG, "Received null DeleteResultProto from native."); return DeleteByQueryResultProto.newBuilder() @@ -539,8 +552,7 @@ public class IcingSearchEngine implements Closeable { } try { - return StorageInfoResultProto.parseFrom( - storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); + return StorageInfoResultProto.parseFrom(storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); } catch (InvalidProtocolBufferException e) { Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e); return StorageInfoResultProto.newBuilder() @@ -550,6 +562,28 @@ public class IcingSearchEngine implements Closeable { } @NonNull + public DebugInfoResultProto getDebugInfo(DebugInfoVerbosity.Code verbosity) { + throwIfClosed(); + + byte[] debugInfoResultProtoBytes = nativeGetDebugInfo(this, verbosity.getNumber()); + if (debugInfoResultProtoBytes == null) { + Log.e(TAG, "Received null DebugInfoResultProto from native."); + return DebugInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DebugInfoResultProto.parseFrom(debugInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DebugInfoResultProto.", e); + return DebugInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull public ResetResultProto reset() { throwIfClosed(); @@ -571,6 +605,31 @@ public class IcingSearchEngine implements Closeable { } } + public static boolean shouldLog(LogSeverity.Code severity) { + return shouldLog(severity, (short) 0); + } + + public static boolean shouldLog(LogSeverity.Code severity, short verbosity) { + return nativeShouldLog((short) severity.getNumber(), verbosity); + } + + public static boolean setLoggingLevel(LogSeverity.Code severity) { + return setLoggingLevel(severity, (short) 0); + } + + public static boolean setLoggingLevel(LogSeverity.Code severity, short verbosity) { + return nativeSetLoggingLevel((short) severity.getNumber(), verbosity); + } + + @Nullable + public static String getLoggingTag() { + String tag = nativeGetLoggingTag(); + if (tag == null) { + Log.e(TAG, "Received null logging tag from native."); + } + return tag; + } + private static native long nativeCreate(byte[] icingSearchEngineOptionsBytes); private static native void nativeDestroy(IcingSearchEngine instance); @@ -615,7 +674,7 @@ public class IcingSearchEngine implements Closeable { IcingSearchEngine instance, String schemaType); private static native byte[] nativeDeleteByQuery( - IcingSearchEngine instance, byte[] searchSpecBytes); + IcingSearchEngine instance, byte[] searchSpecBytes, boolean returnDeletedDocumentInfo); private static native byte[] nativePersistToDisk(IcingSearchEngine instance, int persistType); @@ -629,4 +688,12 @@ public class IcingSearchEngine implements Closeable { private static native byte[] nativeSearchSuggestions( IcingSearchEngine instance, byte[] suggestionSpecBytes); + + private static native byte[] nativeGetDebugInfo(IcingSearchEngine instance, int verbosity); + + private static native boolean nativeShouldLog(short severity, short verbosity); + + private static native boolean nativeSetLoggingLevel(short severity, short verbosity); + + private static native String nativeGetLoggingTag(); } diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index a46814c..b55cfd1 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -17,6 +17,9 @@ package com.google.android.icing; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import com.google.android.icing.IcingSearchEngine; +import com.google.android.icing.proto.DebugInfoResultProto; +import com.google.android.icing.proto.DebugInfoVerbosity; import com.google.android.icing.proto.DeleteByNamespaceResultProto; import com.google.android.icing.proto.DeleteByQueryResultProto; import com.google.android.icing.proto.DeleteBySchemaTypeResultProto; @@ -30,6 +33,7 @@ import com.google.android.icing.proto.GetSchemaResultProto; import com.google.android.icing.proto.GetSchemaTypeResultProto; import com.google.android.icing.proto.IcingSearchEngineOptions; import com.google.android.icing.proto.InitializeResultProto; +import com.google.android.icing.proto.LogSeverity; import com.google.android.icing.proto.OptimizeResultProto; import com.google.android.icing.proto.PersistToDiskResultProto; import com.google.android.icing.proto.PersistType; @@ -57,7 +61,6 @@ import com.google.android.icing.proto.SuggestionSpecProto.SuggestionScoringSpecP import com.google.android.icing.proto.TermMatchType; import com.google.android.icing.proto.TermMatchType.Code; import com.google.android.icing.proto.UsageReport; -import com.google.android.icing.IcingSearchEngine; import java.io.File; import java.util.HashMap; import java.util.Map; @@ -389,6 +392,60 @@ public final class IcingSearchEngineTest { DeleteByQueryResultProto deleteResultProto = icingSearchEngine.deleteByQuery(searchSpec); assertStatusOk(deleteResultProto.getStatus()); + // By default, the deleteByQuery API does not return the summary about deleted documents, unless + // the returnDeletedDocumentInfo parameter is set to true. + assertThat(deleteResultProto.getDeletedDocumentsList()).isEmpty(); + + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri1", GetResultSpecProto.getDefaultInstance()); + assertThat(getResultProto.getStatus().getCode()).isEqualTo(StatusProto.Code.NOT_FOUND); + getResultProto = + icingSearchEngine.get("namespace", "uri2", GetResultSpecProto.getDefaultInstance()); + assertStatusOk(getResultProto.getStatus()); + } + + @Test + public void testDeleteByQueryWithDeletedDocumentInfo() throws Exception { + assertStatusOk(icingSearchEngine.initialize().getStatus()); + + SchemaTypeConfigProto emailTypeConfig = createEmailTypeConfig(); + SchemaProto schema = SchemaProto.newBuilder().addTypes(emailTypeConfig).build(); + assertThat( + icingSearchEngine + .setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false) + .getStatus() + .getCode()) + .isEqualTo(StatusProto.Code.OK); + + DocumentProto emailDocument1 = + createEmailDocument("namespace", "uri1").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("foo")) + .build(); + + assertStatusOk(icingSearchEngine.put(emailDocument1).getStatus()); + DocumentProto emailDocument2 = + createEmailDocument("namespace", "uri2").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("bar")) + .build(); + + assertStatusOk(icingSearchEngine.put(emailDocument2).getStatus()); + + SearchSpecProto searchSpec = + SearchSpecProto.newBuilder() + .setQuery("foo") + .setTermMatchType(TermMatchType.Code.PREFIX) + .build(); + + DeleteByQueryResultProto deleteResultProto = + icingSearchEngine.deleteByQuery(searchSpec, /*returnDeletedDocumentInfo=*/ true); + assertStatusOk(deleteResultProto.getStatus()); + DeleteByQueryResultProto.DocumentGroupInfo info = + DeleteByQueryResultProto.DocumentGroupInfo.newBuilder() + .setNamespace("namespace") + .setSchema("Email") + .addUris("uri1") + .build(); + assertThat(deleteResultProto.getDeletedDocumentsList()).containsExactly(info); GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri1", GetResultSpecProto.getDefaultInstance()); @@ -434,6 +491,35 @@ public final class IcingSearchEngineTest { } @Test + public void testGetDebugInfo() throws Exception { + assertStatusOk(icingSearchEngine.initialize().getStatus()); + + SchemaTypeConfigProto emailTypeConfig = createEmailTypeConfig(); + SchemaProto schema = SchemaProto.newBuilder().addTypes(emailTypeConfig).build(); + assertThat( + icingSearchEngine + .setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false) + .getStatus() + .getCode()) + .isEqualTo(StatusProto.Code.OK); + + DocumentProto emailDocument = createEmailDocument("namespace", "uri"); + assertStatusOk(icingSearchEngine.put(emailDocument).getStatus()); + + DebugInfoResultProto debugInfoResultProtoBasic = + icingSearchEngine.getDebugInfo(DebugInfoVerbosity.Code.BASIC); + assertStatusOk(debugInfoResultProtoBasic.getStatus()); + assertThat(debugInfoResultProtoBasic.getDebugInfo().getDocumentInfo().getCorpusInfoList()) + .isEmpty(); // because verbosity=BASIC + + DebugInfoResultProto debugInfoResultProtoDetailed = + icingSearchEngine.getDebugInfo(DebugInfoVerbosity.Code.DETAILED); + assertStatusOk(debugInfoResultProtoDetailed.getStatus()); + assertThat(debugInfoResultProtoDetailed.getDebugInfo().getDocumentInfo().getCorpusInfoList()) + .hasSize(1); // because verbosity=DETAILED + } + + @Test public void testGetAllNamespaces() throws Exception { assertStatusOk(icingSearchEngine.initialize().getStatus()); @@ -668,6 +754,31 @@ public final class IcingSearchEngineTest { assertThat(response.getSuggestions(1).getQuery()).isEqualTo("fo"); } + @Test + public void testLogging() throws Exception { + // Set to INFO + assertThat(IcingSearchEngine.setLoggingLevel(LogSeverity.Code.INFO)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.INFO)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.DBG)).isFalse(); + + // Set to WARNING + assertThat(IcingSearchEngine.setLoggingLevel(LogSeverity.Code.WARNING)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.WARNING)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.INFO)).isFalse(); + + // Set to DEBUG + assertThat(IcingSearchEngine.setLoggingLevel(LogSeverity.Code.DBG)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.DBG)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.VERBOSE)).isFalse(); + + // Set to VERBOSE + assertThat(IcingSearchEngine.setLoggingLevel(LogSeverity.Code.VERBOSE, (short) 1)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.VERBOSE, (short) 1)).isTrue(); + assertThat(IcingSearchEngine.shouldLog(LogSeverity.Code.VERBOSE, (short) 2)).isFalse(); + + assertThat(IcingSearchEngine.getLoggingTag()).isNotEmpty(); + } + private static void assertStatusOk(StatusProto status) { assertWithMessage(status.getMessage()).that(status.getCode()).isEqualTo(StatusProto.Code.OK); } diff --git a/proto/icing/proto/debug.proto b/proto/icing/proto/debug.proto index 504ae43..90d1981 100644 --- a/proto/icing/proto/debug.proto +++ b/proto/icing/proto/debug.proto @@ -24,48 +24,57 @@ option java_package = "com.google.android.icing.proto"; option java_multiple_files = true; option objc_class_prefix = "ICNG"; +message LogSeverity { + enum Code { + VERBOSE = 0; + // Unable to use DEBUG at this time because it breaks YTM's iOS tests + // cs/?q=%22-DDEBUG%3D1%22%20f:%2FYoutubeMusic%20f:blueprint&ssfr=1 + DBG = 1; + INFO = 2; + WARNING = 3; + ERROR = 4; + FATAL = 5; + } +} + +message DebugInfoVerbosity { + enum Code { + // Simplest debug information. + BASIC = 0; + // More detailed debug information as indicated in the field documentation + // below. + DETAILED = 1; + } +} + // Next tag: 4 message IndexDebugInfoProto { // Storage information of the index. optional IndexStorageInfoProto index_storage_info = 1; - message MainIndexDebugInfoProto { - // Information about the main lexicon. - // TODO(b/222349894) Convert the string output to a protocol buffer instead. - optional string lexicon_info = 1; - - // Last added document id. - optional uint32 last_added_document_id = 2; - - // If verbosity > 0, return information about the posting list storage. - // TODO(b/222349894) Convert the string output to a protocol buffer instead. - optional string flash_index_storage_info = 3; - } - optional MainIndexDebugInfoProto main_index_info = 2; - - message LiteIndexDebugInfoProto { - // Current number of hits. - optional uint32 curr_size = 1; - - // The maximum possible number of hits. - optional uint32 hit_buffer_size = 2; - - // Last added document id. - optional uint32 last_added_document_id = 3; - - // The first position in the hit buffer that is not sorted yet, - // or curr_size if all hits are sorted. - optional uint32 searchable_end = 4; - - // The most recent checksum of the lite index, by calling - // LiteIndex::ComputeChecksum(). - optional uint32 index_crc = 5; - - // Information about the lite lexicon. - // TODO(b/222349894) Convert the string output to a protocol buffer instead. - optional string lexicon_info = 6; - } - optional LiteIndexDebugInfoProto lite_index_info = 3; + // A formatted string containing the following information: + // lexicon_info: Information about the main lexicon + // last_added_document_id: Last added document id + // flash_index_storage_info: If verbosity = DETAILED, return information about + // the posting list storage + // + // No direct contents from user-provided documents will ever appear in this + // string. + optional string main_index_info = 2; + + // A formatted string containing the following information: + // curr_size: Current number of hits + // hit_buffer_size: The maximum possible number of hits + // last_added_document_id: Last added document id + // searchable_end: The first position in the hit buffer that is not sorted + // yet, or curr_size if all hits are sorted + // index_crc: The most recent checksum of the lite index, by calling + // LiteIndex::ComputeChecksum() + // lexicon_info: Information about the lite lexicon + // + // No direct contents from user-provided documents will ever appear in this + // string. + optional string lite_index_info = 3; } // Next tag: 4 @@ -84,8 +93,8 @@ message DocumentDebugInfoProto { optional uint32 total_token = 4; } - // If verbosity > 0, return the total number of documents and tokens in each - // (namespace, schema type) pair. + // If verbosity = DETAILED, return the total number of documents and tokens in + // each (namespace, schema type) pair. // Note that deleted and expired documents are skipped in the output. repeated CorpusInfo corpus_info = 3; } @@ -117,7 +126,8 @@ message DebugInfoProto { message DebugInfoResultProto { // Status code can be one of: // OK - // FAILED_PRECONDITION + // FAILED_PRECONDITION if IcingSearchEngine has not been initialized yet + // INTERNAL on IO errors, crc compute error. // // See status.proto for more details. optional StatusProto status = 1; diff --git a/proto/icing/proto/optimize.proto b/proto/icing/proto/optimize.proto index 42290f3..0accb9a 100644 --- a/proto/icing/proto/optimize.proto +++ b/proto/icing/proto/optimize.proto @@ -63,7 +63,7 @@ message GetOptimizeInfoResultProto { optional int64 time_since_last_optimize_ms = 4; } -// Next tag: 10 +// Next tag: 11 message OptimizeStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -91,4 +91,15 @@ message OptimizeStatsProto { // The amount of time since the last optimize ran. optional int64 time_since_last_optimize_ms = 9; + + enum IndexRestorationMode { + // The index has been translated in place to match the optimized document + // store. + INDEX_TRANSLATION = 0; + // The index has been rebuilt from scratch during optimization. This could + // happen when we received a DATA_LOSS error from OptimizeDocumentStore, + // Index::Optimize failed, or rebuilding could be faster. + FULL_INDEX_REBUILD = 1; + } + optional IndexRestorationMode index_restoration_mode = 10; } diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index f005c76..7a361d3 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -65,7 +65,7 @@ message SearchSpecProto { // Client-supplied specifications on what to include/how to format the search // results. -// Next tag: 6 +// Next tag: 7 message ResultSpecProto { // The results will be returned in pages, and num_per_page specifies the // number of documents in one page. @@ -133,6 +133,15 @@ message ResultSpecProto { // ["ns0doc0", "ns0doc1", "ns1doc0", "ns3doc0", "ns3doc1", "ns2doc1", // "ns3doc2"]. repeated ResultGrouping result_groupings = 5; + + // The threshold of total bytes of all documents to cutoff, in order to limit + // # of bytes in a single page. + // Note that it doesn't guarantee the result # of bytes will be smaller, equal + // to, or larger than the threshold. Instead, it is just a threshold to + // cutoff, and only guarantees total bytes of search results will exceed the + // threshold by less than the size of the final search result. + optional int32 num_total_bytes_per_page_threshold = 6 + [default = 2147483647]; // INT_MAX } // The representation of a single match within a DocumentProto property. diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 73d349b..cd00254 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=436284873) +set(synced_AOSP_CL_number=466546985) |