aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreehugger Robot <treehugger-gerrit@google.com>2021-10-29 02:21:01 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2021-10-29 02:21:01 +0000
commit8f16167c39d77866a213b0ef042058e858255d47 (patch)
tree69c5b6c2253d521b1c26f1e6b4bb1cfe571c907e
parent2776967dbcaf1b3be137e7364453eb1e03a3cfe3 (diff)
parent34ddbebdafea0781a4342ac3fbf8d32f74176b81 (diff)
downloadicing-androidx-customview-customview-poolingcontainer-release.tar.gz
-rw-r--r--icing/file/file-backed-vector.h3
-rw-r--r--icing/file/file-backed-vector_test.cc3
-rw-r--r--icing/file/filesystem.cc2
-rw-r--r--icing/file/filesystem.h6
-rw-r--r--icing/icing-search-engine.cc82
-rw-r--r--icing/icing-search-engine.h11
-rw-r--r--icing/icing-search-engine_test.cc105
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-and.cc3
-rw-r--r--icing/index/lite/lite-index.cc5
-rw-r--r--icing/index/main/flash-index-storage.cc4
-rw-r--r--icing/index/main/flash-index-storage_test.cc2
-rw-r--r--icing/index/main/index-block.cc3
-rw-r--r--icing/index/main/index-block.h2
-rw-r--r--icing/index/main/posting-list-free.h2
-rw-r--r--icing/index/main/posting-list-used.h2
-rw-r--r--icing/jni/icing-search-engine-jni.cc19
-rw-r--r--icing/legacy/core/icing-core-types.h3
-rw-r--r--icing/legacy/core/icing-string-util.cc9
-rw-r--r--icing/legacy/core/icing-string-util.h5
-rw-r--r--icing/legacy/core/icing-timer.h3
-rw-r--r--icing/legacy/index/icing-array-storage.cc2
-rw-r--r--icing/legacy/index/icing-array-storage.h3
-rw-r--r--icing/legacy/index/icing-bit-util.h5
-rw-r--r--icing/legacy/index/icing-dynamic-trie.cc6
-rw-r--r--icing/legacy/index/icing-dynamic-trie.h3
-rw-r--r--icing/legacy/index/icing-filesystem.cc2
-rw-r--r--icing/legacy/index/icing-flash-bitmap.h3
-rw-r--r--icing/legacy/index/icing-mmapper.cc5
-rw-r--r--icing/legacy/index/icing-mock-filesystem.h9
-rw-r--r--icing/legacy/index/icing-storage-file.cc2
-rw-r--r--icing/portable/endian.h2
-rw-r--r--icing/query/query-processor.cc2
-rw-r--r--icing/query/suggestion-processor.cc93
-rw-r--r--icing/query/suggestion-processor.h68
-rw-r--r--icing/query/suggestion-processor_test.cc324
-rw-r--r--icing/schema/schema-store.cc5
-rw-r--r--icing/schema/schema-store.h6
-rw-r--r--icing/scoring/bm25f-calculator.cc51
-rw-r--r--icing/scoring/bm25f-calculator.h32
-rw-r--r--icing/scoring/score-and-rank_benchmark.cc125
-rw-r--r--icing/scoring/scorer.cc19
-rw-r--r--icing/scoring/scorer.h4
-rw-r--r--icing/scoring/scorer_test.cc191
-rw-r--r--icing/scoring/scoring-processor.cc9
-rw-r--r--icing/scoring/scoring-processor.h4
-rw-r--r--icing/scoring/scoring-processor_test.cc377
-rw-r--r--icing/scoring/section-weights.cc146
-rw-r--r--icing/scoring/section-weights.h95
-rw-r--r--icing/scoring/section-weights_test.cc386
-rw-r--r--icing/tokenization/raw-query-tokenizer.cc3
-rw-r--r--icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc2
-rw-r--r--icing/transform/icu/icu-normalizer.cc25
-rw-r--r--icing/transform/map/map-normalizer.cc3
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngine.java25
-rw-r--r--java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java36
-rw-r--r--proto/icing/proto/scoring.proto39
-rw-r--r--proto/icing/proto/search.proto34
-rw-r--r--synced_AOSP_CL_number.txt2
58 files changed, 2214 insertions, 208 deletions
diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h
index 0989935..00bdc7e 100644
--- a/icing/file/file-backed-vector.h
+++ b/icing/file/file-backed-vector.h
@@ -56,10 +56,9 @@
#ifndef ICING_FILE_FILE_BACKED_VECTOR_H_
#define ICING_FILE_FILE_BACKED_VECTOR_H_
-#include <inttypes.h>
-#include <stdint.h>
#include <sys/mman.h>
+#include <cinttypes>
#include <cstdint>
#include <memory>
#include <string>
diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc
index b05ce2d..7c02af9 100644
--- a/icing/file/file-backed-vector_test.cc
+++ b/icing/file/file-backed-vector_test.cc
@@ -14,9 +14,8 @@
#include "icing/file/file-backed-vector.h"
-#include <errno.h>
-
#include <algorithm>
+#include <cerrno>
#include <cstdint>
#include <memory>
#include <string_view>
diff --git a/icing/file/filesystem.cc b/icing/file/filesystem.cc
index 0655cb9..82b8d98 100644
--- a/icing/file/filesystem.cc
+++ b/icing/file/filesystem.cc
@@ -16,7 +16,6 @@
#include <dirent.h>
#include <dlfcn.h>
-#include <errno.h>
#include <fcntl.h>
#include <fnmatch.h>
#include <pthread.h>
@@ -26,6 +25,7 @@
#include <unistd.h>
#include <algorithm>
+#include <cerrno>
#include <cstdint>
#include <unordered_set>
diff --git a/icing/file/filesystem.h b/icing/file/filesystem.h
index 6bed8e6..ca8c4a8 100644
--- a/icing/file/filesystem.h
+++ b/icing/file/filesystem.h
@@ -17,11 +17,9 @@
#ifndef ICING_FILE_FILESYSTEM_H_
#define ICING_FILE_FILESYSTEM_H_
-#include <stdint.h>
-#include <stdio.h>
-#include <string.h>
-
#include <cstdint>
+#include <cstdio>
+#include <cstring>
#include <memory>
#include <string>
#include <unordered_set>
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc
index c4d585e..9aa833b 100644
--- a/icing/icing-search-engine.cc
+++ b/icing/icing-search-engine.cc
@@ -47,6 +47,7 @@
#include "icing/proto/search.pb.h"
#include "icing/proto/status.pb.h"
#include "icing/query/query-processor.h"
+#include "icing/query/suggestion-processor.h"
#include "icing/result/projection-tree.h"
#include "icing/result/projector.h"
#include "icing/result/result-retriever.h"
@@ -134,6 +135,26 @@ libtextclassifier3::Status ValidateSearchSpec(
return libtextclassifier3::Status::OK;
}
+libtextclassifier3::Status ValidateSuggestionSpec(
+ const SuggestionSpecProto& suggestion_spec,
+ const PerformanceConfiguration& configuration) {
+ if (suggestion_spec.prefix().empty()) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("SuggestionSpecProto.prefix is empty!"));
+ }
+ if (suggestion_spec.num_to_return() <= 0) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ "SuggestionSpecProto.num_to_return must be positive."));
+ }
+ if (suggestion_spec.prefix().size() > configuration.max_query_length) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("SuggestionSpecProto.prefix is longer than the "
+ "maximum allowed prefix length: ",
+ std::to_string(configuration.max_query_length)));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
// Document store files are in a standalone subfolder for easier file
// management. We can delete and recreate the subfolder and not touch/affect
// anything else.
@@ -1397,8 +1418,8 @@ SearchResultProto IcingSearchEngine::Search(
component_timer = clock_->GetNewTimer();
// Scores but does not rank the results.
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
- scoring_processor_or =
- ScoringProcessor::Create(scoring_spec, document_store_.get());
+ scoring_processor_or = ScoringProcessor::Create(
+ scoring_spec, document_store_.get(), schema_store_.get());
if (!scoring_processor_or.ok()) {
TransformStatus(scoring_processor_or.status(), result_status);
return result_proto;
@@ -1825,5 +1846,62 @@ ResetResultProto IcingSearchEngine::ResetInternal() {
return result_proto;
}
+SuggestionResponse IcingSearchEngine::SearchSuggestions(
+ const SuggestionSpecProto& suggestion_spec) {
+ // TODO(b/146008613) Explore ideas to make this function read-only.
+ absl_ports::unique_lock l(&mutex_);
+ SuggestionResponse response;
+ StatusProto* response_status = response.mutable_status();
+ if (!initialized_) {
+ response_status->set_code(StatusProto::FAILED_PRECONDITION);
+ response_status->set_message("IcingSearchEngine has not been initialized!");
+ return response;
+ }
+
+ libtextclassifier3::Status status =
+ ValidateSuggestionSpec(suggestion_spec, performance_configuration_);
+ if (!status.ok()) {
+ TransformStatus(status, response_status);
+ return response;
+ }
+
+ // Create the suggestion processor.
+ auto suggestion_processor_or = SuggestionProcessor::Create(
+ index_.get(), language_segmenter_.get(), normalizer_.get());
+ if (!suggestion_processor_or.ok()) {
+ TransformStatus(suggestion_processor_or.status(), response_status);
+ return response;
+ }
+ std::unique_ptr<SuggestionProcessor> suggestion_processor =
+ std::move(suggestion_processor_or).ValueOrDie();
+
+ std::vector<NamespaceId> namespace_ids;
+ namespace_ids.reserve(suggestion_spec.namespace_filters_size());
+ for (std::string_view name_space : suggestion_spec.namespace_filters()) {
+ auto namespace_id_or = document_store_->GetNamespaceId(name_space);
+ if (!namespace_id_or.ok()) {
+ continue;
+ }
+ namespace_ids.push_back(namespace_id_or.ValueOrDie());
+ }
+
+ // Run suggestion based on given SuggestionSpec.
+ libtextclassifier3::StatusOr<std::vector<TermMetadata>> terms_or =
+ suggestion_processor->QuerySuggestions(suggestion_spec, namespace_ids);
+ if (!terms_or.ok()) {
+ TransformStatus(terms_or.status(), response_status);
+ return response;
+ }
+
+ // Convert vector<TermMetaData> into final SuggestionResponse proto.
+ for (TermMetadata& term : terms_or.ValueOrDie()) {
+ SuggestionResponse::Suggestion suggestion;
+ suggestion.set_query(std::move(term.content));
+ response.mutable_suggestions()->Add(std::move(suggestion));
+ }
+ response_status->set_code(StatusProto::OK);
+ return response;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h
index d53dbd7..0a79714 100644
--- a/icing/icing-search-engine.h
+++ b/icing/icing-search-engine.h
@@ -302,6 +302,17 @@ class IcingSearchEngine {
const ResultSpecProto& result_spec)
ICING_LOCKS_EXCLUDED(mutex_);
+ // Retrieves, scores, ranks and returns the suggested query string according
+ // to the specs. Results can be empty.
+ //
+ // Returns a SuggestionResponse with status:
+ // OK with results on success
+ // INVALID_ARGUMENT if any of specs is invalid
+ // FAILED_PRECONDITION IcingSearchEngine has not been initialized yet
+ // INTERNAL_ERROR on any other errors
+ SuggestionResponse SearchSuggestions(
+ const SuggestionSpecProto& suggestion_spec) ICING_LOCKS_EXCLUDED(mutex_);
+
// Fetches the next page of results of a previously executed query. Results
// can be empty if next-page token is invalid. Invalid next page tokens are
// tokens that are either zero or were previously passed to
diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc
index 3db1701..b5206cd 100644
--- a/icing/icing-search-engine_test.cc
+++ b/icing/icing-search-engine_test.cc
@@ -8049,6 +8049,111 @@ TEST_F(IcingSearchEngineTest, PutDocumentIndexFailureDeletion) {
ASSERT_THAT(get_result.status(), ProtoStatusIs(StatusProto::NOT_FOUND));
}
+TEST_F(IcingSearchEngineTest, SearchSuggestionsTest) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(),
+ ProtoIsOk());
+
+ // Creates and inserts 6 documents, and index 6 termSix, 5 termFive, 4
+ // termFour, 3 termThree, 2 termTwo and one termOne.
+ DocumentProto document1 =
+ DocumentBuilder()
+ .SetKey("namespace", "uri1")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty(
+ "subject", "termOne termTwo termThree termFour termFive termSix")
+ .Build();
+ DocumentProto document2 =
+ DocumentBuilder()
+ .SetKey("namespace", "uri2")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty("subject",
+ "termTwo termThree termFour termFive termSix")
+ .Build();
+ DocumentProto document3 =
+ DocumentBuilder()
+ .SetKey("namespace", "uri3")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty("subject", "termThree termFour termFive termSix")
+ .Build();
+ DocumentProto document4 =
+ DocumentBuilder()
+ .SetKey("namespace", "uri4")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty("subject", "termFour termFive termSix")
+ .Build();
+ DocumentProto document5 =
+ DocumentBuilder()
+ .SetKey("namespace", "uri5")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty("subject", "termFive termSix")
+ .Build();
+ DocumentProto document6 = DocumentBuilder()
+ .SetKey("namespace", "uri6")
+ .SetSchema("Email")
+ .SetCreationTimestampMs(10)
+ .AddStringProperty("subject", "termSix")
+ .Build();
+ ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document4).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document5).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(document6).status(), ProtoIsOk());
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("t");
+ suggestion_spec.set_num_to_return(10);
+
+ // Query all suggestions, and they will be ranked.
+ SuggestionResponse response = icing.SearchSuggestions(suggestion_spec);
+ ASSERT_THAT(response.status(), ProtoIsOk());
+ ASSERT_THAT(response.suggestions().at(0).query(), "termsix");
+ ASSERT_THAT(response.suggestions().at(1).query(), "termfive");
+ ASSERT_THAT(response.suggestions().at(2).query(), "termfour");
+ ASSERT_THAT(response.suggestions().at(3).query(), "termthree");
+ ASSERT_THAT(response.suggestions().at(4).query(), "termtwo");
+ ASSERT_THAT(response.suggestions().at(5).query(), "termone");
+
+ // Query first three suggestions, and they will be ranked.
+ suggestion_spec.set_num_to_return(3);
+ response = icing.SearchSuggestions(suggestion_spec);
+ ASSERT_THAT(response.status(), ProtoIsOk());
+ ASSERT_THAT(response.suggestions().at(0).query(), "termsix");
+ ASSERT_THAT(response.suggestions().at(1).query(), "termfive");
+ ASSERT_THAT(response.suggestions().at(2).query(), "termfour");
+}
+
+TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_emptyPrefix) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("");
+ suggestion_spec.set_num_to_return(10);
+
+ ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(),
+ ProtoStatusIs(StatusProto::INVALID_ARGUMENT));
+}
+
+TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_NonPositiveNumToReturn) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("prefix");
+ suggestion_spec.set_num_to_return(0);
+
+ ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(),
+ ProtoStatusIs(StatusProto::INVALID_ARGUMENT));
+}
+
#ifndef ICING_JNI_TEST
// We skip this test case when we're running in a jni_test since the data files
// will be stored in the android-instrumented storage location, rather than the
diff --git a/icing/index/iterator/doc-hit-info-iterator-and.cc b/icing/index/iterator/doc-hit-info-iterator-and.cc
index 39aa969..543e9ef 100644
--- a/icing/index/iterator/doc-hit-info-iterator-and.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-and.cc
@@ -14,8 +14,7 @@
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
-#include <stddef.h>
-
+#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc
index fb23934..9e4ac28 100644
--- a/icing/index/lite/lite-index.cc
+++ b/icing/index/lite/lite-index.cc
@@ -14,12 +14,11 @@
#include "icing/index/lite/lite-index.h"
-#include <inttypes.h>
-#include <stddef.h>
-#include <stdint.h>
#include <sys/mman.h>
#include <algorithm>
+#include <cinttypes>
+#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
diff --git a/icing/index/main/flash-index-storage.cc b/icing/index/main/flash-index-storage.cc
index f125b6d..3c52375 100644
--- a/icing/index/main/flash-index-storage.cc
+++ b/icing/index/main/flash-index-storage.cc
@@ -14,11 +14,11 @@
#include "icing/index/main/flash-index-storage.h"
-#include <errno.h>
-#include <inttypes.h>
#include <sys/types.h>
#include <algorithm>
+#include <cerrno>
+#include <cinttypes>
#include <cstdint>
#include <memory>
#include <unordered_set>
diff --git a/icing/index/main/flash-index-storage_test.cc b/icing/index/main/flash-index-storage_test.cc
index 7e15524..25fcaad 100644
--- a/icing/index/main/flash-index-storage_test.cc
+++ b/icing/index/main/flash-index-storage_test.cc
@@ -14,10 +14,10 @@
#include "icing/index/main/flash-index-storage.h"
-#include <stdlib.h>
#include <unistd.h>
#include <algorithm>
+#include <cstdlib>
#include <limits>
#include <utility>
#include <vector>
diff --git a/icing/index/main/index-block.cc b/icing/index/main/index-block.cc
index 4590d06..c6ab345 100644
--- a/icing/index/main/index-block.cc
+++ b/icing/index/main/index-block.cc
@@ -14,9 +14,8 @@
#include "icing/index/main/index-block.h"
-#include <inttypes.h>
-
#include <algorithm>
+#include <cinttypes>
#include <limits>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
diff --git a/icing/index/main/index-block.h b/icing/index/main/index-block.h
index edf9a79..5d75a2a 100644
--- a/icing/index/main/index-block.h
+++ b/icing/index/main/index-block.h
@@ -15,10 +15,10 @@
#ifndef ICING_INDEX_MAIN_INDEX_BLOCK_H_
#define ICING_INDEX_MAIN_INDEX_BLOCK_H_
-#include <string.h>
#include <sys/mman.h>
#include <algorithm>
+#include <cstring>
#include <limits>
#include <memory>
#include <string>
diff --git a/icing/index/main/posting-list-free.h b/icing/index/main/posting-list-free.h
index 4f06057..75b99d7 100644
--- a/icing/index/main/posting-list-free.h
+++ b/icing/index/main/posting-list-free.h
@@ -15,10 +15,10 @@
#ifndef ICING_INDEX_MAIN_POSTING_LIST_FREE_H_
#define ICING_INDEX_MAIN_POSTING_LIST_FREE_H_
-#include <string.h>
#include <sys/mman.h>
#include <cstdint>
+#include <cstring>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
diff --git a/icing/index/main/posting-list-used.h b/icing/index/main/posting-list-used.h
index 1b2e24e..8944034 100644
--- a/icing/index/main/posting-list-used.h
+++ b/icing/index/main/posting-list-used.h
@@ -15,10 +15,10 @@
#ifndef ICING_INDEX_MAIN_POSTING_LIST_USED_H_
#define ICING_INDEX_MAIN_POSTING_LIST_USED_H_
-#include <string.h>
#include <sys/mman.h>
#include <algorithm>
+#include <cstring>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc
index ea2bcf7..51d3423 100644
--- a/icing/jni/icing-search-engine-jni.cc
+++ b/icing/jni/icing-search-engine-jni.cc
@@ -420,4 +420,23 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReset(
return SerializeProtoToJniByteArray(env, reset_result_proto);
}
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions(
+ JNIEnv* env, jclass clazz, jobject object,
+ jbyteArray suggestion_spec_bytes) {
+ icing::lib::IcingSearchEngine* icing =
+ GetIcingSearchEnginePointer(env, object);
+
+ icing::lib::SuggestionSpecProto suggestion_spec_proto;
+ if (!ParseProtoFromJniByteArray(env, suggestion_spec_bytes,
+ &suggestion_spec_proto)) {
+ ICING_LOG(ERROR) << "Failed to parse SuggestionSpecProto in nativeSearch";
+ return nullptr;
+ }
+ icing::lib::SuggestionResponse suggestionResponse =
+ icing->SearchSuggestions(suggestion_spec_proto);
+
+ return SerializeProtoToJniByteArray(env, suggestionResponse);
+}
+
} // extern "C"
diff --git a/icing/legacy/core/icing-core-types.h b/icing/legacy/core/icing-core-types.h
index cc12663..7db8408 100644
--- a/icing/legacy/core/icing-core-types.h
+++ b/icing/legacy/core/icing-core-types.h
@@ -21,9 +21,8 @@
#ifndef ICING_LEGACY_CORE_ICING_CORE_TYPES_H_
#define ICING_LEGACY_CORE_ICING_CORE_TYPES_H_
-#include <stdint.h>
-
#include <cstddef> // size_t not defined implicitly for all platforms.
+#include <cstdint>
#include <vector>
#include "icing/legacy/core/icing-compat.h"
diff --git a/icing/legacy/core/icing-string-util.cc b/icing/legacy/core/icing-string-util.cc
index 2eb64ac..ed06e03 100644
--- a/icing/legacy/core/icing-string-util.cc
+++ b/icing/legacy/core/icing-string-util.cc
@@ -13,12 +13,11 @@
// limitations under the License.
#include "icing/legacy/core/icing-string-util.h"
-#include <stdarg.h>
-#include <stddef.h>
-#include <stdint.h>
-#include <stdio.h>
-
#include <algorithm>
+#include <cstdarg>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
#include <string>
#include "icing/legacy/portable/icing-zlib.h"
diff --git a/icing/legacy/core/icing-string-util.h b/icing/legacy/core/icing-string-util.h
index 767e581..e5e4941 100644
--- a/icing/legacy/core/icing-string-util.h
+++ b/icing/legacy/core/icing-string-util.h
@@ -15,9 +15,8 @@
#ifndef ICING_LEGACY_CORE_ICING_STRING_UTIL_H_
#define ICING_LEGACY_CORE_ICING_STRING_UTIL_H_
-#include <stdarg.h>
-#include <stdint.h>
-
+#include <cstdarg>
+#include <cstdint>
#include <string>
#include "icing/legacy/core/icing-compat.h"
diff --git a/icing/legacy/core/icing-timer.h b/icing/legacy/core/icing-timer.h
index 49ba9ad..af38912 100644
--- a/icing/legacy/core/icing-timer.h
+++ b/icing/legacy/core/icing-timer.h
@@ -16,7 +16,8 @@
#define ICING_LEGACY_CORE_ICING_TIMER_H_
#include <sys/time.h>
-#include <time.h>
+
+#include <ctime>
namespace icing {
namespace lib {
diff --git a/icing/legacy/index/icing-array-storage.cc b/icing/legacy/index/icing-array-storage.cc
index b462135..4d2ef67 100644
--- a/icing/legacy/index/icing-array-storage.cc
+++ b/icing/legacy/index/icing-array-storage.cc
@@ -14,10 +14,10 @@
#include "icing/legacy/index/icing-array-storage.h"
-#include <inttypes.h>
#include <sys/mman.h>
#include <algorithm>
+#include <cinttypes>
#include "icing/legacy/core/icing-string-util.h"
#include "icing/legacy/core/icing-timer.h"
diff --git a/icing/legacy/index/icing-array-storage.h b/icing/legacy/index/icing-array-storage.h
index fad0565..0d93172 100644
--- a/icing/legacy/index/icing-array-storage.h
+++ b/icing/legacy/index/icing-array-storage.h
@@ -20,8 +20,7 @@
#ifndef ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_
#define ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_
-#include <stdint.h>
-
+#include <cstdint>
#include <string>
#include <vector>
diff --git a/icing/legacy/index/icing-bit-util.h b/icing/legacy/index/icing-bit-util.h
index 3273a68..d0c3f50 100644
--- a/icing/legacy/index/icing-bit-util.h
+++ b/icing/legacy/index/icing-bit-util.h
@@ -20,9 +20,8 @@
#ifndef ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_
#define ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_
-#include <stdint.h>
-#include <stdio.h>
-
+#include <cstdint>
+#include <cstdio>
#include <limits>
#include <vector>
diff --git a/icing/legacy/index/icing-dynamic-trie.cc b/icing/legacy/index/icing-dynamic-trie.cc
index 29843ba..baa043a 100644
--- a/icing/legacy/index/icing-dynamic-trie.cc
+++ b/icing/legacy/index/icing-dynamic-trie.cc
@@ -62,15 +62,15 @@
#include "icing/legacy/index/icing-dynamic-trie.h"
-#include <errno.h>
#include <fcntl.h>
-#include <inttypes.h>
-#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
+#include <cerrno>
+#include <cinttypes>
+#include <cstring>
#include <memory>
#include <utility>
diff --git a/icing/legacy/index/icing-dynamic-trie.h b/icing/legacy/index/icing-dynamic-trie.h
index 7fe290b..8821799 100644
--- a/icing/legacy/index/icing-dynamic-trie.h
+++ b/icing/legacy/index/icing-dynamic-trie.h
@@ -35,8 +35,7 @@
#ifndef ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_
#define ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_
-#include <stdint.h>
-
+#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
diff --git a/icing/legacy/index/icing-filesystem.cc b/icing/legacy/index/icing-filesystem.cc
index 90e9146..4f5e571 100644
--- a/icing/legacy/index/icing-filesystem.cc
+++ b/icing/legacy/index/icing-filesystem.cc
@@ -16,7 +16,6 @@
#include <dirent.h>
#include <dlfcn.h>
-#include <errno.h>
#include <fcntl.h>
#include <fnmatch.h>
#include <pthread.h>
@@ -27,6 +26,7 @@
#include <unistd.h>
#include <algorithm>
+#include <cerrno>
#include <unordered_set>
#include "icing/absl_ports/str_cat.h"
diff --git a/icing/legacy/index/icing-flash-bitmap.h b/icing/legacy/index/icing-flash-bitmap.h
index 3b3521a..e3ba0e2 100644
--- a/icing/legacy/index/icing-flash-bitmap.h
+++ b/icing/legacy/index/icing-flash-bitmap.h
@@ -37,8 +37,7 @@
#ifndef ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_
#define ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_
-#include <stdint.h>
-
+#include <cstdint>
#include <memory>
#include <string>
diff --git a/icing/legacy/index/icing-mmapper.cc b/icing/legacy/index/icing-mmapper.cc
index 737335c..7946c82 100644
--- a/icing/legacy/index/icing-mmapper.cc
+++ b/icing/legacy/index/icing-mmapper.cc
@@ -17,10 +17,11 @@
//
#include "icing/legacy/index/icing-mmapper.h"
-#include <errno.h>
-#include <string.h>
#include <sys/mman.h>
+#include <cerrno>
+#include <cstring>
+
#include "icing/legacy/core/icing-string-util.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/util/logging.h"
diff --git a/icing/legacy/index/icing-mock-filesystem.h b/icing/legacy/index/icing-mock-filesystem.h
index 75ac62f..122ee7b 100644
--- a/icing/legacy/index/icing-mock-filesystem.h
+++ b/icing/legacy/index/icing-mock-filesystem.h
@@ -15,16 +15,15 @@
#ifndef ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_
#define ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_
-#include <stdint.h>
-#include <stdio.h>
-#include <string.h>
-
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
#include <memory>
#include <string>
#include <vector>
-#include "icing/legacy/index/icing-filesystem.h"
#include "gmock/gmock.h"
+#include "icing/legacy/index/icing-filesystem.h"
namespace icing {
namespace lib {
diff --git a/icing/legacy/index/icing-storage-file.cc b/icing/legacy/index/icing-storage-file.cc
index b27ec67..35a4418 100644
--- a/icing/legacy/index/icing-storage-file.cc
+++ b/icing/legacy/index/icing-storage-file.cc
@@ -14,9 +14,9 @@
#include "icing/legacy/index/icing-storage-file.h"
-#include <inttypes.h>
#include <unistd.h>
+#include <cinttypes>
#include <string>
#include "icing/legacy/core/icing-compat.h"
diff --git a/icing/portable/endian.h b/icing/portable/endian.h
index 595b956..ecebb15 100644
--- a/icing/portable/endian.h
+++ b/icing/portable/endian.h
@@ -77,7 +77,7 @@
// The following guarantees declaration of the byte swap functions
#ifdef COMPILER_MSVC
-#include <stdlib.h> // NOLINT(build/include)
+#include <cstdlib> // NOLINT(build/include)
#define bswap_16(x) _byteswap_ushort(x)
#define bswap_32(x) _byteswap_ulong(x)
diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc
index 1f937fd..36c76db 100644
--- a/icing/query/query-processor.cc
+++ b/icing/query/query-processor.cc
@@ -182,7 +182,7 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) {
const Token& token = tokens.at(i);
std::unique_ptr<DocHitInfoIterator> result_iterator;
- // TODO(cassiewang): Handle negation tokens
+ // TODO(b/202076890): Handle negation tokens
switch (token.type) {
case Token::Type::QUERY_LEFT_PARENTHESES: {
frames.emplace(ParserStateFrame());
diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc
new file mode 100644
index 0000000..9c60810
--- /dev/null
+++ b/icing/query/suggestion-processor.cc
@@ -0,0 +1,93 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/suggestion-processor.h"
+
+#include "icing/tokenization/tokenizer-factory.h"
+#include "icing/tokenization/tokenizer.h"
+#include "icing/transform/normalizer.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
+SuggestionProcessor::Create(Index* index,
+ const LanguageSegmenter* language_segmenter,
+ const Normalizer* normalizer) {
+ ICING_RETURN_ERROR_IF_NULL(index);
+ ICING_RETURN_ERROR_IF_NULL(language_segmenter);
+
+ return std::unique_ptr<SuggestionProcessor>(
+ new SuggestionProcessor(index, language_segmenter, normalizer));
+}
+
+libtextclassifier3::StatusOr<std::vector<TermMetadata>>
+SuggestionProcessor::QuerySuggestions(
+ const icing::lib::SuggestionSpecProto& suggestion_spec,
+ const std::vector<NamespaceId>& namespace_ids) {
+ // We use query tokenizer to tokenize the give prefix, and we only use the
+ // last token to be the suggestion prefix.
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<Tokenizer> tokenizer,
+ tokenizer_factory::CreateIndexingTokenizer(
+ StringIndexingConfig::TokenizerType::PLAIN, &language_segmenter_));
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> iterator,
+ tokenizer->Tokenize(suggestion_spec.prefix()));
+
+ // If there are previous tokens, they are prepended to the suggestion,
+ // separated by spaces.
+ std::string last_token;
+ int token_start_pos;
+ while (iterator->Advance()) {
+ Token token = iterator->GetToken();
+ last_token = token.text;
+ token_start_pos = token.text.data() - suggestion_spec.prefix().c_str();
+ }
+
+ // If the position of the last token is not the end of the prefix, it means
+ // there should be some operator tokens after it and are ignored by the
+ // tokenizer.
+ bool is_last_token = token_start_pos + last_token.length() >=
+ suggestion_spec.prefix().length();
+
+ if (!is_last_token || last_token.empty()) {
+ // We don't have a valid last token, return early.
+ return std::vector<TermMetadata>();
+ }
+
+ std::string query_prefix =
+ suggestion_spec.prefix().substr(0, token_start_pos);
+ // Run suggestion based on given SuggestionSpec.
+ // Normalize token text to lowercase since all tokens in the lexicon are
+ // lowercase.
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<TermMetadata> terms,
+ index_.FindTermsByPrefix(normalizer_.NormalizeTerm(last_token),
+ namespace_ids, suggestion_spec.num_to_return()));
+
+ for (TermMetadata& term : terms) {
+ term.content = query_prefix + term.content;
+ }
+ return terms;
+}
+
+SuggestionProcessor::SuggestionProcessor(
+ Index* index, const LanguageSegmenter* language_segmenter,
+ const Normalizer* normalizer)
+ : index_(*index),
+ language_segmenter_(*language_segmenter),
+ normalizer_(*normalizer) {}
+
+} // namespace lib
+} // namespace icing \ No newline at end of file
diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h
new file mode 100644
index 0000000..b10dc84
--- /dev/null
+++ b/icing/query/suggestion-processor.h
@@ -0,0 +1,68 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_SUGGESTION_PROCESSOR_H_
+#define ICING_QUERY_SUGGESTION_PROCESSOR_H_
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/index.h"
+#include "icing/proto/search.pb.h"
+#include "icing/tokenization/language-segmenter.h"
+#include "icing/transform/normalizer.h"
+
+namespace icing {
+namespace lib {
+
+// Processes SuggestionSpecProtos and retrieves the specified TermMedaData that
+// satisfies the prefix and its restrictions. This also performs ranking, and
+// returns TermMetaData ordered by their hit count.
+class SuggestionProcessor {
+ public:
+ // Factory function to create a SuggestionProcessor which does not take
+ // ownership of any input components, and all pointers must refer to valid
+ // objects that outlive the created SuggestionProcessor instance.
+ //
+ // Returns:
+ // An SuggestionProcessor on success
+ // FAILED_PRECONDITION if any of the pointers is null.
+ static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>>
+ Create(Index* index, const LanguageSegmenter* language_segmenter,
+ const Normalizer* normalizer);
+
+ // Query suggestions based on the given SuggestionSpecProto.
+ //
+ // Returns:
+ // On success,
+ // - One vector that represents the entire TermMetadata
+ // INTERNAL_ERROR on all other errors
+ libtextclassifier3::StatusOr<std::vector<TermMetadata>> QuerySuggestions(
+ const SuggestionSpecProto& suggestion_spec,
+ const std::vector<NamespaceId>& namespace_ids);
+
+ private:
+ explicit SuggestionProcessor(Index* index,
+ const LanguageSegmenter* language_segmenter,
+ const Normalizer* normalizer);
+
+ // Not const because we could modify/sort the TermMetaData buffer in the lite
+ // index.
+ Index& index_;
+ const LanguageSegmenter& language_segmenter_;
+ const Normalizer& normalizer_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_SUGGESTION_PROCESSOR_H_
diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc
new file mode 100644
index 0000000..5e62277
--- /dev/null
+++ b/icing/query/suggestion-processor_test.cc
@@ -0,0 +1,324 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/suggestion-processor.h"
+
+#include "gmock/gmock.h"
+#include "icing/helpers/icu/icu-data-file-helper.h"
+#include "icing/store/document-store.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/jni-test-helpers.h"
+#include "icing/testing/test-data.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/transform/normalizer-factory.h"
+#include "unicode/uloc.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Test;
+
+class SuggestionProcessorTest : public Test {
+ protected:
+ SuggestionProcessorTest()
+ : test_dir_(GetTestTempDir() + "/icing"),
+ store_dir_(test_dir_ + "/store"),
+ index_dir_(test_dir_ + "/index") {}
+
+ void SetUp() override {
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ filesystem_.CreateDirectoryRecursively(index_dir_.c_str());
+ filesystem_.CreateDirectoryRecursively(store_dir_.c_str());
+
+ if (!IsCfStringTokenization() && !IsReverseJniTokenization()) {
+ // If we've specified using the reverse-JNI method for segmentation (i.e.
+ // not ICU), then we won't have the ICU data file included to set up.
+ // Technically, we could choose to use reverse-JNI for segmentation AND
+ // include an ICU data file, but that seems unlikely and our current BUILD
+ // setup doesn't do this.
+ ICING_ASSERT_OK(
+ // File generated via icu_data_file rule in //icing/BUILD.
+ icu_data_file_helper::SetUpICUDataFile(
+ GetTestFilePath("icing/icu.dat")));
+ }
+
+ Index::Options options(index_dir_,
+ /*index_merge_size=*/1024 * 1024);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ index_, Index::Create(options, &filesystem_, &icing_filesystem_));
+
+ language_segmenter_factory::SegmenterOptions segmenter_options(
+ ULOC_US, jni_cache_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ language_segmenter_,
+ language_segmenter_factory::Create(segmenter_options));
+
+ ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create(
+ /*max_term_byte_size=*/1000));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult create_result,
+ DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
+ schema_store_.get()));
+ document_store_ = std::move(create_result.document_store);
+ }
+
+ libtextclassifier3::Status AddTokenToIndex(
+ DocumentId document_id, SectionId section_id,
+ TermMatchType::Code term_match_type, const std::string& token) {
+ Index::Editor editor = index_->Edit(document_id, section_id,
+ term_match_type, /*namespace_id=*/0);
+ auto status = editor.BufferTerm(token.c_str());
+ return status.ok() ? editor.IndexAllBufferedTerms() : status;
+ }
+
+ void TearDown() override {
+ document_store_.reset();
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ }
+
+ Filesystem filesystem_;
+ const std::string test_dir_;
+ const std::string store_dir_;
+ std::unique_ptr<Index> index_;
+ std::unique_ptr<LanguageSegmenter> language_segmenter_;
+ std::unique_ptr<Normalizer> normalizer_;
+ std::unique_ptr<DocumentStore> document_store_;
+ std::unique_ptr<SchemaStore> schema_store_;
+ std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache();
+ FakeClock fake_clock_;
+
+ private:
+ IcingFilesystem icing_filesystem_;
+ const std::string index_dir_;
+};
+
+constexpr DocumentId kDocumentId0 = 0;
+constexpr SectionId kSectionId2 = 2;
+
+TEST_F(SuggestionProcessorTest, PrependedPrefixTokenTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix(
+ "prefix token should be prepended to the suggestion f");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms.at(0).content,
+ "prefix token should be prepended to the suggestion foo");
+}
+
+TEST_F(SuggestionProcessorTest, NonExistentPrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("nonExistTerm");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+
+ EXPECT_THAT(terms, IsEmpty());
+}
+
+TEST_F(SuggestionProcessorTest, PrefixTrailingSpaceTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("f ");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+
+ EXPECT_THAT(terms, IsEmpty());
+}
+
+TEST_F(SuggestionProcessorTest, NormalizePrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("F");
+ suggestion_spec.set_num_to_return(10);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(suggestion_spec,
+ /*namespace_ids=*/{}));
+ EXPECT_THAT(terms.at(0).content, "foo");
+
+ suggestion_spec.set_prefix("fO");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ terms, suggestion_processor->QuerySuggestions(suggestion_spec,
+ /*namespace_ids=*/{}));
+ EXPECT_THAT(terms.at(0).content, "foo");
+
+ suggestion_spec.set_prefix("Fo");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ terms, suggestion_processor->QuerySuggestions(suggestion_spec,
+ /*namespace_ids=*/{}));
+ EXPECT_THAT(terms.at(0).content, "foo");
+
+ suggestion_spec.set_prefix("FO");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ terms, suggestion_processor->QuerySuggestions(suggestion_spec,
+ /*namespace_ids=*/{}));
+ EXPECT_THAT(terms.at(0).content, "foo");
+}
+
+TEST_F(SuggestionProcessorTest, OrOperatorPrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "original"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("f OR");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+
+ // Last Operator token will be used to query suggestion
+ EXPECT_THAT(terms.at(0).content, "f original");
+}
+
+TEST_F(SuggestionProcessorTest, ParenthesesOperatorPrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("{f}");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+
+ suggestion_spec.set_prefix("[f]");
+ ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+
+ suggestion_spec.set_prefix("(f)");
+ ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+}
+
+TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "foo"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("f:");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+
+ suggestion_spec.set_prefix("f-");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ terms, suggestion_processor->QuerySuggestions(suggestion_spec,
+ /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+}
+
+TEST_F(SuggestionProcessorTest, InvalidPrefixTest) {
+ ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2,
+ TermMatchType::EXACT_ONLY, "original"),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SuggestionProcessor> suggestion_processor,
+ SuggestionProcessor::Create(index_.get(), language_segmenter_.get(),
+ normalizer_.get()));
+
+ SuggestionSpecProto suggestion_spec;
+ suggestion_spec.set_prefix("OR OR - :");
+ suggestion_spec.set_num_to_return(10);
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms,
+ suggestion_processor->QuerySuggestions(
+ suggestion_spec, /*namespace_ids=*/{}));
+ EXPECT_THAT(terms, IsEmpty());
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc
index 3307638..67528ab 100644
--- a/icing/schema/schema-store.cc
+++ b/icing/schema/schema-store.cc
@@ -491,5 +491,10 @@ SchemaStoreStorageInfoProto SchemaStore::GetStorageInfo() const {
return storage_info;
}
+libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*>
+SchemaStore::GetSectionMetadata(const std::string& schema_type) const {
+ return section_manager_->GetMetadataList(schema_type);
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/schema/schema-store.h b/icing/schema/schema-store.h
index b9be6c0..6b6528d 100644
--- a/icing/schema/schema-store.h
+++ b/icing/schema/schema-store.h
@@ -246,6 +246,12 @@ class SchemaStore {
// INTERNAL_ERROR on compute error
libtextclassifier3::StatusOr<Crc32> ComputeChecksum() const;
+ // Returns:
+ // - On success, the section metadata list for the specified schema type
+ // - NOT_FOUND if the schema type is not present in the schema
+ libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*>
+ GetSectionMetadata(const std::string& schema_type) const;
+
// Calculates the StorageInfo for the Schema Store.
//
// If an IO error occurs while trying to calculate the value for a field, then
diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc
index 4822d7f..28d385e 100644
--- a/icing/scoring/bm25f-calculator.cc
+++ b/icing/scoring/bm25f-calculator.cc
@@ -26,6 +26,7 @@
#include "icing/store/corpus-associated-scoring-data.h"
#include "icing/store/corpus-id.h"
#include "icing/store/document-associated-score-data.h"
+#include "icing/store/document-filter-data.h"
#include "icing/store/document-id.h"
namespace icing {
@@ -42,8 +43,11 @@ constexpr float k1_ = 1.2f;
constexpr float b_ = 0.7f;
// TODO(b/158603900): add tests for Bm25fCalculator
-Bm25fCalculator::Bm25fCalculator(const DocumentStore* document_store)
- : document_store_(document_store) {}
+Bm25fCalculator::Bm25fCalculator(
+ const DocumentStore* document_store,
+ std::unique_ptr<SectionWeights> section_weights)
+ : document_store_(document_store),
+ section_weights_(std::move(section_weights)) {}
// During initialization, Bm25fCalculator iterates through
// hit-iterators for each query term to pre-compute n(q_i) for each corpus under
@@ -121,9 +125,9 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it,
// Compute inverse document frequency (IDF) weight for query term in the given
// corpus, and cache it in the map.
//
-// N - n(q_i) + 0.5
-// IDF(q_i) = log(1 + ------------------)
-// n(q_i) + 0.5
+// N - n(q_i) + 0.5
+// IDF(q_i) = ln(1 + ------------------)
+// n(q_i) + 0.5
//
// where N is the number of documents in the corpus, and n(q_i) is the number
// of documents in the corpus containing the query term q_i.
@@ -149,7 +153,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term,
uint32_t num_docs = csdata.num_docs();
uint32_t nqi = corpus_nqi_map_[corpus_term_info.value];
float idf =
- nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi - 0.5f)) : 0.0f;
+ nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f;
corpus_idf_map_.insert({corpus_term_info.value, idf});
ICING_VLOG(1) << IcingStringUtil::StringPrintf(
"corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id,
@@ -158,6 +162,11 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term,
}
// Get per corpus average document length and cache the result in the map.
+// The average doc length is calculated as:
+//
+// total_tokens_in_corpus
+// Avg Doc Length = -------------------------
+// num_docs_in_corpus + 1
float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) {
auto iter = corpus_avgdl_map_.find(corpus_id);
if (iter != corpus_avgdl_map_.end()) {
@@ -191,8 +200,8 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency(
const DocumentAssociatedScoreData& data) {
uint32_t dl = data.length_in_tokens();
float avgdl = GetCorpusAvgDocLength(data.corpus_id());
- float f_q =
- ComputeTermFrequencyForMatchedSections(data.corpus_id(), term_match_info);
+ float f_q = ComputeTermFrequencyForMatchedSections(
+ data.corpus_id(), term_match_info, hit_info.document_id());
float normalized_tf =
f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl));
@@ -202,23 +211,41 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency(
return normalized_tf;
}
-// Note: once we support section weights, we should update this function to
-// compute the weighted term frequency.
float Bm25fCalculator::ComputeTermFrequencyForMatchedSections(
- CorpusId corpus_id, const TermMatchInfo& term_match_info) const {
+ CorpusId corpus_id, const TermMatchInfo& term_match_info,
+ DocumentId document_id) const {
float sum = 0.0f;
SectionIdMask sections = term_match_info.section_ids_mask;
+ SchemaTypeId schema_type_id = GetSchemaTypeId(document_id);
+
while (sections != 0) {
SectionId section_id = __builtin_ctz(sections);
sections &= ~(1u << section_id);
Hit::TermFrequency tf = term_match_info.term_frequencies[section_id];
+ double weighted_tf = tf * section_weights_->GetNormalizedSectionWeight(
+ schema_type_id, section_id);
if (tf != Hit::kNoTermFrequency) {
- sum += tf;
+ sum += weighted_tf;
}
}
return sum;
}
+SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const {
+ auto filter_data_or = document_store_->GetDocumentFilterData(document_id);
+ if (!filter_data_or.ok()) {
+ // This should never happen. The only failure case for
+ // GetDocumentFilterData is if the document_id is outside of the range of
+ // allocated document_ids, which shouldn't be possible since we're getting
+ // this document_id from the posting lists.
+ ICING_LOG(WARNING) << IcingStringUtil::StringPrintf(
+ "No document filter data for document [%d]", document_id);
+ return kInvalidSchemaTypeId;
+ }
+ DocumentFilterData data = filter_data_or.ValueOrDie();
+ return data.schema_type_id();
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/bm25f-calculator.h b/icing/scoring/bm25f-calculator.h
index 91b4f24..05009d8 100644
--- a/icing/scoring/bm25f-calculator.h
+++ b/icing/scoring/bm25f-calculator.h
@@ -22,6 +22,7 @@
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/legacy/index/icing-bit-util.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/corpus-id.h"
#include "icing/store/document-store.h"
@@ -62,7 +63,8 @@ namespace lib {
// see: glossary/bm25
class Bm25fCalculator {
public:
- explicit Bm25fCalculator(const DocumentStore *document_store_);
+ explicit Bm25fCalculator(const DocumentStore *document_store_,
+ std::unique_ptr<SectionWeights> section_weights_);
// Precompute and cache statistics relevant to BM25F.
// Populates term_id_map_ and corpus_nqi_map_ for use while scoring other
@@ -108,18 +110,43 @@ class Bm25fCalculator {
}
};
+ // Returns idf weight for the term and provided corpus.
float GetCorpusIdfWeightForTerm(std::string_view term, CorpusId corpus_id);
+
+ // Returns the average document length for the corpus. The average is
+ // calculated as the sum of tokens in the corpus' documents over the total
+ // number of documents plus one.
float GetCorpusAvgDocLength(CorpusId corpus_id);
+
+ // Returns the normalized term frequency for the term match and document hit.
+ // This normalizes the term frequency by applying smoothing parameters and
+ // factoring document length.
float ComputedNormalizedTermFrequency(
const TermMatchInfo &term_match_info, const DocHitInfo &hit_info,
const DocumentAssociatedScoreData &data);
+
+ // Returns the weighted term frequency for the term match and document. For
+ // each section the term is present, we scale the term frequency by its
+ // section weight. We return the sum of the weighted term frequencies over all
+ // sections.
float ComputeTermFrequencyForMatchedSections(
- CorpusId corpus_id, const TermMatchInfo &term_match_info) const;
+ CorpusId corpus_id, const TermMatchInfo &term_match_info,
+ DocumentId document_id) const;
+ // Returns the schema type id for the document by retrieving it from the
+ // DocumentFilterData.
+ SchemaTypeId GetSchemaTypeId(DocumentId document_id) const;
+
+ // Clears cached scoring data and prepares the calculator for a new scoring
+ // run.
void Clear();
const DocumentStore *document_store_; // Does not own.
+ // Used for accessing normalized section weights when computing the weighted
+ // term frequency.
+ std::unique_ptr<SectionWeights> section_weights_;
+
// Map from query term to compact term ID.
// Necessary as a key to the other maps.
// The use of the string_view as key here means that the query_term_iterators
@@ -130,7 +157,6 @@ class Bm25fCalculator {
// Necessary to calculate the normalized term frequency.
// This information is cached in the DocumentStore::CorpusScoreCache
std::unordered_map<CorpusId, float> corpus_avgdl_map_;
-
// Map from <corpus ID, term ID> to number of documents containing term q_i,
// called n(q_i).
// Necessary to calculate IDF(q_i) (inverse document frequency).
diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc
index e940e98..cc1d995 100644
--- a/icing/scoring/score-and-rank_benchmark.cc
+++ b/icing/scoring/score-and-rank_benchmark.cc
@@ -117,7 +117,8 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) {
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get()));
+ ScoringProcessor::Create(scoring_spec, document_store.get(),
+ schema_store.get()));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -220,7 +221,8 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) {
ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get()));
+ ScoringProcessor::Create(scoring_spec, document_store.get(),
+ schema_store.get()));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -322,7 +324,8 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) {
scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(scoring_spec, document_store.get()));
+ ScoringProcessor::Create(scoring_spec, document_store.get(),
+ schema_store.get()));
int num_to_score = state.range(0);
int num_of_documents = state.range(1);
@@ -390,6 +393,122 @@ BENCHMARK(BM_ScoreAndRankDocumentHitsNoScoring)
->ArgPair(10000, 18000)
->ArgPair(10000, 20000);
+void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) {
+ const std::string base_dir = GetTestTempDir() + "/score_and_rank_benchmark";
+ const std::string document_store_dir = base_dir + "/document_store";
+ const std::string schema_store_dir = base_dir + "/schema_store";
+
+ // Creates file directories
+ Filesystem filesystem;
+ filesystem.DeleteDirectoryRecursively(base_dir.c_str());
+ filesystem.CreateDirectoryRecursively(document_store_dir.c_str());
+ filesystem.CreateDirectoryRecursively(schema_store_dir.c_str());
+
+ Clock clock;
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SchemaStore> schema_store,
+ SchemaStore::Create(&filesystem, base_dir, &clock));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult create_result,
+ DocumentStore::Create(&filesystem, document_store_dir, &clock,
+ schema_store.get()));
+ std::unique_ptr<DocumentStore> document_store =
+ std::move(create_result.document_store);
+
+ ICING_ASSERT_OK(schema_store->SetSchema(CreateSchemaWithEmailType()));
+
+ ScoringSpecProto scoring_spec;
+ scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor,
+ ScoringProcessor::Create(scoring_spec, document_store.get(),
+ schema_store.get()));
+
+ int num_to_score = state.range(0);
+ int num_of_documents = state.range(1);
+
+ std::mt19937 random_generator;
+ std::uniform_int_distribution<int> distribution(
+ 1, std::numeric_limits<int>::max());
+
+ SectionId section_id = 0;
+ SectionIdMask section_id_mask = 1U << section_id;
+
+ // Puts documents into document store
+ std::vector<DocHitInfo> doc_hit_infos;
+ for (int i = 0; i < num_of_documents; i++) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store->Put(CreateEmailDocument(
+ /*id=*/i, /*document_score=*/1,
+ /*creation_timestamp_ms=*/1),
+ /*num_tokens=*/10));
+ DocHitInfo doc_hit = DocHitInfo(document_id, section_id_mask);
+ // Set five matches for term "foo" for each document hit.
+ doc_hit.UpdateSection(section_id, /*hit_term_frequency=*/5);
+ doc_hit_infos.push_back(doc_hit);
+ }
+
+ ScoredDocumentHitComparator scored_document_hit_comparator(
+ /*is_descending=*/true);
+
+ for (auto _ : state) {
+ // Creates a dummy DocHitInfoIterator with results, we need to pause the
+ // timer here so that the cost of copying test data is not included.
+ state.PauseTiming();
+ std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+ // Create a query term iterator that assigns the document hits to term
+ // "foo".
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators;
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+ state.ResumeTiming();
+
+ std::vector<ScoredDocumentHit> scored_document_hits =
+ scoring_processor->Score(std::move(doc_hit_info_iterator), num_to_score,
+ &query_term_iterators);
+
+ BuildHeapInPlace(&scored_document_hits, scored_document_hit_comparator);
+ // Ranks and gets the first page, 20 is a common page size
+ std::vector<ScoredDocumentHit> results =
+ PopTopResultsFromHeap(&scored_document_hits, /*num_results=*/20,
+ scored_document_hit_comparator);
+ }
+
+ // Clean up
+ document_store.reset();
+ schema_store.reset();
+ filesystem.DeleteDirectoryRecursively(base_dir.c_str());
+}
+BENCHMARK(BM_ScoreAndRankDocumentHitsByRelevanceScoring)
+ // num_to_score, num_of_documents in document store
+ ->ArgPair(1000, 30000)
+ ->ArgPair(3000, 30000)
+ ->ArgPair(5000, 30000)
+ ->ArgPair(7000, 30000)
+ ->ArgPair(9000, 30000)
+ ->ArgPair(11000, 30000)
+ ->ArgPair(13000, 30000)
+ ->ArgPair(15000, 30000)
+ ->ArgPair(17000, 30000)
+ ->ArgPair(19000, 30000)
+ ->ArgPair(21000, 30000)
+ ->ArgPair(23000, 30000)
+ ->ArgPair(25000, 30000)
+ ->ArgPair(27000, 30000)
+ ->ArgPair(29000, 30000)
+ // Starting from this line, we're trying to see if num_of_documents affects
+ // performance
+ ->ArgPair(10000, 10000)
+ ->ArgPair(10000, 12000)
+ ->ArgPair(10000, 14000)
+ ->ArgPair(10000, 16000)
+ ->ArgPair(10000, 18000)
+ ->ArgPair(10000, 20000);
+
} // namespace
} // namespace lib
diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer.cc
index a4734b4..5f33e66 100644
--- a/icing/scoring/scorer.cc
+++ b/icing/scoring/scorer.cc
@@ -22,6 +22,7 @@
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/scoring.pb.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/util/status-macros.h"
@@ -156,11 +157,12 @@ class NoScorer : public Scorer {
};
libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
- ScoringSpecProto::RankingStrategy::Code rank_by, double default_score,
- const DocumentStore* document_store) {
+ const ScoringSpecProto& scoring_spec, double default_score,
+ const DocumentStore* document_store, const SchemaStore* schema_store) {
ICING_RETURN_ERROR_IF_NULL(document_store);
+ ICING_RETURN_ERROR_IF_NULL(schema_store);
- switch (rank_by) {
+ switch (scoring_spec.rank_by()) {
case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE:
return std::make_unique<DocumentScoreScorer>(document_store,
default_score);
@@ -168,7 +170,12 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
return std::make_unique<DocumentCreationTimestampScorer>(document_store,
default_score);
case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: {
- auto bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store);
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store, scoring_spec));
+
+ auto bm25f_calculator = std::make_unique<Bm25fCalculator>(
+ document_store, std::move(section_weights));
return std::make_unique<RelevanceScoreScorer>(std::move(bm25f_calculator),
default_score);
}
@@ -183,8 +190,8 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP:
[[fallthrough]];
case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
- return std::make_unique<UsageScorer>(document_store, rank_by,
- default_score);
+ return std::make_unique<UsageScorer>(
+ document_store, scoring_spec.rank_by(), default_score);
case ScoringSpecProto::RankingStrategy::NONE:
return std::make_unique<NoScorer>(default_score);
}
diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h
index a22db0f..abdd5ca 100644
--- a/icing/scoring/scorer.h
+++ b/icing/scoring/scorer.h
@@ -43,8 +43,8 @@ class Scorer {
// FAILED_PRECONDITION on any null pointer input
// INVALID_ARGUMENT if fails to create an instance
static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
- ScoringSpecProto::RankingStrategy::Code rank_by, double default_score,
- const DocumentStore* document_store);
+ const ScoringSpecProto& scoring_spec, double default_score,
+ const DocumentStore* document_store, const SchemaStore* schema_store);
// Returns a non-negative score of a document. The score can be a
// document-associated score which comes from the DocumentProto directly, an
diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc
index 8b89514..f22a31a 100644
--- a/icing/scoring/scorer_test.cc
+++ b/icing/scoring/scorer_test.cc
@@ -27,6 +27,7 @@
#include "icing/proto/scoring.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
+#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/testing/common-matchers.h"
@@ -91,6 +92,8 @@ class ScorerTest : public testing::Test {
DocumentStore* document_store() { return document_store_.get(); }
+ SchemaStore* schema_store() { return schema_store_.get(); }
+
const FakeClock& fake_clock1() { return fake_clock1_; }
const FakeClock& fake_clock2() { return fake_clock2_; }
@@ -121,17 +124,37 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri,
return usage_report;
}
-TEST_F(ScorerTest, CreationWithNullPointerShouldFail) {
- EXPECT_THAT(Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/0, /*document_store=*/nullptr),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ScoringSpecProto CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::Code ranking_strategy) {
+ ScoringSpecProto scoring_spec;
+ scoring_spec.set_rank_by(ranking_strategy);
+ return scoring_spec;
+}
+
+TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) {
+ EXPECT_THAT(
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, /*document_store=*/nullptr,
+ schema_store()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+}
+
+TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) {
+ EXPECT_THAT(
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, document_store(),
+ /*schema_store=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/10, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
// Non existent document id
DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1);
@@ -153,8 +176,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/10, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -185,8 +209,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/10, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -213,8 +238,9 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/10, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
@@ -235,8 +261,9 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
@@ -259,8 +286,9 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE,
- /*default_score=*/10, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
@@ -290,8 +318,9 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) {
document_store()->Put(test_document2));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo1 = DocHitInfo(document_id1);
DocHitInfo docHitInfo2 = DocHitInfo(document_id2);
@@ -316,16 +345,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -357,16 +389,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -398,16 +433,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -439,19 +477,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -499,19 +540,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -559,19 +603,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -607,8 +654,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(ScoringSpecProto::RankingStrategy::NONE,
- /*default_score=*/3, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE),
+ /*default_score=*/3, document_store(), schema_store()));
DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0);
DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1);
@@ -618,8 +666,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer, Scorer::Create(ScoringSpecProto::RankingStrategy::NONE,
- /*default_score=*/111, document_store()));
+ scorer,
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE),
+ /*default_score=*/111, document_store(), schema_store()));
docHitInfo1 = DocHitInfo(/*document_id_in=*/4);
docHitInfo2 = DocHitInfo(/*document_id_in=*/5);
@@ -643,9 +693,10 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP,
- /*default_score=*/0, document_store()));
+ Scorer::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
// Create usage report for the maximum allowable timestamp.
diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc
index 24480ef..e36f3bb 100644
--- a/icing/scoring/scoring-processor.cc
+++ b/icing/scoring/scoring-processor.cc
@@ -39,19 +39,20 @@ constexpr double kDefaultScoreInAscendingOrder =
libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
ScoringProcessor::Create(const ScoringSpecProto& scoring_spec,
- const DocumentStore* document_store) {
+ const DocumentStore* document_store,
+ const SchemaStore* schema_store) {
ICING_RETURN_ERROR_IF_NULL(document_store);
+ ICING_RETURN_ERROR_IF_NULL(schema_store);
bool is_descending_order =
scoring_spec.order_by() == ScoringSpecProto::Order::DESC;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(scoring_spec.rank_by(),
+ Scorer::Create(scoring_spec,
is_descending_order ? kDefaultScoreInDescendingOrder
: kDefaultScoreInAscendingOrder,
- document_store));
-
+ document_store, schema_store));
// Using `new` to access a non-public constructor.
return std::unique_ptr<ScoringProcessor>(
new ScoringProcessor(std::move(scorer)));
diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h
index 2289605..e7d09b1 100644
--- a/icing/scoring/scoring-processor.h
+++ b/icing/scoring/scoring-processor.h
@@ -40,8 +40,8 @@ class ScoringProcessor {
// A ScoringProcessor on success
// FAILED_PRECONDITION on any null pointer input
static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create(
- const ScoringSpecProto& scoring_spec,
- const DocumentStore* document_store);
+ const ScoringSpecProto& scoring_spec, const DocumentStore* document_store,
+ const SchemaStore* schema_store);
// Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns
// a vector of ScoredDocumentHits. The size of results is no more than
diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc
index 125e2a7..7e5cb0f 100644
--- a/icing/scoring/scoring-processor_test.cc
+++ b/icing/scoring/scoring-processor_test.cc
@@ -69,11 +69,24 @@ class ScoringProcessorTest : public testing::Test {
// Creates a simple email schema
SchemaProto test_email_schema =
SchemaBuilder()
- .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty(
- PropertyConfigBuilder()
- .SetName("subject")
- .SetDataType(TYPE_STRING)
- .SetCardinality(CARDINALITY_OPTIONAL)))
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("email")
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("subject")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("body")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
.Build();
ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema));
}
@@ -86,6 +99,8 @@ class ScoringProcessorTest : public testing::Test {
DocumentStore* document_store() { return document_store_.get(); }
+ SchemaStore* schema_store() { return schema_store_.get(); }
+
private:
const std::string test_dir_;
const std::string doc_store_dir_;
@@ -139,16 +154,46 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri,
return usage_report;
}
-TEST_F(ScoringProcessorTest, CreationWithNullPointerShouldFail) {
+TypePropertyWeights CreateTypePropertyWeights(
+ std::string schema_type, std::vector<PropertyWeight> property_weights) {
+ TypePropertyWeights type_property_weights;
+ type_property_weights.set_schema_type(std::move(schema_type));
+ type_property_weights.mutable_property_weights()->Reserve(
+ property_weights.size());
+
+ for (PropertyWeight& property_weight : property_weights) {
+ *type_property_weights.add_property_weights() = std::move(property_weight);
+ }
+
+ return type_property_weights;
+}
+
+PropertyWeight CreatePropertyWeight(std::string path, double weight) {
+ PropertyWeight property_weight;
+ property_weight.set_path(std::move(path));
+ property_weight.set_weight(weight);
+ return property_weight;
+}
+
+TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) {
+ ScoringSpecProto spec_proto;
+ EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr,
+ schema_store()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+}
+
+TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) {
ScoringSpecProto spec_proto;
- EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr),
+ EXPECT_THAT(ScoringProcessor::Create(spec_proto, document_store(),
+ /*schema_store=*/nullptr),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScoringProcessorTest, ShouldCreateInstance) {
ScoringSpecProto spec_proto;
spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
- ICING_EXPECT_OK(ScoringProcessor::Create(spec_proto, document_store()));
+ ICING_EXPECT_OK(
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
}
TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
@@ -163,7 +208,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/5),
@@ -189,7 +234,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/-1),
@@ -219,7 +264,7 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/2),
@@ -251,7 +296,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) {
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -306,7 +351,7 @@ TEST_F(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -316,11 +361,11 @@ TEST_F(ScoringProcessorTest,
// the document's length determines the final score. Document shorter than the
// average corpus length are slightly boosted.
ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask,
- /*score=*/0.255482);
+ /*score=*/0.187114);
ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask,
- /*score=*/0.115927);
+ /*score=*/0.084904);
ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask,
- /*score=*/0.166435);
+ /*score=*/0.121896);
EXPECT_THAT(
scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3, &query_term_iterators),
@@ -375,7 +420,7 @@ TEST_F(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -384,11 +429,11 @@ TEST_F(ScoringProcessorTest,
// Since the three documents all contain the query term "foo" exactly once
// and they have the same length, they will have the same BM25F scoret.
ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask,
- /*score=*/0.16173716);
+ /*score=*/0.118455);
ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask,
- /*score=*/0.16173716);
+ /*score=*/0.118455);
ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask,
- /*score=*/0.16173716);
+ /*score=*/0.118455);
EXPECT_THAT(
scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3, &query_term_iterators),
@@ -448,7 +493,7 @@ TEST_F(ScoringProcessorTest,
// Creates a ScoringProcessor
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
query_term_iterators;
@@ -457,11 +502,11 @@ TEST_F(ScoringProcessorTest,
// Since the three documents all have the same length, the score is decided by
// the frequency of the query term "foo".
ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1,
- /*score=*/0.309497);
+ /*score=*/0.226674);
ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask2,
- /*score=*/0.16173716);
+ /*score=*/0.118455);
ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask3,
- /*score=*/0.268599);
+ /*score=*/0.196720);
EXPECT_THAT(
scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3, &query_term_iterators),
@@ -470,6 +515,280 @@ TEST_F(ScoringProcessorTest,
EqualsScoredDocumentHit(expected_scored_doc_hit3)));
}
+TEST_F(ScoringProcessorTest,
+ ShouldScoreByRelevanceScore_HitTermWithZeroFrequency) {
+ DocumentProto document1 =
+ CreateDocument("icing", "email/1", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store()->Put(document1, /*num_tokens=*/10));
+
+ // Document 1 contains the term "foo" 0 times in the "subject" property
+ DocHitInfo doc_hit_info1(document_id1);
+ doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/0);
+
+ // Creates input doc_hit_infos and expected output scored_document_hits
+ std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1};
+
+ // Creates a dummy DocHitInfoIterator with 1 result for the query "foo"
+ std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ ScoringSpecProto spec_proto;
+ spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+
+ // Creates a ScoringProcessor
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor,
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
+
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators;
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ SectionIdMask section_id_mask1 = 0b00000001;
+
+ // Since the document hit has zero frequency, expect a score of zero.
+ ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1,
+ /*score=*/0.000000);
+ EXPECT_THAT(
+ scoring_processor->Score(std::move(doc_hit_info_iterator),
+ /*num_to_score=*/1, &query_term_iterators),
+ ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1)));
+}
+
+TEST_F(ScoringProcessorTest,
+ ShouldScoreByRelevanceScore_SameHitFrequencyDifferentPropertyWeights) {
+ DocumentProto document1 =
+ CreateDocument("icing", "email/1", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+ DocumentProto document2 =
+ CreateDocument("icing", "email/2", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store()->Put(document1, /*num_tokens=*/1));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id2,
+ document_store()->Put(document2, /*num_tokens=*/1));
+
+ // Document 1 contains the term "foo" 1 time in the "body" property
+ SectionId body_section_id = 0;
+ DocHitInfo doc_hit_info1(document_id1);
+ doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1);
+
+ // Document 2 contains the term "foo" 1 time in the "subject" property
+ SectionId subject_section_id = 1;
+ DocHitInfo doc_hit_info2(document_id2);
+ doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1);
+
+ // Creates input doc_hit_infos and expected output scored_document_hits
+ std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2};
+
+ // Creates a dummy DocHitInfoIterator with 2 results for the query "foo"
+ std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ ScoringSpecProto spec_proto;
+ spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+
+ PropertyWeight body_property_weight =
+ CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5);
+ PropertyWeight subject_property_weight =
+ CreatePropertyWeight(/*path=*/"subject", /*weight=*/2.0);
+ *spec_proto.add_type_property_weights() = CreateTypePropertyWeights(
+ /*schema_type=*/"email", {body_property_weight, subject_property_weight});
+
+ // Creates a ScoringProcessor
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor,
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
+
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators;
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ SectionIdMask body_section_id_mask = 1U << body_section_id;
+ SectionIdMask subject_section_id_mask = 1U << subject_section_id;
+
+ // We expect document 2 to have a higher score than document 1 as it matches
+ // "foo" in the "subject" property, which is weighed higher than the "body"
+ // property. Final scores are computed with smoothing applied.
+ ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask,
+ /*score=*/0.053624);
+ ScoredDocumentHit expected_scored_doc_hit2(document_id2,
+ subject_section_id_mask,
+ /*score=*/0.153094);
+ EXPECT_THAT(
+ scoring_processor->Score(std::move(doc_hit_info_iterator),
+ /*num_to_score=*/2, &query_term_iterators),
+ ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1),
+ EqualsScoredDocumentHit(expected_scored_doc_hit2)));
+}
+
+TEST_F(ScoringProcessorTest,
+ ShouldScoreByRelevanceScore_WithImplicitPropertyWeight) {
+ DocumentProto document1 =
+ CreateDocument("icing", "email/1", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+ DocumentProto document2 =
+ CreateDocument("icing", "email/2", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store()->Put(document1, /*num_tokens=*/1));
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id2,
+ document_store()->Put(document2, /*num_tokens=*/1));
+
+ // Document 1 contains the term "foo" 1 time in the "body" property
+ SectionId body_section_id = 0;
+ DocHitInfo doc_hit_info1(document_id1);
+ doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1);
+
+ // Document 2 contains the term "foo" 1 time in the "subject" property
+ SectionId subject_section_id = 1;
+ DocHitInfo doc_hit_info2(document_id2);
+ doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1);
+
+ // Creates input doc_hit_infos and expected output scored_document_hits
+ std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2};
+
+ // Creates a dummy DocHitInfoIterator with 2 results for the query "foo"
+ std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ ScoringSpecProto spec_proto;
+ spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+
+ PropertyWeight body_property_weight =
+ CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5);
+ *spec_proto.add_type_property_weights() = CreateTypePropertyWeights(
+ /*schema_type=*/"email", {body_property_weight});
+
+ // Creates a ScoringProcessor
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor,
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
+
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators;
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ SectionIdMask body_section_id_mask = 1U << body_section_id;
+ SectionIdMask subject_section_id_mask = 1U << subject_section_id;
+
+ // We expect document 2 to have a higher score than document 1 as it matches
+ // "foo" in the "subject" property, which is weighed higher than the "body"
+ // property. This is because the "subject" property is implictly given a
+ // a weight of 1.0, the default weight value. Final scores are computed with
+ // smoothing applied.
+ ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask,
+ /*score=*/0.094601);
+ ScoredDocumentHit expected_scored_doc_hit2(document_id2,
+ subject_section_id_mask,
+ /*score=*/0.153094);
+ EXPECT_THAT(
+ scoring_processor->Score(std::move(doc_hit_info_iterator),
+ /*num_to_score=*/2, &query_term_iterators),
+ ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1),
+ EqualsScoredDocumentHit(expected_scored_doc_hit2)));
+}
+
+TEST_F(ScoringProcessorTest,
+ ShouldScoreByRelevanceScore_WithDefaultPropertyWeight) {
+ DocumentProto document1 =
+ CreateDocument("icing", "email/1", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+ DocumentProto document2 =
+ CreateDocument("icing", "email/2", kDefaultScore,
+ /*creation_timestamp_ms=*/kDefaultCreationTimestampMs);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id1,
+ document_store()->Put(document1, /*num_tokens=*/1));
+
+ // Document 1 contains the term "foo" 1 time in the "body" property
+ SectionId body_section_id = 0;
+ DocHitInfo doc_hit_info1(document_id1);
+ doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1);
+
+ // Creates input doc_hit_infos and expected output scored_document_hits
+ std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1};
+
+ // Creates a dummy DocHitInfoIterator with 1 result for the query "foo"
+ std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ ScoringSpecProto spec_proto;
+ spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+
+ *spec_proto.add_type_property_weights() =
+ CreateTypePropertyWeights(/*schema_type=*/"email", {});
+
+ // Creates a ScoringProcessor with no explicit weights set.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor,
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
+
+ ScoringSpecProto spec_proto_with_weights;
+ spec_proto_with_weights.set_rank_by(
+ ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE);
+
+ PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body",
+ /*weight=*/1.0);
+ *spec_proto_with_weights.add_type_property_weights() =
+ CreateTypePropertyWeights(/*schema_type=*/"email",
+ {body_property_weight});
+
+ // Creates a ScoringProcessor with default weight set for "body" property.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoringProcessor> scoring_processor_with_weights,
+ ScoringProcessor::Create(spec_proto_with_weights, document_store(),
+ schema_store()));
+
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators;
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ // Create a doc hit iterator
+ std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>
+ query_term_iterators_scoring_with_weights;
+ query_term_iterators_scoring_with_weights["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ SectionIdMask body_section_id_mask = 1U << body_section_id;
+
+ // We expect document 1 to have the same score whether a weight is explicitly
+ // set to 1.0 or implictly scored with the default weight. Final scores are
+ // computed with smoothing applied.
+ ScoredDocumentHit expected_scored_doc_hit(document_id1, body_section_id_mask,
+ /*score=*/0.208191);
+ EXPECT_THAT(
+ scoring_processor->Score(std::move(doc_hit_info_iterator),
+ /*num_to_score=*/1, &query_term_iterators),
+ ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit)));
+
+ // Restore ownership of doc hit iterator and query term iterator to test.
+ doc_hit_info_iterator =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+ query_term_iterators["foo"] =
+ std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo");
+
+ EXPECT_THAT(scoring_processor_with_weights->Score(
+ std::move(doc_hit_info_iterator),
+ /*num_to_score=*/1, &query_term_iterators),
+ ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit)));
+}
+
TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) {
DocumentProto document1 =
CreateDocument("icing", "email/1", kDefaultScore,
@@ -509,7 +828,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -569,7 +888,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -629,7 +948,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
@@ -665,7 +984,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/4),
ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default),
@@ -714,7 +1033,7 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) {
// Creates a ScoringProcessor which ranks in descending order
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ScoringProcessor> scoring_processor,
- ScoringProcessor::Create(spec_proto, document_store()));
+ ScoringProcessor::Create(spec_proto, document_store(), schema_store()));
EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator),
/*num_to_score=*/3),
diff --git a/icing/scoring/section-weights.cc b/icing/scoring/section-weights.cc
new file mode 100644
index 0000000..c4afe7f
--- /dev/null
+++ b/icing/scoring/section-weights.cc
@@ -0,0 +1,146 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/section-weights.h"
+
+#include <cfloat>
+#include <unordered_map>
+#include <utility>
+
+#include "icing/proto/scoring.pb.h"
+#include "icing/schema/section.h"
+#include "icing/util/logging.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+// Normalizes all weights in the map to be in range (0.0, 1.0], where the max
+// weight is normalized to 1.0.
+inline void NormalizeSectionWeights(
+ double max_weight, std::unordered_map<SectionId, double>& section_weights) {
+ for (auto& raw_weight : section_weights) {
+ raw_weight.second = raw_weight.second / max_weight;
+ }
+}
+} // namespace
+
+libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>>
+SectionWeights::Create(const SchemaStore* schema_store,
+ const ScoringSpecProto& scoring_spec) {
+ ICING_RETURN_ERROR_IF_NULL(schema_store);
+
+ std::unordered_map<SchemaTypeId, NormalizedSectionWeights>
+ schema_property_weight_map;
+ for (const TypePropertyWeights& type_property_weights :
+ scoring_spec.type_property_weights()) {
+ std::string_view schema_type = type_property_weights.schema_type();
+ auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type);
+ if (!schema_type_id_or.ok()) {
+ ICING_LOG(WARNING) << "No schema type id found for schema type: "
+ << schema_type;
+ continue;
+ }
+ SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie();
+ auto section_metadata_list_or =
+ schema_store->GetSectionMetadata(schema_type.data());
+ if (!section_metadata_list_or.ok()) {
+ ICING_LOG(WARNING) << "No metadata found for schema type: "
+ << schema_type;
+ continue;
+ }
+
+ const std::vector<SectionMetadata>* metadata_list =
+ section_metadata_list_or.ValueOrDie();
+
+ std::unordered_map<std::string, double> property_paths_weights;
+ for (const PropertyWeight& property_weight :
+ type_property_weights.property_weights()) {
+ double property_path_weight = property_weight.weight();
+
+ // Return error on negative and zero weights.
+ if (property_path_weight <= 0.0) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Property weight for property path \"%s\" is negative or zero. "
+ "Negative and zero weights are invalid.",
+ property_weight.path().c_str()));
+ }
+ property_paths_weights.insert(
+ {property_weight.path(), property_path_weight});
+ }
+ NormalizedSectionWeights normalized_section_weights =
+ ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list);
+
+ schema_property_weight_map.insert(
+ {schema_type_id,
+ {/*section_weights*/ std::move(
+ normalized_section_weights.section_weights),
+ /*default_weight*/ normalized_section_weights.default_weight}});
+ }
+ // Using `new` to access a non-public constructor.
+ return std::unique_ptr<SectionWeights>(
+ new SectionWeights(std::move(schema_property_weight_map)));
+}
+
+double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id,
+ SectionId section_id) const {
+ auto schema_type_map = schema_section_weight_map_.find(schema_type_id);
+ if (schema_type_map == schema_section_weight_map_.end()) {
+ // Return default weight if the schema type has no weights specified.
+ return kDefaultSectionWeight;
+ }
+
+ auto section_weight =
+ schema_type_map->second.section_weights.find(section_id);
+ if (section_weight == schema_type_map->second.section_weights.end()) {
+ // If there is no entry for SectionId, the weight is implicitly the
+ // normalized default weight.
+ return schema_type_map->second.default_weight;
+ }
+ return section_weight->second;
+}
+
+inline SectionWeights::NormalizedSectionWeights
+SectionWeights::ExtractNormalizedSectionWeights(
+ const std::unordered_map<std::string, double>& raw_weights,
+ const std::vector<SectionMetadata>& metadata_list) {
+ double max_weight = 0.0;
+ std::unordered_map<SectionId, double> section_weights;
+ for (const SectionMetadata& section_metadata : metadata_list) {
+ std::string_view metadata_path = section_metadata.path;
+ double section_weight = kDefaultSectionWeight;
+ auto iter = raw_weights.find(metadata_path.data());
+ if (iter != raw_weights.end()) {
+ section_weight = iter->second;
+ section_weights.insert({section_metadata.id, section_weight});
+ }
+ // Replace max if we see new max weight.
+ max_weight = std::max(max_weight, section_weight);
+ }
+
+ NormalizeSectionWeights(max_weight, section_weights);
+ // Set normalized default weight to 1.0 in case there is no section
+ // metadata and max_weight is 0.0 (we should not see this case).
+ double normalized_default_weight = max_weight == 0.0
+ ? kDefaultSectionWeight
+ : kDefaultSectionWeight / max_weight;
+ SectionWeights::NormalizedSectionWeights normalized_section_weights =
+ SectionWeights::NormalizedSectionWeights();
+ normalized_section_weights.section_weights = std::move(section_weights);
+ normalized_section_weights.default_weight = normalized_default_weight;
+ return normalized_section_weights;
+}
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/section-weights.h b/icing/scoring/section-weights.h
new file mode 100644
index 0000000..23a9188
--- /dev/null
+++ b/icing/scoring/section-weights.h
@@ -0,0 +1,95 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_SCORING_SECTION_WEIGHTS_H_
+#define ICING_SCORING_SECTION_WEIGHTS_H_
+
+#include <unordered_map>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/schema/schema-store.h"
+#include "icing/store/document-store.h"
+
+namespace icing {
+namespace lib {
+
+inline constexpr double kDefaultSectionWeight = 1.0;
+
+// Provides functions for setting and retrieving section weights for schema
+// type properties. Section weights are used to promote and demote term matches
+// in sections when scoring results. Section weights are provided by property
+// path, and can range from (0, DBL_MAX]. The SectionId is matched to the
+// property path by going over the schema type's section metadata. Weights that
+// correspond to a valid property path are then normalized against the maxmium
+// section weight, and put into map for quick access for scorers. By default,
+// a section is given a raw, pre-normalized weight of 1.0.
+class SectionWeights {
+ public:
+ // SectionWeights instances should not be copied.
+ SectionWeights(const SectionWeights&) = delete;
+ SectionWeights& operator=(const SectionWeights&) = delete;
+
+ // Factory function to create a SectionWeights instance. Raw weights are
+ // provided through the ScoringSpecProto. Provided property paths for weights
+ // are validated against the schema type's section metadata. If the property
+ // path doesn't exist, the property weight is ignored. If a weight is 0 or
+ // negative, an invalid argument error is returned. Raw weights are then
+ // normalized against the maximum weight for that schema type.
+ //
+ // Returns:
+ // A SectionWeights instance on success
+ // FAILED_PRECONDITION on any null pointer input
+ // INVALID_ARGUMENT if a provided weight for a property path is less than or
+ // equal to 0.
+ static libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> Create(
+ const SchemaStore* schema_store, const ScoringSpecProto& scoring_spec);
+
+ // Returns the normalized section weight by SchemaTypeId and SectionId. If
+ // the SchemaTypeId, or the SectionId for a SchemaTypeId, is not found in the
+ // normalized weights map, the default weight is returned instead.
+ double GetNormalizedSectionWeight(SchemaTypeId schema_type_id,
+ SectionId section_id) const;
+
+ private:
+ // Holds the normalized section weights for a schema type, as well as the
+ // normalized default weight for sections that have no weight set.
+ struct NormalizedSectionWeights {
+ std::unordered_map<SectionId, double> section_weights;
+ double default_weight;
+ };
+
+ explicit SectionWeights(
+ const std::unordered_map<SchemaTypeId, NormalizedSectionWeights>
+ schema_section_weight_map)
+ : schema_section_weight_map_(std::move(schema_section_weight_map)) {}
+
+ // Creates a map of section ids to normalized weights from the raw property
+ // path weight map and section metadata and calculates the normalized default
+ // section weight.
+ static inline SectionWeights::NormalizedSectionWeights
+ ExtractNormalizedSectionWeights(
+ const std::unordered_map<std::string, double>& raw_weights,
+ const std::vector<SectionMetadata>& metadata_list);
+
+ // A map of (SchemaTypeId -> SectionId -> Normalized Weight), allows for fast
+ // look up of normalized weights. This is precomputed when creating a
+ // SectionWeights instance.
+ std::unordered_map<SchemaTypeId, NormalizedSectionWeights>
+ schema_section_weight_map_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_SCORING_SECTION_WEIGHTS_H_
diff --git a/icing/scoring/section-weights_test.cc b/icing/scoring/section-weights_test.cc
new file mode 100644
index 0000000..b90c3d5
--- /dev/null
+++ b/icing/scoring/section-weights_test.cc
@@ -0,0 +1,386 @@
+// Copyright (C) 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/section-weights.h"
+
+#include <cfloat>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/proto/scoring.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+using ::testing::Eq;
+
+class SectionWeightsTest : public testing::Test {
+ protected:
+ SectionWeightsTest()
+ : test_dir_(GetTestTempDir() + "/icing"),
+ schema_store_dir_(test_dir_ + "/schema_store") {}
+
+ void SetUp() override {
+ // Creates file directories
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_));
+
+ SchemaTypeConfigProto sender_schema =
+ SchemaTypeConfigBuilder()
+ .SetType("sender")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("name")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetCardinality(
+ PropertyConfigProto_Cardinality_Code_OPTIONAL))
+ .Build();
+ SchemaTypeConfigProto email_schema =
+ SchemaTypeConfigBuilder()
+ .SetType("email")
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("subject")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetDataType(PropertyConfigProto_DataType_Code_STRING)
+ .SetCardinality(
+ PropertyConfigProto_Cardinality_Code_OPTIONAL))
+ .AddProperty(
+ PropertyConfigBuilder()
+ .SetName("body")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetDataType(PropertyConfigProto_DataType_Code_STRING)
+ .SetCardinality(
+ PropertyConfigProto_Cardinality_Code_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("sender")
+ .SetDataTypeDocument(
+ "sender", /*index_nested_properties=*/true)
+ .SetCardinality(
+ PropertyConfigProto_Cardinality_Code_OPTIONAL))
+ .Build();
+ SchemaProto schema =
+ SchemaBuilder().AddType(sender_schema).AddType(email_schema).Build();
+
+ ICING_ASSERT_OK(schema_store_->SetSchema(schema));
+ }
+
+ void TearDown() override {
+ schema_store_.reset();
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ }
+
+ SchemaStore *schema_store() { return schema_store_.get(); }
+
+ private:
+ const std::string test_dir_;
+ const std::string schema_store_dir_;
+ Filesystem filesystem_;
+ FakeClock fake_clock_;
+ std::unique_ptr<SchemaStore> schema_store_;
+};
+
+TEST_F(SectionWeightsTest, ShouldNormalizeSinglePropertyWeight) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("sender");
+
+ PropertyWeight *property_weight =
+ type_property_weights->add_property_weights();
+ property_weight->set_weight(5.0);
+ property_weight->set_path("name");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id,
+ schema_store()->GetSchemaTypeId("sender"));
+
+ // section_id 0 corresponds to property "name".
+ // We expect 1.0 as there is only one property in the "sender" schema type
+ // so it should take the max normalized weight of 1.0.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id,
+ /*section_id=*/0),
+ Eq(1.0));
+}
+
+TEST_F(SectionWeightsTest, ShouldAcceptMaxWeightValue) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("sender");
+
+ PropertyWeight *property_weight =
+ type_property_weights->add_property_weights();
+ property_weight->set_weight(DBL_MAX);
+ property_weight->set_path("name");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id,
+ schema_store()->GetSchemaTypeId("sender"));
+
+ // section_id 0 corresponds to property "name".
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id,
+ /*section_id=*/0),
+ Eq(1.0));
+}
+
+TEST_F(SectionWeightsTest, ShouldFailWithNegativeWeights) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("email");
+
+ PropertyWeight *body_propery_weight =
+ type_property_weights->add_property_weights();
+ body_propery_weight->set_weight(-100.0);
+ body_propery_weight->set_path("body");
+
+ EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(SectionWeightsTest, ShouldFailWithZeroWeight) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("sender");
+
+ PropertyWeight *property_weight =
+ type_property_weights->add_property_weights();
+ property_weight->set_weight(0.0);
+ property_weight->set_path("name");
+
+ EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(SectionWeightsTest, ShouldReturnDefaultIfTypePropertyWeightsNotSet) {
+ ScoringSpecProto spec_proto;
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id,
+ schema_store()->GetSchemaTypeId("email"));
+
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/0),
+ Eq(kDefaultSectionWeight));
+}
+
+TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeights) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("email");
+
+ PropertyWeight *body_property_weight =
+ type_property_weights->add_property_weights();
+ body_property_weight->set_weight(1.0);
+ body_property_weight->set_path("body");
+
+ PropertyWeight *subject_property_weight =
+ type_property_weights->add_property_weights();
+ subject_property_weight->set_weight(100.0);
+ subject_property_weight->set_path("subject");
+
+ PropertyWeight *nested_property_weight =
+ type_property_weights->add_property_weights();
+ nested_property_weight->set_weight(50.0);
+ nested_property_weight->set_path("sender.name");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id,
+ schema_store()->GetSchemaTypeId("email"));
+
+ // Normalized weight for "body" property.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/0),
+ Eq(0.01));
+ // Normalized weight for "sender.name" property (the nested property).
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/1),
+ Eq(0.5));
+ // Normalized weight for "subject" property.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/2),
+ Eq(1.0));
+}
+
+TEST_F(SectionWeightsTest, ShouldNormalizeIfAllWeightsBelowOne) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("email");
+
+ PropertyWeight *body_property_weight =
+ type_property_weights->add_property_weights();
+ body_property_weight->set_weight(0.1);
+ body_property_weight->set_path("body");
+
+ PropertyWeight *sender_name_weight =
+ type_property_weights->add_property_weights();
+ sender_name_weight->set_weight(0.2);
+ sender_name_weight->set_path("sender.name");
+
+ PropertyWeight *subject_property_weight =
+ type_property_weights->add_property_weights();
+ subject_property_weight->set_weight(0.4);
+ subject_property_weight->set_path("subject");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id,
+ schema_store()->GetSchemaTypeId("email"));
+
+ // Normalized weight for "body" property.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/0),
+ Eq(1.0 / 4.0));
+ // Normalized weight for "sender.name" property (the nested property).
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/1),
+ Eq(2.0 / 4.0));
+ // Normalized weight for "subject" property.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/2),
+ Eq(1.0));
+}
+
+TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeightSeparatelyForTypes) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *email_type_property_weights =
+ spec_proto.add_type_property_weights();
+ email_type_property_weights->set_schema_type("email");
+
+ PropertyWeight *body_property_weight =
+ email_type_property_weights->add_property_weights();
+ body_property_weight->set_weight(1.0);
+ body_property_weight->set_path("body");
+
+ PropertyWeight *subject_property_weight =
+ email_type_property_weights->add_property_weights();
+ subject_property_weight->set_weight(100.0);
+ subject_property_weight->set_path("subject");
+
+ PropertyWeight *sender_name_property_weight =
+ email_type_property_weights->add_property_weights();
+ sender_name_property_weight->set_weight(50.0);
+ sender_name_property_weight->set_path("sender.name");
+
+ TypePropertyWeights *sender_type_property_weights =
+ spec_proto.add_type_property_weights();
+ sender_type_property_weights->set_schema_type("sender");
+
+ PropertyWeight *sender_property_weight =
+ sender_type_property_weights->add_property_weights();
+ sender_property_weight->set_weight(25.0);
+ sender_property_weight->set_path("sender");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id,
+ schema_store()->GetSchemaTypeId("email"));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id,
+ schema_store()->GetSchemaTypeId("sender"));
+
+ // Normalized weight for "sender.name" property (the nested property)
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/1),
+ Eq(0.5));
+ // Normalized weight for "name" property for "sender" schema type. As it is
+ // the only property of the type, it should take the max normalized weight of
+ // 1.0.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id,
+ /*section_id=*/2),
+ Eq(1.0));
+}
+
+TEST_F(SectionWeightsTest, ShouldSkipNonExistentPathWhenSettingWeights) {
+ ScoringSpecProto spec_proto;
+
+ TypePropertyWeights *type_property_weights =
+ spec_proto.add_type_property_weights();
+ type_property_weights->set_schema_type("email");
+
+ // If this property weight isn't skipped, then the max property weight would
+ // be set to 100.0 and all weights would be normalized against the max.
+ PropertyWeight *non_valid_property_weight =
+ type_property_weights->add_property_weights();
+ non_valid_property_weight->set_weight(100.0);
+ non_valid_property_weight->set_path("sender.organization");
+
+ PropertyWeight *subject_property_weight =
+ type_property_weights->add_property_weights();
+ subject_property_weight->set_weight(10.0);
+ subject_property_weight->set_path("subject");
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<SectionWeights> section_weights,
+ SectionWeights::Create(schema_store(), spec_proto));
+ ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id,
+ schema_store()->GetSchemaTypeId("email"));
+
+ // Normalized weight for "body" property. Because the weight is not explicitly
+ // set, it is set to the default of 1.0 before being normalized.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/0),
+ Eq(0.1));
+ // Normalized weight for "sender.name" property (the nested property). Because
+ // the weight is not explicitly set, it is set to the default of 1.0 before
+ // being normalized.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/1),
+ Eq(0.1));
+ // Normalized weight for "subject" property. Because the invalid property path
+ // is skipped when assigning weights, subject takes the max normalized weight
+ // of 1.0 instead.
+ EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id,
+ /*section_id=*/2),
+ Eq(1.0));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/tokenization/raw-query-tokenizer.cc b/icing/tokenization/raw-query-tokenizer.cc
index 205d3a2..2d461ee 100644
--- a/icing/tokenization/raw-query-tokenizer.cc
+++ b/icing/tokenization/raw-query-tokenizer.cc
@@ -14,9 +14,8 @@
#include "icing/tokenization/raw-query-tokenizer.h"
-#include <stddef.h>
-
#include <cctype>
+#include <cstddef>
#include <memory>
#include <string>
#include <string_view>
diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc
index 6b1cb3a..8e1e563 100644
--- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc
+++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc
@@ -15,10 +15,10 @@
#include "icing/tokenization/reverse_jni/reverse-jni-break-iterator.h"
#include <jni.h>
-#include <math.h>
#include <cassert>
#include <cctype>
+#include <cmath>
#include <map>
#include "icing/jni/jni-cache.h"
diff --git a/icing/transform/icu/icu-normalizer.cc b/icing/transform/icu/icu-normalizer.cc
index 250d6cf..aceb11d 100644
--- a/icing/transform/icu/icu-normalizer.cc
+++ b/icing/transform/icu/icu-normalizer.cc
@@ -302,14 +302,16 @@ IcuNormalizer::TermTransformer::FindNormalizedNonLatinMatchEndPosition(
int32_t c16_length;
int32_t limit;
- constexpr int kUtf32CharBufferLength = 3;
- UChar32 normalized_buffer[kUtf32CharBufferLength];
- int32_t c32_length;
+ constexpr int kCharBufferLength = 3 * 4;
+ char normalized_buffer[kCharBufferLength];
+ int32_t c8_length;
while (char_itr.utf8_index() < term.length() &&
normalized_char_itr.utf8_index() < normalized_term.length()) {
UChar32 c = char_itr.GetCurrentChar();
- u_strFromUTF32(c16, kUtf16CharBufferLength, &c16_length, &c,
- /*srcLength=*/1, &status);
+ int c_lenth = i18n_utils::GetUtf8Length(c);
+ u_strFromUTF8(c16, kUtf16CharBufferLength, &c16_length,
+ term.data() + char_itr.utf8_index(),
+ /*srcLength=*/c_lenth, &status);
if (U_FAILURE(status)) {
break;
}
@@ -322,19 +324,20 @@ IcuNormalizer::TermTransformer::FindNormalizedNonLatinMatchEndPosition(
break;
}
- u_strToUTF32(normalized_buffer, kUtf32CharBufferLength, &c32_length, c16,
- c16_length, &status);
+ u_strToUTF8(normalized_buffer, kCharBufferLength, &c8_length, c16,
+ c16_length, &status);
if (U_FAILURE(status)) {
break;
}
- for (int i = 0; i < c32_length; ++i) {
- UChar32 normalized_c = normalized_char_itr.GetCurrentChar();
- if (normalized_buffer[i] != normalized_c) {
+ for (int i = 0; i < c8_length; ++i) {
+ if (normalized_buffer[i] !=
+ normalized_term[normalized_char_itr.utf8_index() + i]) {
return char_itr;
}
- normalized_char_itr.AdvanceToUtf32(normalized_char_itr.utf32_index() + 1);
}
+ normalized_char_itr.AdvanceToUtf8(normalized_char_itr.utf8_index() +
+ c8_length);
char_itr.AdvanceToUtf32(char_itr.utf32_index() + 1);
}
if (U_FAILURE(status)) {
diff --git a/icing/transform/map/map-normalizer.cc b/icing/transform/map/map-normalizer.cc
index 95aa633..61fce65 100644
--- a/icing/transform/map/map-normalizer.cc
+++ b/icing/transform/map/map-normalizer.cc
@@ -14,8 +14,7 @@
#include "icing/transform/map/map-normalizer.h"
-#include <ctype.h>
-
+#include <cctype>
#include <string>
#include <string_view>
#include <unordered_map>
diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java
index 1f5fb51..95e0c84 100644
--- a/java/src/com/google/android/icing/IcingSearchEngine.java
+++ b/java/src/com/google/android/icing/IcingSearchEngine.java
@@ -43,6 +43,8 @@ import com.google.android.icing.proto.SearchSpecProto;
import com.google.android.icing.proto.SetSchemaResultProto;
import com.google.android.icing.proto.StatusProto;
import com.google.android.icing.proto.StorageInfoResultProto;
+import com.google.android.icing.proto.SuggestionResponse;
+import com.google.android.icing.proto.SuggestionSpecProto;
import com.google.android.icing.proto.UsageReport;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
@@ -370,6 +372,26 @@ public class IcingSearchEngine implements Closeable {
}
@NonNull
+ public SuggestionResponse searchSuggestions(@NonNull SuggestionSpecProto suggestionSpec) {
+ byte[] suggestionResponseBytes = nativeSearchSuggestions(this, suggestionSpec.toByteArray());
+ if (suggestionResponseBytes == null) {
+ Log.e(TAG, "Received null suggestionResponseBytes from native.");
+ return SuggestionResponse.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing suggestionResponseBytes.", e);
+ return SuggestionResponse.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
public DeleteByNamespaceResultProto deleteByNamespace(@NonNull String namespace) {
throwIfClosed();
@@ -604,4 +626,7 @@ public class IcingSearchEngine implements Closeable {
private static native byte[] nativeGetStorageInfo(IcingSearchEngine instance);
private static native byte[] nativeReset(IcingSearchEngine instance);
+
+ private static native byte[] nativeSearchSuggestions(
+ IcingSearchEngine instance, byte[] suggestionSpecBytes);
}
diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java
index 0cee80c..cb28331 100644
--- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java
+++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java
@@ -51,6 +51,8 @@ import com.google.android.icing.proto.StatusProto;
import com.google.android.icing.proto.StorageInfoResultProto;
import com.google.android.icing.proto.StringIndexingConfig;
import com.google.android.icing.proto.StringIndexingConfig.TokenizerType;
+import com.google.android.icing.proto.SuggestionResponse;
+import com.google.android.icing.proto.SuggestionSpecProto;
import com.google.android.icing.proto.TermMatchType;
import com.google.android.icing.proto.UsageReport;
import com.google.android.icing.IcingSearchEngine;
@@ -623,6 +625,40 @@ public final class IcingSearchEngineTest {
assertThat(match).isEqualTo("𐀂𐀃");
}
+ @Test
+ public void testSearchSuggestions() {
+ assertStatusOk(icingSearchEngine.initialize().getStatus());
+
+ SchemaTypeConfigProto emailTypeConfig = createEmailTypeConfig();
+ SchemaProto schema = SchemaProto.newBuilder().addTypes(emailTypeConfig).build();
+ assertThat(
+ icingSearchEngine
+ .setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false)
+ .getStatus()
+ .getCode())
+ .isEqualTo(StatusProto.Code.OK);
+
+ DocumentProto emailDocument1 =
+ createEmailDocument("namespace", "uri1").toBuilder()
+ .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("fo"))
+ .build();
+ DocumentProto emailDocument2 =
+ createEmailDocument("namespace", "uri2").toBuilder()
+ .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("foo"))
+ .build();
+ assertStatusOk(icingSearchEngine.put(emailDocument1).getStatus());
+ assertStatusOk(icingSearchEngine.put(emailDocument2).getStatus());
+
+ SuggestionSpecProto suggestionSpec =
+ SuggestionSpecProto.newBuilder().setPrefix("f").setNumToReturn(10).build();
+
+ SuggestionResponse response = icingSearchEngine.searchSuggestions(suggestionSpec);
+ assertStatusOk(response.getStatus());
+ assertThat(response.getSuggestionsList()).hasSize(2);
+ assertThat(response.getSuggestions(0).getQuery()).isEqualTo("foo");
+ assertThat(response.getSuggestions(1).getQuery()).isEqualTo("fo");
+ }
+
private static void assertStatusOk(StatusProto status) {
assertWithMessage(status.getMessage()).that(status.getCode()).isEqualTo(StatusProto.Code.OK);
}
diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto
index 6186fde..a3a64df 100644
--- a/proto/icing/proto/scoring.proto
+++ b/proto/icing/proto/scoring.proto
@@ -23,7 +23,7 @@ option objc_class_prefix = "ICNG";
// Encapsulates the configurations on how Icing should score and rank the search
// results.
// TODO(b/170347684): Change all timestamps to seconds.
-// Next tag: 3
+// Next tag: 4
message ScoringSpecProto {
// OPTIONAL: Indicates how the search results will be ranked.
message RankingStrategy {
@@ -83,4 +83,41 @@ message ScoringSpecProto {
}
}
optional Order.Code order_by = 2;
+
+ // OPTIONAL: Specifies property weights for RELEVANCE_SCORE scoring strategy.
+ // Property weights are used for promoting or demoting query term matches in a
+ // document property. When property weights are provided, the term frequency
+ // is multiplied by the normalized property weight when computing the
+ // normalized term frequency component of BM25F. To prefer query term matches
+ // in the "subject" property over the "body" property of "Email" documents,
+ // set a higher property weight value for "subject" than "body". By default,
+ // all properties that are not specified are given a raw, pre-normalized
+ // weight of 1.0 when scoring.
+ repeated TypePropertyWeights type_property_weights = 3;
+}
+
+// Next tag: 3
+message TypePropertyWeights {
+ // Schema type to apply property weights to.
+ optional string schema_type = 1;
+
+ // Property weights to apply to the schema type.
+ repeated PropertyWeight property_weights = 2;
+}
+
+// Next tag: 3
+message PropertyWeight {
+ // Property path to assign property weight to. Property paths must be composed
+ // only of property names and property separators (the '.' character).
+ // For example, if an "Email" schema type has string property "subject" and
+ // document property "sender", which has string property "name", the property
+ // path for the email's subject would just be "subject" and the property path
+ // for the sender's name would be "sender.name". If an invalid path is
+ // specified, the property weight is discarded.
+ optional string path = 1;
+
+ // Property weight, valid values are positive. Zero and negative weights are
+ // invalid and will result in an error. By default, a property is given a raw,
+ // pre-normalized weight of 1.0.
+ optional double weight = 2;
}
diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto
index 544995e..c712ab2 100644
--- a/proto/icing/proto/search.proto
+++ b/proto/icing/proto/search.proto
@@ -308,3 +308,37 @@ message GetResultSpecProto {
// type will be retrieved.
repeated TypePropertyMask type_property_masks = 1;
}
+
+// Next tag: 4
+message SuggestionSpecProto {
+ // REQUIRED: The "raw" prefix string that users may type. For example, "f"
+ // will search for suggested query that start with "f" like "foo", "fool".
+ optional string prefix = 1;
+
+ // OPTIONAL: Only search for suggestions that under the specified namespaces.
+ // If unset, the suggestion will search over all namespaces. Note that this
+ // applies to the entire 'prefix'. To issue different suggestions for
+ // different namespaces, separate RunSuggestion()'s will need to be made.
+ repeated string namespace_filters = 2;
+
+ // REQUIRED: The number of suggestions to be returned.
+ optional int32 num_to_return = 3;
+}
+
+// Next tag: 3
+message SuggestionResponse {
+ message Suggestion {
+ // The suggested query string for client to search for.
+ optional string query = 1;
+ }
+
+ // Status code can be one of:
+ // OK
+ // FAILED_PRECONDITION
+ // INTERNAL
+ //
+ // See status.proto for more details.
+ optional StatusProto status = 1;
+
+ repeated Suggestion suggestions = 2;
+}
diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt
index 90b2b5a..7e0431b 100644
--- a/synced_AOSP_CL_number.txt
+++ b/synced_AOSP_CL_number.txt
@@ -1 +1 @@
-set(synced_AOSP_CL_number=396463043)
+set(synced_AOSP_CL_number=404879391)