aboutsummaryrefslogtreecommitdiff
path: root/source/opt/const_folding_rules.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/const_folding_rules.cpp')
-rw-r--r--source/opt/const_folding_rules.cpp257
1 files changed, 85 insertions, 172 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 14f22089..0ad755c9 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -19,7 +19,8 @@
namespace spvtools {
namespace opt {
namespace {
-constexpr uint32_t kExtractCompositeIdInIdx = 0;
+
+const uint32_t kExtractCompositeIdInIdx = 0;
// Returns a constants with the value NaN of the given type. Only works for
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
@@ -119,97 +120,11 @@ ConstantFoldingRule FoldExtractWithConstants() {
};
}
-// Folds an OpcompositeInsert where input is a composite constant.
-ConstantFoldingRule FoldInsertWithConstants() {
- return [](IRContext* context, Instruction* inst,
- const std::vector<const analysis::Constant*>& constants)
- -> const analysis::Constant* {
- analysis::ConstantManager* const_mgr = context->get_constant_mgr();
- const analysis::Constant* object = constants[0];
- const analysis::Constant* composite = constants[1];
- if (object == nullptr || composite == nullptr) {
- return nullptr;
- }
-
- // If there is more than 1 index, then each additional constant used by the
- // index will need to be recreated to use the inserted object.
- std::vector<const analysis::Constant*> chain;
- std::vector<const analysis::Constant*> components;
- const analysis::Type* type = nullptr;
- const uint32_t final_index = (inst->NumInOperands() - 1);
-
- // Work down hierarchy of all indexes
- for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
- type = composite->type();
-
- if (composite->AsNullConstant()) {
- // Make new composite so it can be inserted in the index with the
- // non-null value
- const auto new_composite = const_mgr->GetNullCompositeConstant(type);
- // Keep track of any indexes along the way to last index
- if (i != final_index) {
- chain.push_back(new_composite);
- }
- components = new_composite->AsCompositeConstant()->GetComponents();
- } else {
- // Keep track of any indexes along the way to last index
- if (i != final_index) {
- chain.push_back(composite);
- }
- components = composite->AsCompositeConstant()->GetComponents();
- }
- const uint32_t index = inst->GetSingleWordInOperand(i);
- composite = components[index];
- }
-
- // Final index in hierarchy is inserted with new object.
- const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
- std::vector<uint32_t> ids;
- for (size_t i = 0; i < components.size(); i++) {
- const analysis::Constant* constant =
- (i == final_operand) ? object : components[i];
- Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
- ids.push_back(member_inst->result_id());
- }
- const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
-
- // Work backwards up the chain and replace each index with new constant.
- for (size_t i = chain.size(); i > 0; i--) {
- // Need to insert any previous instruction into the module first.
- // Can't just insert in types_values_begin() because it will move above
- // where the types are declared.
- // Can't compare with location of inst because not all new added
- // instructions are added to types_values_
- auto iter = context->types_values_end();
- Module::inst_iterator* pos = &iter;
- const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
-
- composite = chain[i - 1];
- components = composite->AsCompositeConstant()->GetComponents();
- type = composite->type();
- ids.clear();
- for (size_t k = 0; k < components.size(); k++) {
- const uint32_t index =
- inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
- const analysis::Constant* constant =
- (k == index) ? new_constant : components[k];
- const uint32_t constant_id =
- const_mgr->FindDeclaredConstant(constant, 0);
- ids.push_back(constant_id);
- }
- new_constant = const_mgr->GetConstant(type, ids);
- }
-
- // If multiple constants were created, only need to return the top index.
- return new_constant;
- };
-}
-
ConstantFoldingRule FoldVectorShuffleWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- assert(inst->opcode() == spv::Op::OpVectorShuffle);
+ assert(inst->opcode() == SpvOpVectorShuffle);
const analysis::Constant* c1 = constants[0];
const analysis::Constant* c2 = constants[1];
if (c1 == nullptr || c2 == nullptr) {
@@ -265,7 +180,7 @@ ConstantFoldingRule FoldVectorTimesScalar() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
+ assert(inst->opcode() == SpvOpVectorTimesScalar);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
@@ -340,7 +255,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
+ assert(inst->opcode() == SpvOpVectorTimesMatrix);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
@@ -433,7 +348,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
- assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
+ assert(inst->opcode() == SpvOpMatrixTimesVector);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
@@ -543,9 +458,9 @@ ConstantFoldingRule FoldCompositeWithConstants() {
}
uint32_t component_type_id = 0;
- if (type_inst->opcode() == spv::Op::OpTypeStruct) {
+ if (type_inst->opcode() == SpvOpTypeStruct) {
component_type_id = type_inst->GetSingleWordInOperand(i);
- } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
+ } else if (type_inst->opcode() == SpvOpTypeArray) {
component_type_id = type_inst->GetSingleWordInOperand(0);
}
@@ -594,7 +509,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
}
const analysis::Constant* arg =
- (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
+ (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
if (arg == nullptr) {
return nullptr;
@@ -684,7 +599,7 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
- if (inst->opcode() == spv::Op::OpExtInst) {
+ if (inst->opcode() == SpvOpExtInst) {
return FoldFPBinaryOp(scalar_rule, inst->type_id(),
{constants[1], constants[2]}, context);
}
@@ -1042,7 +957,7 @@ UnaryScalarFoldingRule FoldFNegateOp() {
ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
-ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
+ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
return [cmp_opcode](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
@@ -1070,7 +985,7 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
return nullptr;
}
- if (operand_inst->opcode() != spv::Op::OpExtInst) {
+ if (operand_inst->opcode() != SpvOpExtInst) {
return nullptr;
}
@@ -1094,25 +1009,25 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
bool result = false;
switch (cmp_opcode) {
- case spv::Op::OpFOrdLessThan:
- case spv::Op::OpFUnordLessThan:
- case spv::Op::OpFOrdGreaterThanEqual:
- case spv::Op::OpFUnordGreaterThanEqual:
+ case SpvOpFOrdLessThan:
+ case SpvOpFUnordLessThan:
+ case SpvOpFOrdGreaterThanEqual:
+ case SpvOpFUnordGreaterThanEqual:
if (constants[0]) {
if (min_const) {
if (constants[0]->GetValueAsDouble() <
min_const->GetValueAsDouble()) {
found_result = true;
- result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
- cmp_opcode == spv::Op::OpFUnordLessThan);
+ result = (cmp_opcode == SpvOpFOrdLessThan ||
+ cmp_opcode == SpvOpFUnordLessThan);
}
}
if (max_const) {
if (constants[0]->GetValueAsDouble() >=
max_const->GetValueAsDouble()) {
found_result = true;
- result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
- cmp_opcode == spv::Op::OpFUnordLessThan);
+ result = !(cmp_opcode == SpvOpFOrdLessThan ||
+ cmp_opcode == SpvOpFUnordLessThan);
}
}
}
@@ -1122,8 +1037,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
if (max_const->GetValueAsDouble() <
constants[1]->GetValueAsDouble()) {
found_result = true;
- result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
- cmp_opcode == spv::Op::OpFUnordLessThan);
+ result = (cmp_opcode == SpvOpFOrdLessThan ||
+ cmp_opcode == SpvOpFUnordLessThan);
}
}
@@ -1131,31 +1046,31 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
if (min_const->GetValueAsDouble() >=
constants[1]->GetValueAsDouble()) {
found_result = true;
- result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
- cmp_opcode == spv::Op::OpFUnordLessThan);
+ result = !(cmp_opcode == SpvOpFOrdLessThan ||
+ cmp_opcode == SpvOpFUnordLessThan);
}
}
}
break;
- case spv::Op::OpFOrdGreaterThan:
- case spv::Op::OpFUnordGreaterThan:
- case spv::Op::OpFOrdLessThanEqual:
- case spv::Op::OpFUnordLessThanEqual:
+ case SpvOpFOrdGreaterThan:
+ case SpvOpFUnordGreaterThan:
+ case SpvOpFOrdLessThanEqual:
+ case SpvOpFUnordLessThanEqual:
if (constants[0]) {
if (min_const) {
if (constants[0]->GetValueAsDouble() <=
min_const->GetValueAsDouble()) {
found_result = true;
- result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
- cmp_opcode == spv::Op::OpFUnordLessThanEqual);
+ result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
+ cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
if (max_const) {
if (constants[0]->GetValueAsDouble() >
max_const->GetValueAsDouble()) {
found_result = true;
- result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
- cmp_opcode == spv::Op::OpFUnordLessThanEqual);
+ result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
+ cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
}
@@ -1165,8 +1080,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
if (max_const->GetValueAsDouble() <=
constants[1]->GetValueAsDouble()) {
found_result = true;
- result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
- cmp_opcode == spv::Op::OpFUnordLessThanEqual);
+ result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
+ cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
@@ -1174,8 +1089,8 @@ ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
if (min_const->GetValueAsDouble() >
constants[1]->GetValueAsDouble()) {
found_result = true;
- result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
- cmp_opcode == spv::Op::OpFUnordLessThanEqual);
+ result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
+ cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
}
@@ -1202,7 +1117,7 @@ ConstantFoldingRule FoldFMix() {
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
- assert(inst->opcode() == spv::Op::OpExtInst &&
+ assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
@@ -1352,7 +1267,7 @@ const analysis::Constant* FoldMax(const analysis::Type* result_type,
const analysis::Constant* FoldClamp1(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
- assert(inst->opcode() == spv::Op::OpExtInst &&
+ assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
@@ -1378,7 +1293,7 @@ const analysis::Constant* FoldClamp1(
const analysis::Constant* FoldClamp2(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
- assert(inst->opcode() == spv::Op::OpExtInst &&
+ assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
@@ -1406,7 +1321,7 @@ const analysis::Constant* FoldClamp2(
const analysis::Constant* FoldClamp3(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
- assert(inst->opcode() == spv::Op::OpExtInst &&
+ assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
@@ -1492,70 +1407,68 @@ void ConstantFoldingRules::AddFoldingRules() {
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
- rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
+ rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
- rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
- rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
+ rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
- rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
- rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
- rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
- rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
+ rules_[SpvOpConvertFToS].push_back(FoldFToI());
+ rules_[SpvOpConvertFToU].push_back(FoldFToI());
+ rules_[SpvOpConvertSToF].push_back(FoldIToF());
+ rules_[SpvOpConvertUToF].push_back(FoldIToF());
- rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
- rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
- rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
- rules_[spv::Op::OpFMul].push_back(FoldFMul());
- rules_[spv::Op::OpFSub].push_back(FoldFSub());
+ rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
+ rules_[SpvOpFAdd].push_back(FoldFAdd());
+ rules_[SpvOpFDiv].push_back(FoldFDiv());
+ rules_[SpvOpFMul].push_back(FoldFMul());
+ rules_[SpvOpFSub].push_back(FoldFSub());
- rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
+ rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
- rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
+ rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
- rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
+ rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
- rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
+ rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
- rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
- rules_[spv::Op::OpFOrdLessThan].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
+ rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
+ rules_[SpvOpFOrdLessThan].push_back(
+ FoldFClampFeedingCompare(SpvOpFOrdLessThan));
- rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
- rules_[spv::Op::OpFUnordLessThan].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
+ rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
+ rules_[SpvOpFUnordLessThan].push_back(
+ FoldFClampFeedingCompare(SpvOpFUnordLessThan));
- rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
- rules_[spv::Op::OpFOrdGreaterThan].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
+ rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
+ rules_[SpvOpFOrdGreaterThan].push_back(
+ FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
- rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
- rules_[spv::Op::OpFUnordGreaterThan].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
+ rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
+ rules_[SpvOpFUnordGreaterThan].push_back(
+ FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
- rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
- rules_[spv::Op::OpFOrdLessThanEqual].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
+ rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
+ rules_[SpvOpFOrdLessThanEqual].push_back(
+ FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
- rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
- rules_[spv::Op::OpFUnordLessThanEqual].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
+ rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
+ rules_[SpvOpFUnordLessThanEqual].push_back(
+ FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
- rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
- rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
+ rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
+ rules_[SpvOpFOrdGreaterThanEqual].push_back(
+ FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
- rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
- FoldFUnordGreaterThanEqual());
- rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
- FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
+ rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
+ rules_[SpvOpFUnordGreaterThanEqual].push_back(
+ FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
- rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
- rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
- rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
- rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
+ rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
+ rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
+ rules_[SpvOpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
+ rules_[SpvOpMatrixTimesVector].push_back(FoldMatrixTimesVector());
- rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
- rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
+ rules_[SpvOpFNegate].push_back(FoldFNegate());
+ rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
// Add rules for GLSLstd450
FeatureManager* feature_manager = context_->get_feature_mgr();