aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Steed <george.steed@arm.com>2024-05-09 12:09:57 +0000
committerJames Zern <jzern@google.com>2024-05-10 16:28:12 -0700
commit8298efe1a467070e13ad1b5c22dd51c520998df4 (patch)
treeb2fa22f7ab749029ed8757fbbceb2498d02d992b
parentd2cb57022d7cd0af4780b6aa99466031ea25de04 (diff)
downloadlibaom-main.tar.gz
{,highbd_}intrapred_neon.c: Avoid over-reads in z1 and z3 predsHEADmastermain
The existing z1 and z3 predictors already contain checks to see if the first element of the vector would over-read, however this is not sufficient since the vector may straddle the end of the input array. To get around this, add an additional check against the end of the array. If we would over-read, load a full vector up to the end of the array and then use TBL to shuffle the data into the correct place. This also means that we no longer need the compare and BSL at the end of each loop iteration to select between the computed data or the value of the last element duplicated. Bug: aomedia:3571, b:338345960 Test: presubmit Change-Id: I03e2313b9bf0b44d64811fff1bedf4eb7381518a (cherry picked from commit f1b43b5c0d0c98a37713e9939a782ebe014c1d1f)
-rw-r--r--README.android3
-rw-r--r--aom_dsp/arm/highbd_intrapred_neon.c108
-rw-r--r--aom_dsp/arm/intrapred_neon.c65
3 files changed, 136 insertions, 40 deletions
diff --git a/README.android b/README.android
index 02616e0ff..54d2961a6 100644
--- a/README.android
+++ b/README.android
@@ -46,3 +46,6 @@ Tools needed to build libaom:
Generate config files that contain the source list for each platform.
A list of prerequisites is at the top of generate_config.sh.
+
+Cherry-picks:
+f1b43b5c0d {,highbd_}intrapred_neon.c: Avoid over-reads in z1 and z3 preds
diff --git a/aom_dsp/arm/highbd_intrapred_neon.c b/aom_dsp/arm/highbd_intrapred_neon.c
index dc47974c6..e66f523e3 100644
--- a/aom_dsp/arm/highbd_intrapred_neon.c
+++ b/aom_dsp/arm/highbd_intrapred_neon.c
@@ -1293,6 +1293,33 @@ static AOM_FORCE_INLINE uint16x8_t highbd_dr_z1_apply_shift_x8(uint16x8_t a0,
highbd_dr_z1_apply_shift_x4(vget_high_u16(a0), vget_high_u16(a1), shift));
}
+// clang-format off
+static const uint8_t kLoadMaxShuffles[] = {
+ 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 8, 9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15,
+ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15,
+ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14, 15,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+};
+// clang-format on
+
+static INLINE uint16x8_t zn_load_masked_neon(const uint16_t *ptr,
+ int shuffle_idx) {
+ uint8x16_t shuffle = vld1q_u8(&kLoadMaxShuffles[16 * shuffle_idx]);
+ uint8x16_t src = vreinterpretq_u8_u16(vld1q_u16(ptr));
+#if AOM_ARCH_AARCH64
+ return vreinterpretq_u16_u8(vqtbl1q_u8(src, shuffle));
+#else
+ uint8x8x2_t src2 = { { vget_low_u8(src), vget_high_u8(src) } };
+ uint8x8_t lo = vtbl2_u8(src2, vget_low_u8(shuffle));
+ uint8x8_t hi = vtbl2_u8(src2, vget_high_u8(shuffle));
+ return vreinterpretq_u16_u8(vcombine_u8(lo, hi));
+#endif
+}
+
static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst,
ptrdiff_t stride, int bw,
int bh,
@@ -1336,13 +1363,26 @@ static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst,
} else {
int c = 0;
do {
- const uint16x8_t a0 = vld1q_u16(&above[base + c]);
- const uint16x8_t a1 = vld1q_u16(&above[base + c + 1]);
- const uint16x8_t val = highbd_dr_z1_apply_shift_x8(a0, a1, shift);
- const uint16x8_t cmp =
- vcgtq_s16(vdupq_n_s16(max_base_x - base - c), iota1x8);
- const uint16x8_t res = vbslq_u16(cmp, val, vdupq_n_u16(above_max));
- vst1q_u16(dst + c, res);
+ uint16x8_t a0;
+ uint16x8_t a1;
+ if (base + c >= max_base_x) {
+ a0 = a1 = vdupq_n_u16(above_max);
+ } else {
+ if (base + c + 7 >= max_base_x) {
+ int shuffle_idx = max_base_x - base - c;
+ a0 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
+ } else {
+ a0 = vld1q_u16(above + base + c);
+ }
+ if (base + c + 8 >= max_base_x) {
+ int shuffle_idx = max_base_x - base - c - 1;
+ a1 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
+ } else {
+ a1 = vld1q_u16(above + base + c + 1);
+ }
+ }
+
+ vst1q_u16(dst + c, highbd_dr_z1_apply_shift_x8(a0, a1, shift));
c += 8;
} while (c < bw);
}
@@ -2456,13 +2496,29 @@ void av1_highbd_dr_prediction_z2_neon(uint16_t *dst, ptrdiff_t stride, int bw,
val_lo = vmlal_lane_u16(val_lo, vget_low_u16(in1), (s1), (lane)); \
uint32x4_t val_hi = vmull_lane_u16(vget_high_u16(in0), (s0), (lane)); \
val_hi = vmlal_lane_u16(val_hi, vget_high_u16(in1), (s1), (lane)); \
- const uint16x8_t cmp = vaddq_u16((iota), vdupq_n_u16(base)); \
- const uint16x8_t res = vcombine_u16(vrshrn_n_u32(val_lo, (shift)), \
- vrshrn_n_u32(val_hi, (shift))); \
- *(out) = vbslq_u16(vcltq_u16(cmp, vdupq_n_u16(max_base_y)), res, \
- vdupq_n_u16(left_max)); \
+ *(out) = vcombine_u16(vrshrn_n_u32(val_lo, (shift)), \
+ vrshrn_n_u32(val_hi, (shift))); \
} while (0)
+static INLINE uint16x8x2_t z3_load_left_neon(const uint16_t *left0, int ofs,
+ int max_ofs) {
+ uint16x8_t r0;
+ uint16x8_t r1;
+ if (ofs + 7 >= max_ofs) {
+ int shuffle_idx = max_ofs - ofs;
+ r0 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
+ } else {
+ r0 = vld1q_u16(left0 + ofs);
+ }
+ if (ofs + 8 >= max_ofs) {
+ int shuffle_idx = max_ofs - ofs - 1;
+ r1 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
+ } else {
+ r1 = vld1q_u16(left0 + ofs + 1);
+ }
+ return (uint16x8x2_t){ { r0, r1 } };
+}
+
static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst,
ptrdiff_t stride, int bw,
int bh, const uint16_t *left,
@@ -2561,34 +2617,30 @@ static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst,
if (base0 >= max_base_y) {
out[0] = vdupq_n_u16(left_max);
} else {
- const uint16x8_t l00 = vld1q_u16(left + base0);
- const uint16x8_t l01 = vld1q_u16(left1 + base0);
- HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l00, l01,
- shifts0, shifts1, 0, 6);
+ const uint16x8x2_t l0 = z3_load_left_neon(left, base0, max_base_y);
+ HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l0.val[0],
+ l0.val[1], shifts0, shifts1, 0, 6);
}
if (base1 >= max_base_y) {
out[1] = vdupq_n_u16(left_max);
} else {
- const uint16x8_t l10 = vld1q_u16(left + base1);
- const uint16x8_t l11 = vld1q_u16(left1 + base1);
- HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l10, l11,
- shifts0, shifts1, 1, 6);
+ const uint16x8x2_t l1 = z3_load_left_neon(left, base1, max_base_y);
+ HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l1.val[0],
+ l1.val[1], shifts0, shifts1, 1, 6);
}
if (base2 >= max_base_y) {
out[2] = vdupq_n_u16(left_max);
} else {
- const uint16x8_t l20 = vld1q_u16(left + base2);
- const uint16x8_t l21 = vld1q_u16(left1 + base2);
- HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l20, l21,
- shifts0, shifts1, 2, 6);
+ const uint16x8x2_t l2 = z3_load_left_neon(left, base2, max_base_y);
+ HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l2.val[0],
+ l2.val[1], shifts0, shifts1, 2, 6);
}
if (base3 >= max_base_y) {
out[3] = vdupq_n_u16(left_max);
} else {
- const uint16x8_t l30 = vld1q_u16(left + base3);
- const uint16x8_t l31 = vld1q_u16(left1 + base3);
- HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l30, l31,
- shifts0, shifts1, 3, 6);
+ const uint16x8x2_t l3 = z3_load_left_neon(left, base3, max_base_y);
+ HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l3.val[0],
+ l3.val[1], shifts0, shifts1, 3, 6);
}
transpose_array_inplace_u16_4x8(out);
for (int r2 = 0; r2 < 4; ++r2) {
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index c3716b3a7..2c99154fd 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -1356,6 +1356,41 @@ static void dr_prediction_z1_32xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
}
}
+// clang-format off
+static const uint8_t kLoadMaxShuffles[] = {
+ 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+ 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15,
+ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15,
+ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15,
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 15,
+ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15,
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+};
+// clang-format on
+
+static INLINE uint8x16_t z1_load_masked_neon(const uint8_t *ptr,
+ int shuffle_idx) {
+ uint8x16_t shuffle = vld1q_u8(&kLoadMaxShuffles[16 * shuffle_idx]);
+ uint8x16_t src = vld1q_u8(ptr);
+#if AOM_ARCH_AARCH64
+ return vqtbl1q_u8(src, shuffle);
+#else
+ uint8x8x2_t src2 = { { vget_low_u8(src), vget_high_u8(src) } };
+ uint8x8_t lo = vtbl2_u8(src2, vget_low_u8(shuffle));
+ uint8x8_t hi = vtbl2_u8(src2, vget_high_u8(shuffle));
+ return vcombine_u8(lo, hi);
+#endif
+}
+
static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
const uint8_t *above, int dx) {
const int frac_bits = 6;
@@ -1369,7 +1404,6 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
// (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]);
- const uint8x16_t max_base_x128 = vdupq_n_u8(max_base_x);
int x = dx;
for (int r = 0; r < N; r++, dst += stride) {
@@ -1391,12 +1425,24 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
vcreate_u8(0x0F0E0D0C0B0A0908)));
for (int j = 0; j < 64; j += 16) {
- int mdif = max_base_x - (base + j);
- if (mdif <= 0) {
+ if (base + j >= max_base_x) {
vst1q_u8(dst + j, a_mbase_x);
} else {
- uint8x16_t a0_128 = vld1q_u8(above + base + j);
- uint8x16_t a1_128 = vld1q_u8(above + base + 1 + j);
+ uint8x16_t a0_128;
+ uint8x16_t a1_128;
+ if (base + j + 15 >= max_base_x) {
+ int shuffle_idx = max_base_x - base - j;
+ a0_128 = z1_load_masked_neon(above + (max_base_x - 15), shuffle_idx);
+ } else {
+ a0_128 = vld1q_u8(above + base + j);
+ }
+ if (base + j + 16 >= max_base_x) {
+ int shuffle_idx = max_base_x - base - j - 1;
+ a1_128 = z1_load_masked_neon(above + (max_base_x - 15), shuffle_idx);
+ } else {
+ a1_128 = vld1q_u8(above + base + j + 1);
+ }
+
uint16x8_t diff_lo = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128));
uint16x8_t diff_hi =
vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128));
@@ -1406,13 +1452,8 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_128), vdup_n_u8(32));
uint16x8_t res_lo = vmlaq_u16(a32_lo, diff_lo, shift);
uint16x8_t res_hi = vmlaq_u16(a32_hi, diff_hi, shift);
- uint8x16_t v_temp =
- vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5));
-
- uint8x16_t mask128 =
- vcgtq_u8(vqsubq_u8(max_base_x128, base_inc128), vdupq_n_u8(0));
- uint8x16_t res128 = vbslq_u8(mask128, v_temp, a_mbase_x);
- vst1q_u8(dst + j, res128);
+ vst1q_u8(dst + j,
+ vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5)));
base_inc128 = vaddq_u8(base_inc128, vdupq_n_u8(16));
}