aboutsummaryrefslogtreecommitdiff
path: root/icing/scoring/advanced_scoring/advanced-scorer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/scoring/advanced_scoring/advanced-scorer.cc')
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.cc24
1 files changed, 20 insertions, 4 deletions
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc
index 83c1519..e375a8e 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer.cc
@@ -14,14 +14,25 @@
#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+#include <cstdint>
#include <memory>
+#include <utility>
+#include <vector>
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/schema/schema-store.h"
#include "icing/scoring/advanced_scoring/score-expression.h"
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/section-weights.h"
+#include "icing/store/document-store.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -29,11 +40,15 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>>
AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
double default_score,
+ SearchSpecProto::EmbeddingQueryMetricType::Code
+ default_semantic_metric_type,
const DocumentStore* document_store,
const SchemaStore* schema_store, int64_t current_time_ms,
- const JoinChildrenFetcher* join_children_fetcher) {
+ const JoinChildrenFetcher* join_children_fetcher,
+ const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
Lexer lexer(scoring_spec.advanced_scoring_expression(),
Lexer::Language::SCORING);
@@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
std::unique_ptr<Bm25fCalculator> bm25f_calculator =
std::make_unique<Bm25fCalculator>(document_store, section_weights.get(),
current_time_ms);
- ScoringVisitor visitor(default_score, document_store, schema_store,
- section_weights.get(), bm25f_calculator.get(),
- join_children_fetcher, current_time_ms);
+ ScoringVisitor visitor(default_score, default_semantic_metric_type,
+ document_store, schema_store, section_weights.get(),
+ bm25f_calculator.get(), join_children_fetcher,
+ embedding_query_results, current_time_ms);
tree_root->Accept(&visitor);
ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression,