diff options
Diffstat (limited to 'icing/scoring/scorer-factory.cc')
-rw-r--r-- | icing/scoring/scorer-factory.cc | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc index e56f10c..1d66d7f 100644 --- a/icing/scoring/scorer-factory.cc +++ b/icing/scoring/scorer-factory.cc @@ -14,19 +14,26 @@ #include "icing/scoring/scorer-factory.h" +#include <cstdint> #include <memory> +#include <optional> +#include <string> #include <unordered_map> +#include <utility> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" #include "icing/proto/scoring.pb.h" +#include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/advanced-scorer.h" #include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" #include "icing/scoring/section-weights.h" -#include "icing/store/document-id.h" +#include "icing/store/document-associated-score-data.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -173,10 +180,14 @@ namespace scorer_factory { libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, + SearchSpecProto::EmbeddingQueryMetricType::Code + default_semantic_metric_type, const DocumentStore* document_store, const SchemaStore* schema_store, - int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher) { + int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, + const EmbeddingQueryResults* embedding_query_results) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); if (!scoring_spec.advanced_scoring_expression().empty() && scoring_spec.rank_by() != @@ -223,9 +234,10 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( return absl_ports::InvalidArgumentError( "Advanced scoring is enabled, but the expression is empty!"); } - return AdvancedScorer::Create(scoring_spec, default_score, document_store, - schema_store, current_time_ms, - join_children_fetcher); + return AdvancedScorer::Create( + scoring_spec, default_score, default_semantic_metric_type, + document_store, schema_store, current_time_ms, join_children_fetcher, + embedding_query_results); case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: // Use join aggregate score to rank. Since the aggregation score is // calculated by child documents after joining (in JoinProcessor), we can |