aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoralan-baker <alanbaker@google.com>2024-05-14 15:13:54 -0400
committerGitHub <noreply@github.com>2024-05-14 15:13:54 -0400
commitccf3e3c1035189bbb50793d1e249a2c0ba3388a3 (patch)
treee3c5304bdf1c8613a9859dc7322ab22fcb420ec6
parent199038f10cbe56bf7cbfeb5472eb0a25af2f09f5 (diff)
downloadSPIRV-Tools-upstream-main.tar.gz
Improve matrix layout validation (#5662)upstream-main
* Check for matrix decorations on arrays of matrices * MatrixStide, RowMajor and ColMajor can be applied to matrix or arrays of matrix members * Check that matrix stride satisfies alignment in arrays
-rw-r--r--source/val/validate_decorations.cpp19
-rw-r--r--test/val/val_decoration_test.cpp226
2 files changed, 244 insertions, 1 deletions
diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp
index 0a7df658..811c0e67 100644
--- a/source/val/validate_decorations.cpp
+++ b/source/val/validate_decorations.cpp
@@ -623,6 +623,14 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
seen[next_offset % 16] = true;
}
+ } else if (spv::Op::OpTypeMatrix == element_inst->opcode()) {
+ // Matrix stride would be on the array element in the struct.
+ const auto stride = constraint.matrix_stride;
+ if (!IsAlignedTo(stride, alignment)) {
+ return fail(memberIdx)
+ << "is a matrix with stride " << stride
+ << " not satisfying alignment to " << alignment;
+ }
}
// Proceed to the element in case it is an array.
@@ -675,7 +683,16 @@ bool checkForRequiredDecoration(uint32_t struct_id,
spv::Op type, ValidationState_t& vstate) {
const auto& members = getStructMembers(struct_id, vstate);
for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) {
- const auto id = members[memberIdx];
+ auto id = members[memberIdx];
+ if (type == spv::Op::OpTypeMatrix) {
+ // Matrix decorations also apply to arrays of matrices.
+ auto memberInst = vstate.FindDef(id);
+ while (memberInst->opcode() == spv::Op::OpTypeArray ||
+ memberInst->opcode() == spv::Op::OpTypeRuntimeArray) {
+ memberInst = vstate.FindDef(memberInst->GetOperandAs<uint32_t>(1u));
+ }
+ id = memberInst->id();
+ }
if (type != vstate.FindDef(id)->opcode()) continue;
bool found = false;
for (auto& dec : vstate.id_decorations(id)) {
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp
index ba0e9597..b525477c 100644
--- a/test/val/val_decoration_test.cpp
+++ b/test/val/val_decoration_test.cpp
@@ -9444,6 +9444,232 @@ OpFunctionEnd
"contains an array with stride 0, but with an element size of 4"));
}
+TEST_F(ValidateDecorations, MatrixArrayMissingMajorness) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+OpDecorate %array ArrayStride 32
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%block = OpTypeStruct %array
+%ptr = OpTypePointer Uniform %block
+%var = OpVariable %ptr Uniform
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "must be explicitly laid out with RowMajor or ColMajor decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixArrayMissingStride) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 ColMajor
+OpDecorate %array ArrayStride 32
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%block = OpTypeStruct %array
+%ptr = OpTypePointer Uniform %block
+%var = OpVariable %ptr Uniform
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("must be explicitly laid out with MatrixStride decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixArrayBadStride) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 ColMajor
+OpMemberDecorate %block 0 MatrixStride 8
+OpDecorate %array ArrayStride 32
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%block = OpTypeStruct %array
+%ptr = OpTypePointer Uniform %block
+%var = OpVariable %ptr Uniform
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
+}
+
+TEST_F(ValidateDecorations, MatrixArrayArrayMissingMajorness) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 MatrixStride 16
+OpDecorate %array ArrayStride 32
+OpDecorate %rta ArrayStride 64
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%rta = OpTypeRuntimeArray %array
+%block = OpTypeStruct %rta
+%ptr = OpTypePointer StorageBuffer %block
+%var = OpVariable %ptr StorageBuffer
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "must be explicitly laid out with RowMajor or ColMajor decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixArrayArrayMissingStride) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 ColMajor
+OpDecorate %array ArrayStride 32
+OpDecorate %rta ArrayStride 64
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%rta = OpTypeRuntimeArray %array
+%block = OpTypeStruct %rta
+%ptr = OpTypePointer StorageBuffer %block
+%var = OpVariable %ptr StorageBuffer
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("must be explicitly laid out with MatrixStride decorations"));
+}
+
+TEST_F(ValidateDecorations, MatrixArrayArrayBadStride) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %var DescriptorSet 0
+OpDecorate %var Binding 0
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 0 ColMajor
+OpMemberDecorate %block 0 MatrixStride 8
+OpDecorate %array ArrayStride 32
+OpDecorate %a ArrayStride 64
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%int = OpTypeInt 32 0
+%int_2 = OpConstant %int 2
+%vec = OpTypeVector %float 2
+%mat = OpTypeMatrix %vec 2
+%array = OpTypeArray %mat %int_2
+%a = OpTypeArray %array %int_2
+%block = OpTypeStruct %a
+%ptr = OpTypePointer Uniform %block
+%var = OpVariable %ptr Uniform
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools