[feat](trx-backend-soapysdr): implement ARM NEON optimizations

Add NEON (aarch64) vectorized paths mirroring the existing AVX2 paths:

- demod/math_arm.rs: replace the no-op placeholder with a full NEON FM
  discriminator that processes 4 samples per iteration using a 7th-order
  minimax atan polynomial and branchless atan2 with argument reduction,
  matching the accuracy of the AVX2 path (max error ~2.4e-7 rad).
  32-bit ARM retains the scalar fallback.

- dsp/filter.rs: add mul_freq_domain_neon() that deinterleaves 4 complex
  pairs via vuzpq/vzipq, performs complex multiply with vmulq/vaddq/vsubq,
  then reinterleaves. On aarch64 this path is always taken (NEON is
  mandatory); scalar fallback remains for other targets.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Stanislaw Grams <stanislawgrams@gmail.com>
This commit is contained in:
2026-03-17 01:22:56 +01:00
parent 088f05050c
commit bff3be38be
2 changed files with 188 additions and 2 deletions
@@ -5,8 +5,141 @@
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
use num_complex::Complex; use num_complex::Complex;
/// Placeholder hook for future ARM/NEON FM discriminator vectorization. /// 7th-order minimax atan approximation for |z| <= 1.
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn atan_poly_neon(
z: std::arch::aarch64::float32x4_t,
) -> std::arch::aarch64::float32x4_t {
use std::arch::aarch64::*;
let c0 = vdupq_n_f32(0.999_999_5_f32);
let c1 = vdupq_n_f32(-0.333_326_1_f32);
let c2 = vdupq_n_f32(0.199_777_1_f32);
let c3 = vdupq_n_f32(-0.138_776_8_f32);
let z2 = vmulq_f32(z, z);
let p = vaddq_f32(c2, vmulq_f32(z2, c3));
let p = vaddq_f32(c1, vmulq_f32(z2, p));
let p = vaddq_f32(c0, vmulq_f32(z2, p));
vmulq_f32(z, p)
}
/// Branchless NEON atan2 using argument reduction and polynomial evaluation.
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn fast_atan2_4_neon(
y: std::arch::aarch64::float32x4_t,
x: std::arch::aarch64::float32x4_t,
) -> std::arch::aarch64::float32x4_t {
use std::arch::aarch64::*;
let abs_mask = vdupq_n_u32(0x7FFF_FFFF_u32);
let sign_mask = vdupq_n_u32(0x8000_0000_u32);
let pi = vdupq_n_f32(std::f32::consts::PI);
let pi_2 = vdupq_n_f32(std::f32::consts::FRAC_PI_2);
let abs_y = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(y), abs_mask));
let abs_x = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), abs_mask));
let swap_mask = vcgtq_f32(abs_y, abs_x);
let num = vbslq_f32(swap_mask, x, y);
let den = vbslq_f32(swap_mask, y, x);
let eps = vdupq_n_f32(1.0e-30_f32);
let den_is_zero = vceqq_f32(den, vdupq_n_f32(0.0));
let safe_den = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(den),
vandq_u32(den_is_zero, vreinterpretq_u32_f32(eps)),
));
let atan_input = vdivq_f32(num, safe_den);
let mut result = atan_poly_neon(atan_input);
let pi_2_with_sign = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(pi_2),
vandq_u32(vreinterpretq_u32_f32(atan_input), sign_mask),
));
let adj = vsubq_f32(pi_2_with_sign, result);
result = vbslq_f32(swap_mask, adj, result);
let x_sign_mask = vreinterpretq_f32_s32(vshrq_n_s32::<31>(vreinterpretq_s32_f32(x)));
let pi_xor_y_sign = vreinterpretq_f32_u32(veorq_u32(
vreinterpretq_u32_f32(pi),
vandq_u32(sign_mask, vreinterpretq_u32_f32(y)),
));
let correction = vreinterpretq_f32_u32(vandq_u32(
vreinterpretq_u32_f32(pi_xor_y_sign),
vreinterpretq_u32_f32(x_sign_mask),
));
vaddq_f32(result, correction)
}
/// NEON FM discriminator: processes 4 samples per iteration.
#[cfg(target_arch = "aarch64")]
pub(super) fn demod_fm_body_neon(
samples: &[Complex<f32>],
start: usize,
inv_pi: f32,
output: &mut Vec<f32>,
) -> usize {
unsafe { demod_fm_body_neon_impl(samples, start, inv_pi, output) }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn demod_fm_body_neon_impl(
samples: &[Complex<f32>],
start: usize,
inv_pi: f32,
output: &mut Vec<f32>,
) -> usize {
use std::arch::aarch64::*;
let len = samples.len();
let mut idx = start;
let mut cur_re = [0.0_f32; 4];
let mut cur_im = [0.0_f32; 4];
let mut prev_re = [0.0_f32; 4];
let mut prev_im = [0.0_f32; 4];
let mut angles = [0.0_f32; 4];
let inv_pi_v = vdupq_n_f32(inv_pi);
while idx + 4 <= len {
for lane in 0..4 {
let cur = samples[idx + lane];
let prev = samples[idx + lane - 1];
cur_re[lane] = cur.re;
cur_im[lane] = cur.im;
prev_re[lane] = prev.re;
prev_im[lane] = prev.im;
}
let cur_re_v = vld1q_f32(cur_re.as_ptr());
let cur_im_v = vld1q_f32(cur_im.as_ptr());
let prev_re_v = vld1q_f32(prev_re.as_ptr());
let prev_im_v = vld1q_f32(prev_im.as_ptr());
// Conjugate multiply: s[n] * conj(s[n-1])
// re = cur_re * prev_re + cur_im * prev_im
// im = cur_im * prev_re - cur_re * prev_im
let re_v = vaddq_f32(
vmulq_f32(cur_re_v, prev_re_v),
vmulq_f32(cur_im_v, prev_im_v),
);
let im_v = vsubq_f32(
vmulq_f32(cur_im_v, prev_re_v),
vmulq_f32(cur_re_v, prev_im_v),
);
let angle_v = vmulq_f32(fast_atan2_4_neon(im_v, re_v), inv_pi_v);
vst1q_f32(angles.as_mut_ptr(), angle_v);
output.extend_from_slice(&angles);
idx += 4;
}
idx
}
/// On 32-bit ARM, fall back to the scalar path.
#[cfg(target_arch = "arm")]
pub(super) fn demod_fm_body_neon( pub(super) fn demod_fm_body_neon(
_samples: &[Complex<f32>], _samples: &[Complex<f32>],
start: usize, start: usize,
@@ -179,6 +179,51 @@ unsafe fn mul_freq_domain_avx2(
mul_freq_domain_scalar(&mut buf[i..len], &h_freq[i..len], scale); mul_freq_domain_scalar(&mut buf[i..len], &h_freq[i..len], scale);
} }
/// NEON frequency-domain complex multiply: processes 4 complex pairs per iteration.
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn mul_freq_domain_neon(
buf: &mut [FftComplex<f32>],
h_freq: &[FftComplex<f32>],
scale: f32,
) {
use std::arch::aarch64::*;
let len = buf.len().min(h_freq.len());
let mut i = 0usize;
let scale_v = vdupq_n_f32(scale);
while i + 4 <= len {
let x_ptr = buf.as_mut_ptr().add(i) as *mut f32;
let h_ptr = h_freq.as_ptr().add(i) as *const f32;
// Load 4 complex numbers as two float32x4_t: [re0,im0,re1,im1] and [re2,im2,re3,im3]
let x_lo = vld1q_f32(x_ptr);
let x_hi = vld1q_f32(x_ptr.add(4));
let h_lo = vld1q_f32(h_ptr);
let h_hi = vld1q_f32(h_ptr.add(4));
// Deinterleave: .0 = [re0..re3], .1 = [im0..im3]
let x_ri = vuzpq_f32(x_lo, x_hi);
let h_ri = vuzpq_f32(h_lo, h_hi);
let (x_re, x_im) = (x_ri.0, x_ri.1);
let (h_re, h_im) = (h_ri.0, h_ri.1);
// Complex multiply: out.re = x.re*h.re - x.im*h.im, out.im = x.re*h.im + x.im*h.re
let out_re = vmulq_f32(vsubq_f32(vmulq_f32(x_re, h_re), vmulq_f32(x_im, h_im)), scale_v);
let out_im = vmulq_f32(vaddq_f32(vmulq_f32(x_re, h_im), vmulq_f32(x_im, h_re)), scale_v);
// Reinterleave: .0 = [re0,im0,re1,im1], .1 = [re2,im2,re3,im3]
let out = vzipq_f32(out_re, out_im);
vst1q_f32(x_ptr, out.0);
vst1q_f32(x_ptr.add(4), out.1);
i += 4;
}
mul_freq_domain_scalar(&mut buf[i..len], &h_freq[i..len], scale);
}
fn mul_freq_domain(buf: &mut [FftComplex<f32>], h_freq: &[FftComplex<f32>], scale: f32) { fn mul_freq_domain(buf: &mut [FftComplex<f32>], h_freq: &[FftComplex<f32>], scale: f32) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{ {
@@ -190,6 +235,14 @@ fn mul_freq_domain(buf: &mut [FftComplex<f32>], h_freq: &[FftComplex<f32>], scal
} }
} }
#[cfg(target_arch = "aarch64")]
{
unsafe {
mul_freq_domain_neon(buf, h_freq, scale);
}
return;
}
mul_freq_domain_scalar(buf, h_freq, scale); mul_freq_domain_scalar(buf, h_freq, scale);
} }