aboutsummaryrefslogtreecommitdiff
path: root/test_conformance/subgroups/test_subgroup.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test_conformance/subgroups/test_subgroup.cpp')
-rw-r--r--test_conformance/subgroups/test_subgroup.cpp77
1 files changed, 34 insertions, 43 deletions
diff --git a/test_conformance/subgroups/test_subgroup.cpp b/test_conformance/subgroups/test_subgroup.cpp
index c0e49524..75e9d4ae 100644
--- a/test_conformance/subgroups/test_subgroup.cpp
+++ b/test_conformance/subgroups/test_subgroup.cpp
@@ -24,6 +24,13 @@ namespace {
// Any/All test functions
template <NonUniformVoteOp operation> struct AA
{
+ static void log_test(const WorkGroupParams &test_params,
+ const char *extra_text)
+ {
+ log_info(" sub_group_%s...%s\n", operation_names(operation),
+ extra_text);
+ }
+
static void gen(cl_int *x, cl_int *t, cl_int *m,
const WorkGroupParams &test_params)
{
@@ -35,7 +42,6 @@ template <NonUniformVoteOp operation> struct AA
int e;
ng = ng / nw;
ii = 0;
- log_info(" sub_group_%s...\n", operation_names(operation));
for (k = 0; k < ng; ++k)
{
for (j = 0; j < nj; ++j)
@@ -68,8 +74,8 @@ template <NonUniformVoteOp operation> struct AA
}
}
- static int chk(cl_int *x, cl_int *y, cl_int *mx, cl_int *my, cl_int *m,
- const WorkGroupParams &test_params)
+ static test_status chk(cl_int *x, cl_int *y, cl_int *mx, cl_int *my,
+ cl_int *m, const WorkGroupParams &test_params)
{
int ii, i, j, k, n;
int ng = test_params.global_workgroup_size;
@@ -124,51 +130,33 @@ template <NonUniformVoteOp operation> struct AA
y += nw;
m += 4 * nw;
}
- log_info(" sub_group_%s... passed\n", operation_names(operation));
return TEST_PASS;
}
};
-static const char *any_source = "__kernel void test_any(const __global Type "
- "*in, __global int4 *xy, __global Type *out)\n"
- "{\n"
- " int gid = get_global_id(0);\n"
- " XY(xy,gid);\n"
- " out[gid] = sub_group_any(in[gid]);\n"
- "}\n";
-
-static const char *all_source = "__kernel void test_all(const __global Type "
- "*in, __global int4 *xy, __global Type *out)\n"
- "{\n"
- " int gid = get_global_id(0);\n"
- " XY(xy,gid);\n"
- " out[gid] = sub_group_all(in[gid]);\n"
- "}\n";
-
-
template <typename T>
int run_broadcast_scan_reduction_for_type(RunTestForType rft)
{
int error = rft.run_impl<T, BC<T, SubgroupsBroadcastOp::broadcast>>(
- "test_bcast", bcast_source);
- error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::add_>>("test_redadd",
- redadd_source);
- error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::max_>>("test_redmax",
- redmax_source);
- error |= rft.run_impl<T, RED_NU<T, ArithmeticOp::min_>>("test_redmin",
- redmin_source);
- error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::add_>>("test_scinadd",
- scinadd_source);
- error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::max_>>("test_scinmax",
- scinmax_source);
- error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::min_>>("test_scinmin",
- scinmin_source);
- error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::add_>>("test_scexadd",
- scexadd_source);
- error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::max_>>("test_scexmax",
- scexmax_source);
- error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::min_>>("test_scexmin",
- scexmin_source);
+ "sub_group_broadcast");
+ error |=
+ rft.run_impl<T, RED_NU<T, ArithmeticOp::add_>>("sub_group_reduce_add");
+ error |=
+ rft.run_impl<T, RED_NU<T, ArithmeticOp::max_>>("sub_group_reduce_max");
+ error |=
+ rft.run_impl<T, RED_NU<T, ArithmeticOp::min_>>("sub_group_reduce_min");
+ error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::add_>>(
+ "sub_group_scan_inclusive_add");
+ error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::max_>>(
+ "sub_group_scan_inclusive_max");
+ error |= rft.run_impl<T, SCIN_NU<T, ArithmeticOp::min_>>(
+ "sub_group_scan_inclusive_min");
+ error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::add_>>(
+ "sub_group_scan_exclusive_add");
+ error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::max_>>(
+ "sub_group_scan_exclusive_max");
+ error |= rft.run_impl<T, SCEX_NU<T, ArithmeticOp::min_>>(
+ "sub_group_scan_exclusive_min");
return error;
}
@@ -181,11 +169,14 @@ int test_subgroup_functions(cl_device_id device, cl_context context,
constexpr size_t global_work_size = 2000;
constexpr size_t local_work_size = 200;
WorkGroupParams test_params(global_work_size, local_work_size);
+ test_params.save_kernel_source(sub_group_reduction_scan_source);
+ test_params.save_kernel_source(sub_group_generic_source,
+ "sub_group_broadcast");
+
RunTestForType rft(device, context, queue, num_elements, test_params);
int error =
- rft.run_impl<cl_int, AA<NonUniformVoteOp::any>>("test_any", any_source);
- error |=
- rft.run_impl<cl_int, AA<NonUniformVoteOp::all>>("test_all", all_source);
+ rft.run_impl<cl_int, AA<NonUniformVoteOp::any>>("sub_group_any");
+ error |= rft.run_impl<cl_int, AA<NonUniformVoteOp::all>>("sub_group_all");
error |= run_broadcast_scan_reduction_for_type<cl_int>(rft);
error |= run_broadcast_scan_reduction_for_type<cl_uint>(rft);
error |= run_broadcast_scan_reduction_for_type<cl_long>(rft);