aboutsummaryrefslogtreecommitdiff
path: root/icing/scoring/advanced_scoring/score-expression.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/scoring/advanced_scoring/score-expression.cc')
-rw-r--r--icing/scoring/advanced_scoring/score-expression.cc128
1 files changed, 126 insertions, 2 deletions
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc
index e8a2a89..687180a 100644
--- a/icing/scoring/advanced_scoring/score-expression.cc
+++ b/icing/scoring/advanced_scoring/score-expression.cc
@@ -14,10 +14,39 @@
#include "icing/scoring/advanced_scoring/score-expression.h"
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
#include <numeric>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/index/embed/embedding-query-results.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/join/join-children-fetcher.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/scoring/section-weights.h"
+#include "icing/store/document-associated-score-data.h"
+#include "icing/store/document-filter-data.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/util/embedding-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -49,7 +78,7 @@ OperatorScoreExpression::Create(
return absl_ports::InvalidArgumentError(
"Operators are only supported for double type.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
children_all_constant_double = false;
}
}
@@ -149,7 +178,7 @@ MathFunctionScoreExpression::Create(
"Got an invalid type for the math function. Should expect a double "
"type argument.");
}
- if (!child->is_constant_double()) {
+ if (!child->is_constant()) {
args_all_constant_double = false;
}
}
@@ -517,5 +546,100 @@ SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId(
return filter_data_optional.value().schema_type_id();
}
+libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
+GetSearchSpecEmbeddingFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args) {
+ if (args.size() != 1) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " must have 1 argument."));
+ }
+ if (args[0]->type() != ScoreExpressionType::kDouble) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid argument type."));
+ }
+ bool is_constant = args[0]->is_constant();
+ std::unique_ptr<ScoreExpression> expression =
+ std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>(
+ new GetSearchSpecEmbeddingFunctionScoreExpression(
+ std::move(args[0])));
+ if (is_constant) {
+ return ConstantScoreExpression::Create(
+ expression->eval(DocHitInfo(), /*query_it=*/nullptr),
+ expression->type());
+ }
+ return expression;
+}
+
+libtextclassifier3::StatusOr<double>
+GetSearchSpecEmbeddingFunctionScoreExpression::eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ arg_->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ if (query_index != raw_query_index) {
+ return absl_ports::InvalidArgumentError(
+ "The index of an embedding query must be an integer.");
+ }
+ return query_index;
+}
+
+libtextclassifier3::StatusOr<
+ std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
+MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::vector<std::unique_ptr<ScoreExpression>> args,
+ SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
+ const EmbeddingQueryResults* embedding_query_results) {
+ ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
+ ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
+
+ if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
+ }
+ if (args.size() != 2 && args.size() != 3) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
+ }
+ if (args[1]->type() != ScoreExpressionType::kVectorIndex) {
+ return absl_ports::InvalidArgumentError(absl_ports::StrCat(
+ kFunctionName, " got invalid argument type for embedding vector."));
+ }
+ if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a string.");
+ }
+
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ default_metric_type;
+ if (args.size() == 3) {
+ if (!args[2]->is_constant()) {
+ return absl_ports::InvalidArgumentError(
+ "Embedding metric can only be given as a constant string.");
+ }
+ ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string());
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+ return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
+ new MatchedSemanticScoresFunctionScoreExpression(
+ std::move(args), metric_type, *embedding_query_results));
+}
+
+libtextclassifier3::StatusOr<std::vector<double>>
+MatchedSemanticScoresFunctionScoreExpression::eval_list(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
+ ICING_ASSIGN_OR_RETURN(double raw_query_index,
+ args_[1]->eval(hit_info, query_it));
+ uint32_t query_index = (uint32_t)raw_query_index;
+ const std::vector<double>* scores =
+ embedding_query_results_.GetMatchedScoresForDocument(
+ query_index, metric_type_, hit_info.document_id());
+ if (scores == nullptr) {
+ return std::vector<double>();
+ }
+ return *scores;
+}
+
} // namespace lib
} // namespace icing