From bff3be38befa1261c646a904f9a38f991edb4266 Mon Sep 17 00:00:00 2001 From: Stanislaw Grams Date: Tue, 17 Mar 2026 01:22:56 +0100 Subject: [PATCH] [feat](trx-backend-soapysdr): implement ARM NEON optimizations Add NEON (aarch64) vectorized paths mirroring the existing AVX2 paths: - demod/math_arm.rs: replace the no-op placeholder with a full NEON FM discriminator that processes 4 samples per iteration using a 7th-order minimax atan polynomial and branchless atan2 with argument reduction, matching the accuracy of the AVX2 path (max error ~2.4e-7 rad). 32-bit ARM retains the scalar fallback. - dsp/filter.rs: add mul_freq_domain_neon() that deinterleaves 4 complex pairs via vuzpq/vzipq, performs complex multiply with vmulq/vaddq/vsubq, then reinterleaves. On aarch64 this path is always taken (NEON is mandatory); scalar fallback remains for other targets. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Stanislaw Grams --- .../src/demod/math_arm.rs | 137 +++++++++++++++++- .../trx-backend-soapysdr/src/dsp/filter.rs | 53 +++++++ 2 files changed, 188 insertions(+), 2 deletions(-) diff --git a/src/trx-server/trx-backend/trx-backend-soapysdr/src/demod/math_arm.rs b/src/trx-server/trx-backend/trx-backend-soapysdr/src/demod/math_arm.rs index 254192c..d979674 100644 --- a/src/trx-server/trx-backend/trx-backend-soapysdr/src/demod/math_arm.rs +++ b/src/trx-server/trx-backend/trx-backend-soapysdr/src/demod/math_arm.rs @@ -5,8 +5,141 @@ #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] use num_complex::Complex; -/// Placeholder hook for future ARM/NEON FM discriminator vectorization. -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +/// 7th-order minimax atan approximation for |z| <= 1. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn atan_poly_neon( + z: std::arch::aarch64::float32x4_t, +) -> std::arch::aarch64::float32x4_t { + use std::arch::aarch64::*; + let c0 = vdupq_n_f32(0.999_999_5_f32); + let c1 = vdupq_n_f32(-0.333_326_1_f32); + let c2 = vdupq_n_f32(0.199_777_1_f32); + let c3 = vdupq_n_f32(-0.138_776_8_f32); + let z2 = vmulq_f32(z, z); + let p = vaddq_f32(c2, vmulq_f32(z2, c3)); + let p = vaddq_f32(c1, vmulq_f32(z2, p)); + let p = vaddq_f32(c0, vmulq_f32(z2, p)); + vmulq_f32(z, p) +} + +/// Branchless NEON atan2 using argument reduction and polynomial evaluation. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn fast_atan2_4_neon( + y: std::arch::aarch64::float32x4_t, + x: std::arch::aarch64::float32x4_t, +) -> std::arch::aarch64::float32x4_t { + use std::arch::aarch64::*; + let abs_mask = vdupq_n_u32(0x7FFF_FFFF_u32); + let sign_mask = vdupq_n_u32(0x8000_0000_u32); + let pi = vdupq_n_f32(std::f32::consts::PI); + let pi_2 = vdupq_n_f32(std::f32::consts::FRAC_PI_2); + + let abs_y = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(y), abs_mask)); + let abs_x = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), abs_mask)); + + let swap_mask = vcgtq_f32(abs_y, abs_x); + let num = vbslq_f32(swap_mask, x, y); + let den = vbslq_f32(swap_mask, y, x); + + let eps = vdupq_n_f32(1.0e-30_f32); + let den_is_zero = vceqq_f32(den, vdupq_n_f32(0.0)); + let safe_den = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(den), + vandq_u32(den_is_zero, vreinterpretq_u32_f32(eps)), + )); + let atan_input = vdivq_f32(num, safe_den); + let mut result = atan_poly_neon(atan_input); + + let pi_2_with_sign = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(pi_2), + vandq_u32(vreinterpretq_u32_f32(atan_input), sign_mask), + )); + let adj = vsubq_f32(pi_2_with_sign, result); + result = vbslq_f32(swap_mask, adj, result); + + let x_sign_mask = vreinterpretq_f32_s32(vshrq_n_s32::<31>(vreinterpretq_s32_f32(x))); + let pi_xor_y_sign = vreinterpretq_f32_u32(veorq_u32( + vreinterpretq_u32_f32(pi), + vandq_u32(sign_mask, vreinterpretq_u32_f32(y)), + )); + let correction = vreinterpretq_f32_u32(vandq_u32( + vreinterpretq_u32_f32(pi_xor_y_sign), + vreinterpretq_u32_f32(x_sign_mask), + )); + vaddq_f32(result, correction) +} + +/// NEON FM discriminator: processes 4 samples per iteration. +#[cfg(target_arch = "aarch64")] +pub(super) fn demod_fm_body_neon( + samples: &[Complex], + start: usize, + inv_pi: f32, + output: &mut Vec, +) -> usize { + unsafe { demod_fm_body_neon_impl(samples, start, inv_pi, output) } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn demod_fm_body_neon_impl( + samples: &[Complex], + start: usize, + inv_pi: f32, + output: &mut Vec, +) -> usize { + use std::arch::aarch64::*; + + let len = samples.len(); + let mut idx = start; + let mut cur_re = [0.0_f32; 4]; + let mut cur_im = [0.0_f32; 4]; + let mut prev_re = [0.0_f32; 4]; + let mut prev_im = [0.0_f32; 4]; + let mut angles = [0.0_f32; 4]; + let inv_pi_v = vdupq_n_f32(inv_pi); + + while idx + 4 <= len { + for lane in 0..4 { + let cur = samples[idx + lane]; + let prev = samples[idx + lane - 1]; + cur_re[lane] = cur.re; + cur_im[lane] = cur.im; + prev_re[lane] = prev.re; + prev_im[lane] = prev.im; + } + + let cur_re_v = vld1q_f32(cur_re.as_ptr()); + let cur_im_v = vld1q_f32(cur_im.as_ptr()); + let prev_re_v = vld1q_f32(prev_re.as_ptr()); + let prev_im_v = vld1q_f32(prev_im.as_ptr()); + + // Conjugate multiply: s[n] * conj(s[n-1]) + // re = cur_re * prev_re + cur_im * prev_im + // im = cur_im * prev_re - cur_re * prev_im + let re_v = vaddq_f32( + vmulq_f32(cur_re_v, prev_re_v), + vmulq_f32(cur_im_v, prev_im_v), + ); + let im_v = vsubq_f32( + vmulq_f32(cur_im_v, prev_re_v), + vmulq_f32(cur_re_v, prev_im_v), + ); + + let angle_v = vmulq_f32(fast_atan2_4_neon(im_v, re_v), inv_pi_v); + vst1q_f32(angles.as_mut_ptr(), angle_v); + output.extend_from_slice(&angles); + + idx += 4; + } + + idx +} + +/// On 32-bit ARM, fall back to the scalar path. +#[cfg(target_arch = "arm")] pub(super) fn demod_fm_body_neon( _samples: &[Complex], start: usize, diff --git a/src/trx-server/trx-backend/trx-backend-soapysdr/src/dsp/filter.rs b/src/trx-server/trx-backend/trx-backend-soapysdr/src/dsp/filter.rs index b945ba1..d99e0e5 100644 --- a/src/trx-server/trx-backend/trx-backend-soapysdr/src/dsp/filter.rs +++ b/src/trx-server/trx-backend/trx-backend-soapysdr/src/dsp/filter.rs @@ -179,6 +179,51 @@ unsafe fn mul_freq_domain_avx2( mul_freq_domain_scalar(&mut buf[i..len], &h_freq[i..len], scale); } +/// NEON frequency-domain complex multiply: processes 4 complex pairs per iteration. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn mul_freq_domain_neon( + buf: &mut [FftComplex], + h_freq: &[FftComplex], + scale: f32, +) { + use std::arch::aarch64::*; + + let len = buf.len().min(h_freq.len()); + let mut i = 0usize; + let scale_v = vdupq_n_f32(scale); + + while i + 4 <= len { + let x_ptr = buf.as_mut_ptr().add(i) as *mut f32; + let h_ptr = h_freq.as_ptr().add(i) as *const f32; + + // Load 4 complex numbers as two float32x4_t: [re0,im0,re1,im1] and [re2,im2,re3,im3] + let x_lo = vld1q_f32(x_ptr); + let x_hi = vld1q_f32(x_ptr.add(4)); + let h_lo = vld1q_f32(h_ptr); + let h_hi = vld1q_f32(h_ptr.add(4)); + + // Deinterleave: .0 = [re0..re3], .1 = [im0..im3] + let x_ri = vuzpq_f32(x_lo, x_hi); + let h_ri = vuzpq_f32(h_lo, h_hi); + let (x_re, x_im) = (x_ri.0, x_ri.1); + let (h_re, h_im) = (h_ri.0, h_ri.1); + + // Complex multiply: out.re = x.re*h.re - x.im*h.im, out.im = x.re*h.im + x.im*h.re + let out_re = vmulq_f32(vsubq_f32(vmulq_f32(x_re, h_re), vmulq_f32(x_im, h_im)), scale_v); + let out_im = vmulq_f32(vaddq_f32(vmulq_f32(x_re, h_im), vmulq_f32(x_im, h_re)), scale_v); + + // Reinterleave: .0 = [re0,im0,re1,im1], .1 = [re2,im2,re3,im3] + let out = vzipq_f32(out_re, out_im); + vst1q_f32(x_ptr, out.0); + vst1q_f32(x_ptr.add(4), out.1); + + i += 4; + } + + mul_freq_domain_scalar(&mut buf[i..len], &h_freq[i..len], scale); +} + fn mul_freq_domain(buf: &mut [FftComplex], h_freq: &[FftComplex], scale: f32) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { @@ -190,6 +235,14 @@ fn mul_freq_domain(buf: &mut [FftComplex], h_freq: &[FftComplex], scal } } + #[cfg(target_arch = "aarch64")] + { + unsafe { + mul_freq_domain_neon(buf, h_freq, scale); + } + return; + } + mul_freq_domain_scalar(buf, h_freq, scale); }