diff options
Diffstat (limited to 'source/opt/const_folding_rules.cpp')
-rw-r--r-- | source/opt/const_folding_rules.cpp | 257 |
1 files changed, 85 insertions, 172 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 14f22089..0ad755c9 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -19,7 +19,8 @@ namespace spvtools { namespace opt { namespace { -constexpr uint32_t kExtractCompositeIdInIdx = 0; + +const uint32_t kExtractCompositeIdInIdx = 0; // Returns a constants with the value NaN of the given type. Only works for // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. @@ -119,97 +120,11 @@ ConstantFoldingRule FoldExtractWithConstants() { }; } -// Folds an OpcompositeInsert where input is a composite constant. -ConstantFoldingRule FoldInsertWithConstants() { - return [](IRContext* context, Instruction* inst, - const std::vector<const analysis::Constant*>& constants) - -> const analysis::Constant* { - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); - const analysis::Constant* object = constants[0]; - const analysis::Constant* composite = constants[1]; - if (object == nullptr || composite == nullptr) { - return nullptr; - } - - // If there is more than 1 index, then each additional constant used by the - // index will need to be recreated to use the inserted object. - std::vector<const analysis::Constant*> chain; - std::vector<const analysis::Constant*> components; - const analysis::Type* type = nullptr; - const uint32_t final_index = (inst->NumInOperands() - 1); - - // Work down hierarchy of all indexes - for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { - type = composite->type(); - - if (composite->AsNullConstant()) { - // Make new composite so it can be inserted in the index with the - // non-null value - const auto new_composite = const_mgr->GetNullCompositeConstant(type); - // Keep track of any indexes along the way to last index - if (i != final_index) { - chain.push_back(new_composite); - } - components = new_composite->AsCompositeConstant()->GetComponents(); - } else { - // Keep track of any indexes along the way to last index - if (i != final_index) { - chain.push_back(composite); - } - components = composite->AsCompositeConstant()->GetComponents(); - } - const uint32_t index = inst->GetSingleWordInOperand(i); - composite = components[index]; - } - - // Final index in hierarchy is inserted with new object. - const uint32_t final_operand = inst->GetSingleWordInOperand(final_index); - std::vector<uint32_t> ids; - for (size_t i = 0; i < components.size(); i++) { - const analysis::Constant* constant = - (i == final_operand) ? object : components[i]; - Instruction* member_inst = const_mgr->GetDefiningInstruction(constant); - ids.push_back(member_inst->result_id()); - } - const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids); - - // Work backwards up the chain and replace each index with new constant. - for (size_t i = chain.size(); i > 0; i--) { - // Need to insert any previous instruction into the module first. - // Can't just insert in types_values_begin() because it will move above - // where the types are declared. - // Can't compare with location of inst because not all new added - // instructions are added to types_values_ - auto iter = context->types_values_end(); - Module::inst_iterator* pos = &iter; - const_mgr->BuildInstructionAndAddToModule(new_constant, pos); - - composite = chain[i - 1]; - components = composite->AsCompositeConstant()->GetComponents(); - type = composite->type(); - ids.clear(); - for (size_t k = 0; k < components.size(); k++) { - const uint32_t index = - inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i)); - const analysis::Constant* constant = - (k == index) ? new_constant : components[k]; - const uint32_t constant_id = - const_mgr->FindDeclaredConstant(constant, 0); - ids.push_back(constant_id); - } - new_constant = const_mgr->GetConstant(type, ids); - } - - // If multiple constants were created, only need to return the top index. - return new_constant; - }; -} - ConstantFoldingRule FoldVectorShuffleWithConstants() { return [](IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - assert(inst->opcode() == spv::Op::OpVectorShuffle); + assert(inst->opcode() == SpvOpVectorShuffle); const analysis::Constant* c1 = constants[0]; const analysis::Constant* c2 = constants[1]; if (c1 == nullptr || c2 == nullptr) { @@ -265,7 +180,7 @@ ConstantFoldingRule FoldVectorTimesScalar() { return [](IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - assert(inst->opcode() == spv::Op::OpVectorTimesScalar); + assert(inst->opcode() == SpvOpVectorTimesScalar); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); @@ -340,7 +255,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() { return [](IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - assert(inst->opcode() == spv::Op::OpVectorTimesMatrix); + assert(inst->opcode() == SpvOpVectorTimesMatrix); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); @@ -433,7 +348,7 @@ ConstantFoldingRule FoldMatrixTimesVector() { return [](IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { - assert(inst->opcode() == spv::Op::OpMatrixTimesVector); + assert(inst->opcode() == SpvOpMatrixTimesVector); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); @@ -543,9 +458,9 @@ ConstantFoldingRule FoldCompositeWithConstants() { } uint32_t component_type_id = 0; - if (type_inst->opcode() == spv::Op::OpTypeStruct) { + if (type_inst->opcode() == SpvOpTypeStruct) { component_type_id = type_inst->GetSingleWordInOperand(i); - } else if (type_inst->opcode() == spv::Op::OpTypeArray) { + } else if (type_inst->opcode() == SpvOpTypeArray) { component_type_id = type_inst->GetSingleWordInOperand(0); } @@ -594,7 +509,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { } const analysis::Constant* arg = - (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; + (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0]; if (arg == nullptr) { return nullptr; @@ -684,7 +599,7 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { if (!inst->IsFloatingPointFoldingAllowed()) { return nullptr; } - if (inst->opcode() == spv::Op::OpExtInst) { + if (inst->opcode() == SpvOpExtInst) { return FoldFPBinaryOp(scalar_rule, inst->type_id(), {constants[1], constants[2]}, context); } @@ -1042,7 +957,7 @@ UnaryScalarFoldingRule FoldFNegateOp() { ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } -ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { +ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { return [cmp_opcode](IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { @@ -1070,7 +985,7 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { return nullptr; } - if (operand_inst->opcode() != spv::Op::OpExtInst) { + if (operand_inst->opcode() != SpvOpExtInst) { return nullptr; } @@ -1094,25 +1009,25 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { bool result = false; switch (cmp_opcode) { - case spv::Op::OpFOrdLessThan: - case spv::Op::OpFUnordLessThan: - case spv::Op::OpFOrdGreaterThanEqual: - case spv::Op::OpFUnordGreaterThanEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: if (constants[0]) { if (min_const) { if (constants[0]->GetValueAsDouble() < min_const->GetValueAsDouble()) { found_result = true; - result = (cmp_opcode == spv::Op::OpFOrdLessThan || - cmp_opcode == spv::Op::OpFUnordLessThan); + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); } } if (max_const) { if (constants[0]->GetValueAsDouble() >= max_const->GetValueAsDouble()) { found_result = true; - result = !(cmp_opcode == spv::Op::OpFOrdLessThan || - cmp_opcode == spv::Op::OpFUnordLessThan); + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); } } } @@ -1122,8 +1037,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { if (max_const->GetValueAsDouble() < constants[1]->GetValueAsDouble()) { found_result = true; - result = (cmp_opcode == spv::Op::OpFOrdLessThan || - cmp_opcode == spv::Op::OpFUnordLessThan); + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); } } @@ -1131,31 +1046,31 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { if (min_const->GetValueAsDouble() >= constants[1]->GetValueAsDouble()) { found_result = true; - result = !(cmp_opcode == spv::Op::OpFOrdLessThan || - cmp_opcode == spv::Op::OpFUnordLessThan); + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); } } } break; - case spv::Op::OpFOrdGreaterThan: - case spv::Op::OpFUnordGreaterThan: - case spv::Op::OpFOrdLessThanEqual: - case spv::Op::OpFUnordLessThanEqual: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: if (constants[0]) { if (min_const) { if (constants[0]->GetValueAsDouble() <= min_const->GetValueAsDouble()) { found_result = true; - result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || - cmp_opcode == spv::Op::OpFUnordLessThanEqual); + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); } } if (max_const) { if (constants[0]->GetValueAsDouble() > max_const->GetValueAsDouble()) { found_result = true; - result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || - cmp_opcode == spv::Op::OpFUnordLessThanEqual); + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); } } } @@ -1165,8 +1080,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { if (max_const->GetValueAsDouble() <= constants[1]->GetValueAsDouble()) { found_result = true; - result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || - cmp_opcode == spv::Op::OpFUnordLessThanEqual); + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); } } @@ -1174,8 +1089,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { if (min_const->GetValueAsDouble() > constants[1]->GetValueAsDouble()) { found_result = true; - result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || - cmp_opcode == spv::Op::OpFUnordLessThanEqual); + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); } } } @@ -1202,7 +1117,7 @@ ConstantFoldingRule FoldFMix() { const std::vector<const analysis::Constant*>& constants) -> const analysis::Constant* { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); - assert(inst->opcode() == spv::Op::OpExtInst && + assert(inst->opcode() == SpvOpExtInst && "Expecting an extended instruction."); assert(inst->GetSingleWordInOperand(0) == context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && @@ -1352,7 +1267,7 @@ const analysis::Constant* FoldMax(const analysis::Type* result_type, const analysis::Constant* FoldClamp1( IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) { - assert(inst->opcode() == spv::Op::OpExtInst && + assert(inst->opcode() == SpvOpExtInst && "Expecting an extended instruction."); assert(inst->GetSingleWordInOperand(0) == context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && @@ -1378,7 +1293,7 @@ const analysis::Constant* FoldClamp1( const analysis::Constant* FoldClamp2( IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) { - assert(inst->opcode() == spv::Op::OpExtInst && + assert(inst->opcode() == SpvOpExtInst && "Expecting an extended instruction."); assert(inst->GetSingleWordInOperand(0) == context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && @@ -1406,7 +1321,7 @@ const analysis::Constant* FoldClamp2( const analysis::Constant* FoldClamp3( IRContext* context, Instruction* inst, const std::vector<const analysis::Constant*>& constants) { - assert(inst->opcode() == spv::Op::OpExtInst && + assert(inst->opcode() == SpvOpExtInst && "Expecting an extended instruction."); assert(inst->GetSingleWordInOperand(0) == context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && @@ -1492,70 +1407,68 @@ void ConstantFoldingRules::AddFoldingRules() { // applies to the instruction, the rest of the rules will not be attempted. // Take that into consideration. - rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants()); + rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants()); - rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants()); - rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants()); + rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants()); - rules_[spv::Op::OpConvertFToS].push_back(FoldFToI()); - rules_[spv::Op::OpConvertFToU].push_back(FoldFToI()); - rules_[spv::Op::OpConvertSToF].push_back(FoldIToF()); - rules_[spv::Op::OpConvertUToF].push_back(FoldIToF()); + rules_[SpvOpConvertFToS].push_back(FoldFToI()); + rules_[SpvOpConvertFToU].push_back(FoldFToI()); + rules_[SpvOpConvertSToF].push_back(FoldIToF()); + rules_[SpvOpConvertUToF].push_back(FoldIToF()); - rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants()); - rules_[spv::Op::OpFAdd].push_back(FoldFAdd()); - rules_[spv::Op::OpFDiv].push_back(FoldFDiv()); - rules_[spv::Op::OpFMul].push_back(FoldFMul()); - rules_[spv::Op::OpFSub].push_back(FoldFSub()); + rules_[SpvOpDot].push_back(FoldOpDotWithConstants()); + rules_[SpvOpFAdd].push_back(FoldFAdd()); + rules_[SpvOpFDiv].push_back(FoldFDiv()); + rules_[SpvOpFMul].push_back(FoldFMul()); + rules_[SpvOpFSub].push_back(FoldFSub()); - rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual()); + rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual()); - rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual()); + rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual()); - rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual()); + rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual()); - rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual()); + rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual()); - rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan()); - rules_[spv::Op::OpFOrdLessThan].push_back( - FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan)); + rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan()); + rules_[SpvOpFOrdLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThan)); - rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan()); - rules_[spv::Op::OpFUnordLessThan].push_back( - FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan)); + rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan()); + rules_[SpvOpFUnordLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThan)); - rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); - rules_[spv::Op::OpFOrdGreaterThan].push_back( - FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan)); + rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); + rules_[SpvOpFOrdGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThan)); - rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); - rules_[spv::Op::OpFUnordGreaterThan].push_back( - FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan)); + rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); + rules_[SpvOpFUnordGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThan)); - rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); - rules_[spv::Op::OpFOrdLessThanEqual].push_back( - FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual)); + rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); + rules_[SpvOpFOrdLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual)); - rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); - rules_[spv::Op::OpFUnordLessThanEqual].push_back( - FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual)); + rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); + rules_[SpvOpFUnordLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual)); - rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); - rules_[spv::Op::OpFOrdGreaterThanEqual].push_back( - FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual)); + rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); + rules_[SpvOpFOrdGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual)); - rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( - FoldFUnordGreaterThanEqual()); - rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( - FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual)); + rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); + rules_[SpvOpFUnordGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual)); - rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); - rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar()); - rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix()); - rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector()); + rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); + rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); + rules_[SpvOpVectorTimesMatrix].push_back(FoldVectorTimesMatrix()); + rules_[SpvOpMatrixTimesVector].push_back(FoldMatrixTimesVector()); - rules_[spv::Op::OpFNegate].push_back(FoldFNegate()); - rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16()); + rules_[SpvOpFNegate].push_back(FoldFNegate()); + rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16()); // Add rules for GLSLstd450 FeatureManager* feature_manager = context_->get_feature_mgr(); |