diff options
Diffstat (limited to 'icing/query/advanced_query_parser/query-visitor.cc')
-rw-r--r-- | icing/query/advanced_query_parser/query-visitor.cc | 139 |
1 files changed, 122 insertions, 17 deletions
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc index 31da959..1ac52c5 100644 --- a/icing/query/advanced_query_parser/query-visitor.cc +++ b/icing/query/advanced_query_parser/query-visitor.cc @@ -16,20 +16,26 @@ #include <algorithm> #include <cstdint> -#include <cstdlib> #include <iterator> #include <limits> #include <memory> #include <set> #include <string> +#include <string_view> +#include <unordered_map> #include <utility> #include <vector> +#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" +#include "icing/index/embed/doc-hit-info-iterator-embedding.h" +#include "icing/index/embed/embedding-query-results.h" #include "icing/index/iterator/doc-hit-info-iterator-all-document-id.h" #include "icing/index/iterator/doc-hit-info-iterator-and.h" +#include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator-none.h" #include "icing/index/iterator/doc-hit-info-iterator-not.h" #include "icing/index/iterator/doc-hit-info-iterator-or.h" @@ -37,17 +43,23 @@ #include "icing/index/iterator/doc-hit-info-iterator-property-in-schema.h" #include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/iterator/section-restrict-data.h" #include "icing/index/property-existence-indexing-handler.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/function.h" #include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/advanced_query_parser/param.h" #include "icing/query/advanced_query_parser/parser.h" #include "icing/query/advanced_query_parser/pending-value.h" #include "icing/query/advanced_query_parser/util/string-util.h" #include "icing/query/query-features.h" +#include "icing/query/query-results.h" #include "icing/schema/property-util.h" +#include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer.h" +#include "icing/util/embedding-util.h" #include "icing/util/status-macros.h" namespace icing { @@ -241,6 +253,34 @@ void QueryVisitor::RegisterFunctions() { .ValueOrDie(); registered_functions_.insert( {has_property_function.name(), std::move(has_property_function)}); + + // vector_index getSearchSpecEmbedding(long); + auto get_search_spec_embedding = [](std::vector<PendingValue>&& args) { + return PendingValue::CreateVectorIndexPendingValue( + args.at(0).long_val().ValueOrDie()); + }; + Function get_search_spec_embedding_function = + Function::Create(DataType::kVectorIndex, "getSearchSpecEmbedding", + {Param(DataType::kLong)}, + std::move(get_search_spec_embedding)) + .ValueOrDie(); + registered_functions_.insert({get_search_spec_embedding_function.name(), + std::move(get_search_spec_embedding_function)}); + + // DocHitInfoIterator semanticSearch(vector_index, double, double, string); + auto semantic_search = [this](std::vector<PendingValue>&& args) { + return this->SemanticSearchFunction(std::move(args)); + }; + Function semantic_search_function = + Function::Create(DataType::kDocumentIterator, "semanticSearch", + {Param(DataType::kVectorIndex), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kDouble, Cardinality::kOptional), + Param(DataType::kString, Cardinality::kOptional)}, + std::move(semantic_search)) + .ValueOrDie(); + registered_functions_.insert( + {semantic_search_function.name(), std::move(semantic_search_function)}); } libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( @@ -278,10 +318,11 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SearchFunction( document_store_.last_added_document_id()); } else { QueryVisitor query_visitor( - &index_, &numeric_index_, &document_store_, &schema_store_, - &normalizer_, &tokenizer_, query->raw_term, filter_options_, - match_type_, needs_term_frequency_info_, pending_property_restricts_, - processing_not_, current_time_ms_); + &index_, &numeric_index_, &embedding_index_, &document_store_, + &schema_store_, &normalizer_, &tokenizer_, query->raw_term, + embedding_query_vectors_, filter_options_, match_type_, + embedding_query_metric_type_, needs_term_frequency_info_, + pending_property_restricts_, processing_not_, current_time_ms_); tree_root->Accept(&query_visitor); ICING_ASSIGN_OR_RETURN(query_result, std::move(query_visitor).ConsumeResults()); @@ -359,6 +400,57 @@ libtextclassifier3::StatusOr<PendingValue> QueryVisitor::HasPropertyFunction( return PendingValue(std::move(property_in_document_iterator)); } +libtextclassifier3::StatusOr<PendingValue> QueryVisitor::SemanticSearchFunction( + std::vector<PendingValue>&& args) { + features_.insert(kEmbeddingSearchFeature); + + int64_t vector_index = args.at(0).vector_index_val().ValueOrDie(); + if (embedding_query_vectors_ == nullptr || vector_index < 0 || + vector_index >= embedding_query_vectors_->size()) { + return absl_ports::InvalidArgumentError("Got invalid vector search index!"); + } + + // Handle default values for the optional arguments. + double low = -std::numeric_limits<double>::infinity(); + double high = std::numeric_limits<double>::infinity(); + SearchSpecProto::EmbeddingQueryMetricType::Code metric_type = + embedding_query_metric_type_; + if (args.size() >= 2) { + low = args.at(1).double_val().ValueOrDie(); + } + if (args.size() >= 3) { + high = args.at(2).double_val().ValueOrDie(); + } + if (args.size() >= 4) { + const std::string& metric = args.at(3).string_val().ValueOrDie()->term; + ICING_ASSIGN_OR_RETURN( + metric_type, + embedding_util::GetEmbeddingQueryMetricTypeFromName(metric)); + } + + // Create SectionRestrictData for section restriction. + std::unique_ptr<SectionRestrictData> section_restrict_data = nullptr; + if (pending_property_restricts_.has_active_property_restricts()) { + std::unordered_map<std::string, std::set<std::string>> + type_property_filters; + type_property_filters[std::string(SchemaStore::kSchemaTypeWildcard)] = + pending_property_restricts_.active_property_restricts(); + section_restrict_data = std::make_unique<SectionRestrictData>( + &document_store_, &schema_store_, current_time_ms_, + type_property_filters); + } + + // Create and return iterator. + EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map = + &embedding_query_results_.result_scores[vector_index][metric_type]; + ICING_ASSIGN_OR_RETURN(std::unique_ptr<DocHitInfoIterator> iterator, + DocHitInfoIteratorEmbedding::Create( + &embedding_query_vectors_->at(vector_index), + std::move(section_restrict_data), metric_type, low, + high, score_map, &embedding_index_)); + return PendingValue(std::move(iterator)); +} + libtextclassifier3::StatusOr<int64_t> QueryVisitor::PopPendingIntValue() { if (pending_values_.empty()) { return absl_ports::InvalidArgumentError("Unable to retrieve int value."); @@ -435,8 +527,8 @@ QueryVisitor::PopPendingIterator() { // raw_text, then all of raw_text must correspond to this token. raw_token = raw_text; } else { - ICING_ASSIGN_OR_RETURN(raw_token, string_util::FindEscapedToken( - raw_text, token.text)); + ICING_ASSIGN_OR_RETURN( + raw_token, string_util::FindEscapedToken(raw_text, token.text)); } normalized_term = normalizer_.NormalizeTerm(token.text); QueryTerm term_value{std::move(normalized_term), raw_token, @@ -570,15 +662,14 @@ libtextclassifier3::Status QueryVisitor::ProcessNegationOperator( "Visit unary operator child didn't correctly add pending values."); } - // 3. We want to preserve the original text of the integer value, append our - // minus and *then* parse as an int. - ICING_ASSIGN_OR_RETURN(QueryTerm int_text_val, PopPendingTextValue()); - int_text_val.term = absl_ports::StrCat("-", int_text_val.term); + // 3. We want to preserve the original text of the numeric value, append our + // minus to the text. It will be parsed as either an int or a double later. + ICING_ASSIGN_OR_RETURN(QueryTerm numeric_text_val, PopPendingTextValue()); + numeric_text_val.term = absl_ports::StrCat("-", numeric_text_val.term); PendingValue pending_value = - PendingValue::CreateTextPendingValue(std::move(int_text_val)); - ICING_RETURN_IF_ERROR(pending_value.long_val()); + PendingValue::CreateTextPendingValue(std::move(numeric_text_val)); - // We've parsed our integer value successfully. Pop our placeholder, push it + // We've parsed our numeric value successfully. Pop our placeholder, push it // on to the stack and return successfully. if (!pending_values_.top().is_placeholder()) { return absl_ports::InvalidArgumentError( @@ -768,7 +859,8 @@ void QueryVisitor::VisitMember(const MemberNode* node) { end = text_val.raw_term.data() + text_val.raw_term.length(); } else { start = std::min(start, text_val.raw_term.data()); - end = std::max(end, text_val.raw_term.data() + text_val.raw_term.length()); + end = std::max(end, + text_val.raw_term.data() + text_val.raw_term.length()); } members.push_back(std::move(text_val.term)); } @@ -800,13 +892,26 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { "Function ", node->function_name()->value(), " is not supported.")); return; } + const Function& function = itr->second; // 2. Put in a placeholder PendingValue pending_values_.push(PendingValue()); // 3. Visit the children. - for (const std::unique_ptr<Node>& arg : node->args()) { + expecting_numeric_arg_ = true; + for (int i = 0; i < node->args().size(); ++i) { + const std::unique_ptr<Node>& arg = node->args()[i]; + libtextclassifier3::StatusOr<DataType> arg_type_or = + function.get_param_type(i); + bool current_level_expecting_numeric_arg = expecting_numeric_arg_; + // If arg_type_or has an error, we should ignore it for now, since + // function.Eval should do the type check and return better error messages. + if (arg_type_or.ok() && (arg_type_or.ValueOrDie() == DataType::kLong || + arg_type_or.ValueOrDie() == DataType::kDouble)) { + expecting_numeric_arg_ = true; + } arg->Accept(this); + expecting_numeric_arg_ = current_level_expecting_numeric_arg; if (has_pending_error()) { return; } @@ -819,7 +924,6 @@ void QueryVisitor::VisitFunction(const FunctionNode* node) { pending_values_.pop(); } std::reverse(args.begin(), args.end()); - const Function& function = itr->second; auto eval_result = function.Eval(std::move(args)); if (!eval_result.ok()) { pending_error_ = std::move(eval_result).status(); @@ -955,6 +1059,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && { results.root_iterator = std::move(iterator_or).ValueOrDie(); results.query_term_iterators = std::move(query_term_iterators_); results.query_terms = std::move(property_query_terms_map_); + results.embedding_query_results = std::move(embedding_query_results_); results.features_in_use = std::move(features_); return results; } |