aboutsummaryrefslogtreecommitdiff
path: root/icing/query/advanced_query_parser/query-visitor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/query/advanced_query_parser/query-visitor.cc')
-rw-r--r--icing/query/advanced_query_parser/query-visitor.cc139
1 files changed, 122 insertions, 17 deletions
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc
index 31da959..1ac52c5 100644
--- a/icing/query/advanced_query_parser/query-visitor.cc
+++ b/icing/query/advanced_query_parser/query-visitor.cc
@@ -16,20 +16,26 @@
#include <algorithm>
#include <cstdint>
-#include <cstdlib>
#include <iterator>
#include <limits>
#include <memory>
#include <set>
#include <string>
+#include <string_view>
+#include <unordered_map>
#include <utility>
#include <vector>
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/absl_ports/str_join.h"
+#include "icing/index/embed/doc-hit-info-iterator-embedding.h"
+#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h"
#include "icing/index/iterator/doc-hit-info-iterator-and.h"
+#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator-none.h"
#include "icing/index/iterator/doc-hit-info-iterator-not.h"
#include "icing/index/iterator/doc-hit-info-iterator-or.h"
@@ -37,17 +43,23 @@
#include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h"
#include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/iterator/section-restrict-data.h"
#include "icing/index/property-existence-indexing-handler.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/function.h"
#include "icing/query/advanced_query_parser/lexer.h"
#include "icing/query/advanced_query_parser/param.h"
#include "icing/query/advanced_query_parser/parser.h"
#include "icing/query/advanced_query_parser/pending-value.h"
#include "icing/query/advanced_query_parser/util/string-util.h"
#include "icing/query/query-features.h"
+#include "icing/query/query-results.h"
#include "icing/schema/property-util.h"
+#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/tokenization/token.h"
#include "icing/tokenization/tokenizer.h"
+#include "icing/util/embedding-util.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() {
.ValueOrDie();
registered_functions_.insert(
{has_property_function.name(), std::move(has_property_function)});
+
+ // vector_index getSearchSpecEmbedding(long);
+ auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) {
+ return PendingValue::CreateVectorIndexPendingValue(
+ args.at(0).long_val().ValueOrDie());
+ };
+ Function get_search_spec_embedding_function =
+ Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding",
+ {Param(DataType::kLong)},
+ std::move(get_search_spec_embedding))
+ .ValueOrDie();
+ registered_functions_.insert({get_search_spec_embedding_function.name(),
+ std::move(get_search_spec_embedding_function)});
+
+ // DocHitInfoIterator semanticSearch(vector_index, double, double, string);
+ auto semantic_search = [this](std::vector<PendingValue>&& args) {
+ return this->SemanticSearchFunction(std::move(args));
+ };
+ Function semantic_search_function =
+ Function::Create(DataType::kDocumentIterator, "semanticSearch",
+ {Param(DataType::kVectorIndex),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kDouble, Cardinality::kOptional),
+ Param(DataType::kString, Cardinality::kOptional)},
+ std::move(semantic_search))
+ .ValueOrDie();
+ registered_functions_.insert(
+ {semantic_search_function.name(), std::move(semantic_search_function)});
}
libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
@@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction(
document_store_.last_added_document_id());
} else {
QueryVisitor query_visitor(
- &index_, &numeric_index_, &document_store_, &schema_store_,
- &normalizer_, &tokenizer_, query->raw_term, filter_options_,
- match_type_, needs_term_frequency_info_, pending_property_restricts_,
- processing_not_, current_time_ms_);
+ &index_, &numeric_index_, &embedding_index_, &document_store_,
+ &schema_store_, &normalizer_, &tokenizer_, query->raw_term,
+ embedding_query_vectors_, filter_options_, match_type_,
+ embedding_query_metric_type_, needs_term_frequency_info_,
+ pending_property_restricts_, processing_not_, current_time_ms_);
tree_root->Accept(&query_visitor);
ICING_ASSIGN_OR_RETURN(query_result,
std::move(query_visitor).ConsumeResults());
@@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction(
return PendingValue(std::move(property_in_document_iterator));
}
+libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction(
+ std::vector<PendingValue>&& args) {
+ features_.insert(kEmbeddingSearchFeature);
+
+ int64_t vector_index = args.at(0).vector_index_val().ValueOrDie();
+ if (embedding_query_vectors_ == nullptr || vector_index < 0 ||
+ vector_index >= embedding_query_vectors_->size()) {
+ return absl_ports::InvalidArgumentError("Got invalid vector search index!");
+ }
+
+ // Handle default values for the optional arguments.
+ double low = -std::numeric_limits<double>::infinity();
+ double high = std::numeric_limits<double>::infinity();
+ SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
+ embedding_query_metric_type_;
+ if (args.size() >= 2) {
+ low = args.at(1).double_val().ValueOrDie();
+ }
+ if (args.size() >= 3) {
+ high = args.at(2).double_val().ValueOrDie();
+ }
+ if (args.size() >= 4) {
+ const std::string& metric = args.at(3).string_val().ValueOrDie()->term;
+ ICING_ASSIGN_OR_RETURN(
+ metric_type,
+ embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
+ }
+
+ // Create SectionRestrictData for section restriction.
+ std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr;
+ if (pending_property_restricts_.has_active_property_restricts()) {
+ std::unordered_map<std::string, std::set<std::string>>
+ type_property_filters;
+ type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] =
+ pending_property_restricts_.active_property_restricts();
+ section_restrict_data = std::make_unique<SectionRestrictData>(
+ &document_store_, &schema_store_, current_time_ms_,
+ type_property_filters);
+ }
+
+ // Create and return iterator.
+ EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map =
+ &embedding_query_results_.result_scores[vector_index][metric_type];
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator,
+ DocHitInfoIteratorEmbedding::Create(
+ &embedding_query_vectors_->at(vector_index),
+ std::move(section_restrict_data), metric_type, low,
+ high, score_map, &embedding_index_));
+ return PendingValue(std::move(iterator));
+}
+
libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() {
if (pending_values_.empty()) {
return absl_ports::InvalidArgumentError("Unable to retrieve int value.");
@@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() {
// raw_text, then all of raw_text must correspond to this token.
raw_token = raw_text;
} else {
- ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken(
- raw_text, token.text));
+ ICING_ASSIGN_OR_RETURN(
+ raw_token, string_util::FindEscapedToken(raw_text, token.text));
}
normalized_term = normalizer_.NormalizeTerm(token.text);
QueryTerm term_value{std::move(normalized_term), raw_token,
@@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator(
"Visit unary operator child didn't correctly add pending values.");
}
- // 3. We want to preserve the original text of the integer value, append our
- // minus and *then* parse as an int.
- ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue());
- int_text_val.term = absl_ports::StrCat("-", int_text_val.term);
+ // 3. We want to preserve the original text of the numeric value, append our
+ // minus to the text. It will be parsed as either an int or a double later.
+ ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue());
+ numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term);
PendingValue pending_value =
- PendingValue::CreateTextPendingValue(std::move(int_text_val));
- ICING_RETURN_IF_ERROR(pending_value.long_val());
+ PendingValue::CreateTextPendingValue(std::move(numeric_text_val));
- // We've parsed our integer value successfully. Pop our placeholder, push it
+ // We've parsed our numeric value successfully. Pop our placeholder, push it
// on to the stack and return successfully.
if (!pending_values_.top().is_placeholder()) {
return absl_ports::InvalidArgumentError(
@@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) {
end = text_val.raw_term.data() + text_val.raw_term.length();
} else {
start = std::min(start, text_val.raw_term.data());
- end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length());
+ end = std::max(end,
+ text_val.raw_term.data() + text_val.raw_term.length());
}
members.push_back(std::move(text_val.term));
}
@@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
"Function ", node->function_name()->value(), " is not supported."));
return;
}
+ const Function& function = itr->second;
// 2. Put in a placeholder PendingValue
pending_values_.push(PendingValue());
// 3. Visit the children.
- for (const std::unique_ptr<Node>& arg : node->args()) {
+ expecting_numeric_arg_ = true;
+ for (int i = 0; i < node->args().size(); ++i) {
+ const std::unique_ptr<Node>& arg = node->args()[i];
+ libtextclassifier3::StatusOr<DataType> arg_type_or =
+ function.get_param_type(i);
+ bool current_level_expecting_numeric_arg = expecting_numeric_arg_;
+ // If arg_type_or has an error, we should ignore it for now, since
+ // function.Eval should do the type check and return better error messages.
+ if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong ||
+ arg_type_or.ValueOrDie() == DataType::kDouble)) {
+ expecting_numeric_arg_ = true;
+ }
arg->Accept(this);
+ expecting_numeric_arg_ = current_level_expecting_numeric_arg;
if (has_pending_error()) {
return;
}
@@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) {
pending_values_.pop();
}
std::reverse(args.begin(), args.end());
- const Function& function = itr->second;
auto eval_result = function.Eval(std::move(args));
if (!eval_result.ok()) {
pending_error_ = std::move(eval_result).status();
@@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && {
results.root_iterator = std::move(iterator_or).ValueOrDie();
results.query_term_iterators = std::move(query_term_iterators_);
results.query_terms = std::move(property_query_terms_map_);
+ results.embedding_query_results = std::move(embedding_query_results_);
results.features_in_use = std::move(features_);
return results;
}