aboutsummaryrefslogtreecommitdiff
path: root/icing/scoring/advanced_scoring/score-expression.h
diff options
context:
space:
mode:
Diffstat (limited to 'icing/scoring/advanced_scoring/score-expression.h')
-rw-r--r--icing/scoring/advanced_scoring/score-expression.h139
1 files changed, 124 insertions, 15 deletions
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h
index 08d7997..e28fcd7 100644
--- a/icing/scoring/advanced_scoring/score-expression.h
+++ b/icing/scoring/advanced_scoring/score-expression.h
@@ -15,20 +15,26 @@
#ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
#define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
-#include <algorithm>
-#include <cmath>
+#include <cstdint>
#include <memory>
+#include <string>
+#include <string_view>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
-#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -36,7 +42,11 @@ namespace lib {
enum class ScoreExpressionType {
kDouble,
kDoubleList,
- kDocument // Only "this" is considered as document type.
+ kDocument, // Only "this" is considered as document type.
+ // TODO(b/326656531): Instead of creating a vector index type, consider
+ // changing it to vector type so that the data is the vector directly.
+ kVectorIndex,
+ kString,
};
class ScoreExpression {
@@ -75,12 +85,24 @@ class ScoreExpression {
"checking.");
}
+ virtual libtextclassifier3::StatusOr<std::string_view> eval_string() const {
+ if (type() == ScoreExpressionType::kString) {
+ return absl_ports::UnimplementedError(
+ "All ScoreExpressions of type string must provide their own "
+ "implementation of eval_string!");
+ }
+ return absl_ports::InternalError(
+ "Runtime type error: the expression should never be evaluated to a "
+ "string. There must be inconsistencies in the static type checking.");
+ }
+
// Indicate the type to which the current expression will be evaluated.
virtual ScoreExpressionType type() const = 0;
- // Indicate whether the current expression is a constant double.
- // Returns true if and only if the object is of ConstantScoreExpression type.
- virtual bool is_constant_double() const { return false; }
+ // Indicate whether the current expression is a constant.
+ // Returns true if and only if the object is of ConstantScoreExpression or
+ // StringExpression type.
+ virtual bool is_constant() const { return false; }
};
class ThisExpression : public ScoreExpression {
@@ -100,9 +122,10 @@ class ThisExpression : public ScoreExpression {
class ConstantScoreExpression : public ScoreExpression {
public:
static std::unique_ptr<ConstantScoreExpression> Create(
- libtextclassifier3::StatusOr<double> c) {
+ libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type = ScoreExpressionType::kDouble) {
return std::unique_ptr<ConstantScoreExpression>(
- new ConstantScoreExpression(c));
+ new ConstantScoreExpression(c, type));
}
libtextclassifier3::StatusOr<double> eval(
@@ -110,17 +133,39 @@ class ConstantScoreExpression : public ScoreExpression {
return c_;
}
- ScoreExpressionType type() const override {
- return ScoreExpressionType::kDouble;
- }
+ ScoreExpressionType type() const override { return type_; }
- bool is_constant_double() const override { return true; }
+ bool is_constant() const override { return true; }
private:
- explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c)
- : c_(c) {}
+ explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c,
+ ScoreExpressionType type)
+ : c_(c), type_(type) {}
libtextclassifier3::StatusOr<double> c_;
+ ScoreExpressionType type_;
+};
+
+class StringExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<StringExpression> Create(std::string str) {
+ return std::unique_ptr<StringExpression>(
+ new StringExpression(std::move(str)));
+ }
+
+ libtextclassifier3::StatusOr<std::string_view> eval_string() const override {
+ return str_;
+ }
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kString;
+ }
+
+ bool is_constant() const override { return true; }
+
+ private:
+ explicit StringExpression(std::string str) : str_(std::move(str)) {}
+ std::string str_;
};
class OperatorScoreExpression : public ScoreExpression {
@@ -342,6 +387,70 @@ class PropertyWeightsFunctionScoreExpression : public ScoreExpression {
int64_t current_time_ms_;
};
+class GetSearchSpecEmbeddingFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "getSearchSpecEmbedding";
+
+ // RETURNS:
+ // - A GetSearchSpecEmbeddingFunctionScoreExpression instance on success if
+ // not simplifiable.
+ // - A ConstantScoreExpression instance on success if simplifiable.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args);
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kVectorIndex;
+ }
+
+ private:
+ explicit GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::unique_ptr<ScoreExpression> arg)
+ : arg_(std::move(arg)) {}
+ std::unique_ptr<ScoreExpression> arg_;
+};
+
+class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression {
+ public:
+ static constexpr std::string_view kFunctionName = "matchedSemanticScores";
+
+ // RETURNS:
+ // - A MatchedSemanticScoresFunctionScoreExpression instance on success.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+ Create(std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results);
+
+ libtextclassifier3::StatusOr<std::vector<double>> eval_list(
+ const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) const override;
+
+ ScoreExpressionType type() const override {
+ return ScoreExpressionType::kDoubleList;
+ }
+
+ private:
+ explicit MatchedSemanticScoresFunctionScoreExpression(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
+ const EmbeddingQueryResults& embedding_query_results)
+ : args_(std::move(args)),
+ metric_type_(metric_type),
+ embedding_query_results_(embedding_query_results) {}
+
+ std::vector<std::unique_ptr<ScoreExpression>> args_;
+ const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
+ const EmbeddingQueryResults& embedding_query_results_;
+};
+
} // namespace lib
} // namespace icing