diff options
Diffstat (limited to 'icing/scoring/advanced_scoring/score-expression.h')
-rw-r--r-- | icing/scoring/advanced_scoring/score-expression.h | 139 |
1 files changed, 124 insertions, 15 deletions
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h index 08d7997..e28fcd7 100644 --- a/icing/scoring/advanced_scoring/score-expression.h +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -15,20 +15,26 @@ #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ -#include <algorithm> -#include <cmath> +#include <cstdint> #include <memory> +#include <string> +#include <string_view> #include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/join/join-children-fetcher.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" +#include "icing/store/document-filter-data.h" +#include "icing/store/document-id.h" #include "icing/store/document-store.h" -#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -36,7 +42,11 @@ namespace lib { enum class ScoreExpressionType { kDouble, kDoubleList, - kDocument // Only "this" is considered as document type. + kDocument, // Only "this" is considered as document type. + // TODO(b/326656531): Instead of creating a vector index type, consider + // changing it to vector type so that the data is the vector directly. + kVectorIndex, + kString, }; class ScoreExpression { @@ -75,12 +85,24 @@ class ScoreExpression { "checking."); } + virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const { + if (type() == ScoreExpressionType::kString) { + return absl_ports::UnimplementedError( + "All ScoreExpressions of type string must provide their own " + "implementation of eval_string!"); + } + return absl_ports::InternalError( + "Runtime type error: the expression should never be evaluated to a " + "string. There must be inconsistencies in the static type checking."); + } + // Indicate the type to which the current expression will be evaluated. virtual ScoreExpressionType type() const = 0; - // Indicate whether the current expression is a constant double. - // Returns true if and only if the object is of ConstantScoreExpression type. - virtual bool is_constant_double() const { return false; } + // Indicate whether the current expression is a constant. + // Returns true if and only if the object is of ConstantScoreExpression or + // StringExpression type. + virtual bool is_constant() const { return false; } }; class ThisExpression : public ScoreExpression { @@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression { class ConstantScoreExpression : public ScoreExpression { public: static std::unique_ptr<ConstantScoreExpression> Create( - libtextclassifier3::StatusOr<double> c) { + libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type = ScoreExpressionType::kDouble) { return std::unique_ptr<ConstantScoreExpression>( - new ConstantScoreExpression(c)); + new ConstantScoreExpression(c, type)); } libtextclassifier3::StatusOr<double> eval( @@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression { return c_; } - ScoreExpressionType type() const override { - return ScoreExpressionType::kDouble; - } + ScoreExpressionType type() const override { return type_; } - bool is_constant_double() const override { return true; } + bool is_constant() const override { return true; } private: - explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c) - : c_(c) {} + explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c, + ScoreExpressionType type) + : c_(c), type_(type) {} libtextclassifier3::StatusOr<double> c_; + ScoreExpressionType type_; +}; + +class StringExpression : public ScoreExpression { + public: + static std::unique_ptr<StringExpression> Create(std::string str) { + return std::unique_ptr<StringExpression>( + new StringExpression(std::move(str))); + } + + libtextclassifier3::StatusOr<std::string_view> eval_string() const override { + return str_; + } + + ScoreExpressionType type() const override { + return ScoreExpressionType::kString; + } + + bool is_constant() const override { return true; } + + private: + explicit StringExpression(std::string str) : str_(std::move(str)) {} + std::string str_; }; class OperatorScoreExpression : public ScoreExpression { @@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression { int64_t current_time_ms_; }; +class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding"; + + // RETURNS: + // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if + // not simplifiable. + // - A ConstantScoreExpression instance on success if simplifiable. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( + std::vector<std::unique_ptr<ScoreExpression>> args); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kVectorIndex; + } + + private: + explicit GetSearchSpecEmbeddingFunctionScoreExpression( + std::unique_ptr<ScoreExpression> arg) + : arg_(std::move(arg)) {} + std::unique_ptr<ScoreExpression> arg_; +}; + +class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "matchedSemanticScores"; + + // RETURNS: + // - A MatchedSemanticScoresFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> + Create(std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, + const EmbeddingQueryResults* embedding_query_results); + + libtextclassifier3::StatusOr<std::vector<double>> eval_list( + const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) const override; + + ScoreExpressionType type() const override { + return ScoreExpressionType::kDoubleList; + } + + private: + explicit MatchedSemanticScoresFunctionScoreExpression( + std::vector<std::unique_ptr<ScoreExpression>> args, + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, + const EmbeddingQueryResults& embedding_query_results) + : args_(std::move(args)), + metric_type_(metric_type), + embedding_query_results_(embedding_query_results) {} + + std::vector<std::unique_ptr<ScoreExpression>> args_; + const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; + const EmbeddingQueryResults& embedding_query_results_; +}; + } // namespace lib } // namespace icing |