diff options
Diffstat (limited to 'icing/scoring/advanced_scoring/advanced-scorer.cc')
-rw-r--r-- | icing/scoring/advanced_scoring/advanced-scorer.cc | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc index 83c1519..e375a8e 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer.cc @@ -14,14 +14,25 @@ #include "icing/scoring/advanced_scoring/advanced-scorer.h" +#include <cstdint> #include <memory> +#include <utility> +#include <vector> +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/parser.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/advanced_scoring/scoring-visitor.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/section-weights.h" +#include "icing/store/document-store.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -29,11 +40,15 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, int64_t current_time_ms, - const JoinChildrenFetcher* join_children_fetcher) { + const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); Lexer lexer(scoring_spec.advanced_scoring_expression(), Lexer::Language::SCORING); @@ -48,9 +63,10 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, std::unique_ptr<Bm25fCalculator> bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store, section_weights.get(), current_time_ms); - ScoringVisitor visitor(default_score, document_store, schema_store, - section_weights.get(), bm25f_calculator.get(), - join_children_fetcher, current_time_ms); + ScoringVisitor visitor(default_score, default_semantic_metric_type, + document_store, schema_store, section_weights.get(), + bm25f_calculator.get(), join_children_fetcher, + embedding_query_results, current_time_ms); tree_root->Accept(&visitor); ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression, |