aboutsummaryrefslogtreecommitdiff
path: root/icing/scoring/advanced_scoring/scoring-visitor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/scoring/advanced_scoring/scoring-visitor.cc')
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.cc24
1 files changed, 22 insertions, 2 deletions
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc
index e2b24a2..05240c0 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.cc
+++ b/icing/scoring/advanced_scoring/scoring-visitor.cc
@@ -14,7 +14,17 @@
#include "icing/scoring/advanced_scoring/scoring-visitor.h"
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
namespace icing {
namespace lib {
@@ -25,8 +35,7 @@ void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) {
}
void ScoringVisitor::VisitString(const StringNode* node) {
- pending_error_ =
- absl_ports::InvalidArgumentError("Scoring does not support String!");
+ stack_.push_back(StringExpression::Create(node->value()));
}
void ScoringVisitor::VisitText(const TextNode* node) {
@@ -120,6 +129,17 @@ void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node,
expression = MathFunctionScoreExpression::Create(
MathFunctionScoreExpression::kFunctionNames.at(function_name),
std::move(args));
+ } else if (function_name ==
+ GetSearchSpecEmbeddingFunctionScoreExpression::kFunctionName) {
+ // getSearchSpecEmbedding function
+ expression =
+ GetSearchSpecEmbeddingFunctionScoreExpression::Create(std::move(args));
+ } else if (function_name ==
+ MatchedSemanticScoresFunctionScoreExpression::kFunctionName) {
+ // matchedSemanticScores function
+ expression = MatchedSemanticScoresFunctionScoreExpression::Create(
+ std::move(args), default_semantic_metric_type_,
+ &embedding_query_results_);
}
if (!expression.ok()) {