diff options
Diffstat (limited to 'icing/scoring/advanced_scoring/score-expression.cc')
-rw-r--r-- | icing/scoring/advanced_scoring/score-expression.cc | 128 |
1 files changed, 126 insertions, 2 deletions
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc index e8a2a89..687180a 100644 --- a/icing/scoring/advanced_scoring/score-expression.cc +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -14,10 +14,39 @@ #include "icing/scoring/advanced_scoring/score-expression.h" +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <cstdlib> +#include <memory> #include <numeric> +#include <optional> +#include <string> +#include <string_view> +#include <unordered_map> +#include <unordered_set> +#include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/index/embed/embedding-query-results.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/join/join-children-fetcher.h" +#include "icing/schema/section.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/util/embedding-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -49,7 +78,7 @@ OperatorScoreExpression::Create( return absl_ports::InvalidArgumentError( "Operators are only supported for double type."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { children_all_constant_double = false; } } @@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create( "Got an invalid type for the math function. Should expect a double " "type argument."); } - if (!child->is_constant_double()) { + if (!child->is_constant()) { args_all_constant_double = false; } } @@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId( return filter_data_optional.value().schema_type_id(); } +libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> +GetSearchSpecEmbeddingFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args) { + if (args.size() != 1) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " must have 1 argument.")); + } + if (args[0]->type() != ScoreExpressionType::kDouble) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid argument type.")); + } + bool is_constant = args[0]->is_constant(); + std::unique_ptr<ScoreExpression> expression = + std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>( + new GetSearchSpecEmbeddingFunctionScoreExpression( + std::move(args[0]))); + if (is_constant) { + return ConstantScoreExpression::Create( + expression->eval(DocHitInfo(), /*query_it=*/nullptr), + expression->type()); + } + return expression; +} + +libtextclassifier3::StatusOr<double> +GetSearchSpecEmbeddingFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + arg_->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + if (query_index != raw_query_index) { + return absl_ports::InvalidArgumentError( + "The index of an embedding query must be an integer."); + } + return query_index; +} + +libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> +MatchedSemanticScoresFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results) { + ICING_RETURN_ERROR_IF_NULL(embedding_query_results); + ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args)); + + if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " is not called with \"this\"")); + } + if (args.size() != 2 && args.size() != 3) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat(kFunctionName, " got invalid number of arguments.")); + } + if (args[1]->type() != ScoreExpressionType::kVectorIndex) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + kFunctionName, " got invalid argument type for embedding vector.")); + } + if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a string."); + } + + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + default_metric_type; + if (args.size() == 3) { + if (!args[2]->is_constant()) { + return absl_ports::InvalidArgumentError( + "Embedding metric can only be given as a constant string."); + } + ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string()); + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>( + new MatchedSemanticScoresFunctionScoreExpression( + std::move(args), metric_type, *embedding_query_results)); +} + +libtextclassifier3::StatusOr<std::vector<double>> +MatchedSemanticScoresFunctionScoreExpression::eval_list( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { + ICING_ASSIGN_OR_RETURN(double raw_query_index, + args_[1]->eval(hit_info, query_it)); + uint32_t query_index = (uint32_t)raw_query_index; + const std::vector<double>* scores = + embedding_query_results_.GetMatchedScoresForDocument( + query_index, metric_type_, hit_info.document_id()); + if (scores == nullptr) { + return std::vector<double>(); + } + return *scores; +} + } // namespace lib } // namespace icing |