Files
UnrealEngine/Engine/Source/Runtime/BinkAudioDecoder/SDK/BinkAudio/Src/radfft.cpp
2025-05-18 13:04:45 +08:00

2372 lines
73 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "rrCore.h"
#include "radfft.h"
// Algorithm:
//
// This code computes FFTs and derived transforms using a decimation-in-time, conjugate-pair
// split-radix algorithm that closely follows Blake et al., "The Fastest Fourier Transform
// in the South", albeit without any of the runtime code generation.
//
// FFTS considers only 64-bit x86 and ARM, and thus assumes that 16 vector registers are
// available, and that the cache is at least 8-way set associative. This implementation also
// tries to perform well on 32-bit x86; this reduces the practical size of the recursion
// base cases from 8 to 4. As a side effect, we also don't need more than 4 cache ways, which
// is the right choice should we want to port this to order in-order cores.
//
// We also don't bother with the transition codelets in the base cases. The base case loops
// process as much data as possible while staying aligned; we don't bother with specialized
// SIMD code for the transition to another loop size though, we just run scalar versions
// at the edges. This is slightly less efficient, but a lot simpler.
//
// Real FFTs are computed using a complex FFT of half the size plus a Cooley-Tukey radix-2
// DIT step (the standard "packing" algorithm).
//
// The DCTs are computed using the standard reduction of N-element DCT-II and DCT-III to
// N-element real FFTs.
//
// Some more notes:
// - I looked into DCT-IIs merging their modulate step with the preceding rfft post-process.
// This works and saves a pass over the data, but the code is more complicated and wasn't
// faster in my tests. (Same goes for merging DCT-III modulate with rifft pre-process)
//#define TABLEGEN // Compile this file with /DTABLEGEN to write out prepared tables!
#ifndef FORCE_NO_FFT_TABLES
#define USETABLES // Use pre-generated twiddle/permutation tables
#endif
#ifdef TABLEGEN // TABLEGEN shouldn't use the prebuilt tables or we're getting circular! :)
#undef USETABLES
#endif
#define FFTASSERT(cond) if (!(cond)) RR_BREAK()
#define FFTALIGNED(type, name) static RAD_ALIGN(type const, name, RADFFT_ALIGN)
#define ALIGNHINT(var,align)
#define FFTTABLE(type, name) FFTALIGNED(type, name)
#if defined __has_builtin
#define RAD_HAS_BUILTIN(n) __has_builtin(n)
#else
#define RAD_HAS_BUILTIN(n) 0
#endif
#if RAD_HAS_BUILTIN(__builtin_cos) && RAD_HAS_BUILTIN(__builtin_sin)
#define cos(v) (__builtin_cos(v))
#define sin(v) (__builtin_sin(v))
#define _LIBCPP_MATH_H
#else
#include <math.h>
#endif
#ifdef BIG_OLE_FFT
static UINTa const kMaxN = 4096; // Largest FFT size we support. This is easy to change.
#else
static UINTa const kMaxN = 2048; // Largest FFT size we support. This is easy to change.
#endif
static UINTa const kLeafN = 4; // Size of leaf transforms. This isn't easy to change at all. :)
static UINTa const kMaxPlan = 256; // Largest FFT size we have pre-planned.
typedef U16 Index; // Index into FFT - need to change this if kMaxN > 65536
// For the small (leaf) FFTs, instead of doing an explicit recursion, we just loop over this list
// (the "plan"), which is a sequence of conjugate split-radix steps to do: offset and number of
// loop iterations.
//
// We have a decimation-in-time decomposition, which builds up from smaller towards larger FFTs.
// That means one plan is sufficient for all FFT sizes up to kMaxPlan: stopping after the first
// step with Nloop >= N/4 yields a N-element FFT.
struct PlanElement
{
Index offs;
Index Nloop;
};
// Prepared recursion plan up to kMaxPlan. This stores the sizes and offsets
// of the FFT passes. Generated by this program:
//
// ----
// #include <stdio.h>
// void plan(int offs, int N, int Nleaf)
// {
// static int counter = 0;
// if (N <= Nleaf)
// return;
// // Split-radix recursion pattern
// plan(offs, N/2, Nleaf);
// plan(offs + N/2, N/4, Nleaf);
// plan(offs + 3*N/4, N/4, Nleaf);
// printf("{ %3d, %2d },", offs, N/4);
// printf((++counter % 8) ? " " : "\n");
// }
// int main()
// {
// plan(0, 256, 4); // kMaxPlan=256, kLeafN=4
// printf("\n");
// return 0;
// }
RADDEFSTART
static
PlanElement s_recursion_plan[] =
{
{ 0, 2 }, { 0, 4 }, { 16, 2 }, { 24, 2 }, { 0, 8 }, { 32, 2 }, { 32, 4 }, { 48, 2 },
{ 48, 4 }, { 0, 16 }, { 64, 2 }, { 64, 4 }, { 80, 2 }, { 88, 2 }, { 64, 8 }, { 96, 2 },
{ 96, 4 }, { 112, 2 }, { 120, 2 }, { 96, 8 }, { 0, 32 }, { 128, 2 }, { 128, 4 }, { 144, 2 },
{ 152, 2 }, { 128, 8 }, { 160, 2 }, { 160, 4 }, { 176, 2 }, { 176, 4 }, { 128, 16 }, { 192, 2 },
{ 192, 4 }, { 208, 2 }, { 216, 2 }, { 192, 8 }, { 224, 2 }, { 224, 4 }, { 240, 2 }, { 240, 4 },
{ 192, 16 }, { 0, 64 },
};
// Twiddles and permutation tables for the FFT.
// If you change this, update the TABLEGEN code below!
#ifdef USETABLES
#ifdef BIG_OLE_FFT
#include "radfft_tables_4096.inl"
#else
#include "radfft_tables.inl"
#endif
#else
FFTALIGNED(rfft_complex, s_twiddles[(kMaxN / 4) * 2]); // Regular FFT twiddles: N-elem FFT needs a quarter circle; *2 because we store all "mip levels"
FFTALIGNED(rfft_complex, s_dct_twiddles[kMaxN / 4 + kMaxN / 2]); // DCT needs one eight of a circle at 4N rate, so we need two more.
Index s_permute[(kMaxN / kLeafN) * 2]; // *2 because we keep a "mip chain" for smaller transforms.
#endif
RADDEFEND
// --------------------------------------------------------------------------
// Kernel types
// --------------------------------------------------------------------------
// Real (I)FFT pre/post-pass
typedef void RFFTPrePostKernel(rfft_complex data[], UINTa kEnd, UINTa N4);
// Complex (I)FFT
typedef void CFFTKernel(rfft_complex data[], PlanElement const *plan, UINTa Nover4);
// Radix4 or 2xRadix2 CFFT base cases
typedef void BaseKernel(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm);
// (I)DCT merge/split
typedef void MergeSplitKernel(F32 out[], F32 const in[], UINTa N);
typedef void MergeSplitKernelS16(S16 out[], F32 scale, F32 const in[], UINTa N);
typedef void MergeSplitKernelS16S(S16 out[], S16 left[], F32 scale, F32 const in[], UINTa N);
// (I)DCT modulate
typedef void ModulateKernel(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle);
struct KernelSet
{
RFFTPrePostKernel *rfpost; // Real FFT post
RFFTPrePostKernel *ripre; // Real IFFT pre
CFFTKernel *cfpass; // Complex FFT pass
CFFTKernel *cipass; // Complex IFFT pass
BaseKernel *radix4; // Radix4 base case
BaseKernel *radix2_2; // 2x Radix2 base case
MergeSplitKernel *dct_split; // DCT-II split pass
MergeSplitKernel *dct_merge; // DCT-III merge pass
MergeSplitKernelS16 *dct_merge_s16; // DCT-III merge pass with s16 out
MergeSplitKernelS16S *dct_merge_s16s; // DCT-III merge pass with s16 stereo out (left chan passed in)
ModulateKernel *dct2_mod; // DCT-II modulation pass
ModulateKernel *dct3_mod; // DCT-III modulation pass
};
// --------------------------------------------------------------------------
// Scalar kernels. These are always there.
// --------------------------------------------------------------------------
// Size 1 or 2 complex (I)FFTs - trivial (for base cases)
static void scalar_cfft_tiny(rfft_complex out[], rfft_complex const in[], UINTa N)
{
if (N == 2)
{
rfft_complex a = in[0];
rfft_complex b = in[1];
out[0].re = a.re + b.re;
out[0].im = a.im + b.im;
out[1].re = a.re - b.re;
out[1].im = a.im - b.im;
}
else if (N == 1)
out[0] = in[0];
}
// Real FFT post-pass.
static void scalar_rfpost(rfft_complex out[], UINTa kEnd, UINTa N4)
{
rfft_complex const *twiddle = s_twiddles + N4;
rfft_complex *out0 = out + 1;
rfft_complex *out1 = out + (N4 * 2) - 1;
// out'[k] = 0.5 * ((1 - i*conj(twiddle[k])) * out[k] + (1 + i*conj(twiddle[k])) * conj(out[N-k]))
for (UINTa k = 1; k < kEnd; ++k, ++out0, --out1)
{
rfft_complex const &w = twiddle[k];
F32 dr = 0.5f * (out1->re - out0->re);
F32 di = 0.5f * (out0->im + out1->im);
F32 evr = out0->re + dr;
F32 evi = out0->im - di;
F32 odr = w.re*di + w.im*dr;
F32 odi = w.re*dr - w.im*di;
out0->re = evr + odr;
out0->im = evi + odi;
out1->re = evr - odr;
out1->im = odi - evi;
}
}
// Real IFFT pre-pass.
static void scalar_ripre(rfft_complex out[], UINTa kEnd, UINTa N4)
{
rfft_complex const *twiddle = s_twiddles + N4;
rfft_complex *out0 = out + 1;
rfft_complex *out1 = out + (N4 * 2) - 1;
for (UINTa k = 1; k < kEnd; ++k, ++out0, --out1)
{
rfft_complex const &w = twiddle[k];
F32 dr = 0.5f * (out0->re - out1->re);
F32 di = 0.5f * (out0->im + out1->im);
F32 evr = out0->re - dr;
F32 evi = out0->im - di;
F32 odr = w.re*di + w.im*dr;
F32 odi = w.re*dr - w.im*di;
out0->re = evr - odr;
out0->im = evi + odi;
out1->re = evr + odr;
out1->im = odi - evi;
}
}
// Scalar conjugate split-radix forward pass.
static void scalar_cfpass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
--plan;
do
{
++plan;
UINTa N1 = plan->Nloop;
rfft_complex *out = data + plan->offs;
rfft_complex const *twiddle = s_twiddles + N1;
// k=0 has twiddle factors 1 so we can save a bunch of work
// (this is worthwhile because we keep subdividing into shorter and
// shorter transforms; it's not just a one-time thing, we win on
// every level of the recursion)
{
rfft_complex &x0 = out[0*N1];
rfft_complex &x1 = out[1*N1];
rfft_complex &x2 = out[2*N1];
rfft_complex &x3 = out[3*N1];
F32 Zsumr = x2.re + x3.re;
F32 Zsumi = x2.im + x3.im;
F32 Zdifr = x2.im - x3.im;
F32 Zdifi = x3.re - x2.re;
F32 U0r = x0.re;
F32 U0i = x0.im;
x2.re = U0r - Zsumr;
x2.im = U0i - Zsumi;
x0.re = U0r + Zsumr;
x0.im = U0i + Zsumi;
F32 U1r = x1.re;
F32 U1i = x1.im;
x3.re = U1r - Zdifr;
x3.im = U1i - Zdifi;
x1.re = U1r + Zdifr;
x1.im = U1i + Zdifi;
++out;
}
for (UINTa k = 1; k < N1; ++k)
{
rfft_complex const &w = twiddle[k];
rfft_complex &x0 = out[0*N1];
rfft_complex &x1 = out[1*N1];
rfft_complex &x2 = out[2*N1];
rfft_complex &x3 = out[3*N1];
// This is the general case: (complex values, r=real part, i=imaginary part)
// w_k is the twiddle factor, (omega)^k.
//
// Z_k = w_k * x[k + 2N/4]
// Z'_k = conj(w_k) * x[k + 3N/4]
// U0 = x[k + 0N/4]
// U1 = x[k + 1N/4]
//
// Zsum_k = Z_k + Z'_k
// Zdif_k = -i * (Z_k - Z'_k)
//
// new_x[k + 0N/4] = U0 + Zsum_k = U0 + (Z_k + Z'_k)
// new_x[k + 1N/4] = U1 + Zdif_k = U1 - i*(Z_k - Z'_k)
// new_x[k + 2N/4] = U0 - Zsum_k = U0 - (Z_k + Z'_k)
// new_x[k + 3N/4] = U1 - Zdif_k = U1 + i*(Z_k - Z'_k)
F32 Zkr = w.re*x2.re - w.im*x2.im;
F32 Zki = w.re*x2.im + w.im*x2.re;
F32 Zpkr = w.re*x3.re + w.im*x3.im;
F32 Zpki = w.re*x3.im - w.im*x3.re;
F32 Zsumr = Zkr + Zpkr;
F32 Zsumi = Zki + Zpki;
F32 Zdifr = Zki - Zpki;
F32 Zdifi = Zpkr - Zkr;
F32 U0r = x0.re;
F32 U0i = x0.im;
x2.re = U0r - Zsumr;
x2.im = U0i - Zsumi;
x0.re = U0r + Zsumr;
x0.im = U0i + Zsumi;
F32 U1r = x1.re;
F32 U1i = x1.im;
x3.re = U1r - Zdifr;
x3.im = U1i - Zdifi;
x1.re = U1r + Zdifr;
x1.im = U1i + Zdifi;
++out;
}
} while (plan->Nloop < Nover4);
}
// Scalar conjugate split-radix inverse pass.
static void scalar_cipass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
--plan;
do
{
++plan;
UINTa N1 = plan->Nloop;
rfft_complex *out = data + plan->offs;
rfft_complex const *twiddle = s_twiddles + N1;
// k=0 has twiddle factors 1 so we can save a bunch of work
// (this is worthwhile because we keep subdividing into shorter and
// shorter transforms; it's not just a one-time thing, we win on
// every level of the recursion)
{
rfft_complex &x0 = out[0*N1];
rfft_complex &x1 = out[1*N1];
rfft_complex &x2 = out[2*N1];
rfft_complex &x3 = out[3*N1];
F32 Zsumr = x2.re + x3.re;
F32 Zsumi = x2.im + x3.im;
F32 Zdifr = x3.im - x2.im;
F32 Zdifi = x2.re - x3.re;
F32 U0r = x0.re;
F32 U0i = x0.im;
x2.re = U0r - Zsumr;
x2.im = U0i - Zsumi;
x0.re = U0r + Zsumr;
x0.im = U0i + Zsumi;
F32 U1r = x1.re;
F32 U1i = x1.im;
x3.re = U1r - Zdifr;
x3.im = U1i - Zdifi;
x1.re = U1r + Zdifr;
x1.im = U1i + Zdifi;
++out;
}
for (UINTa k = 1; k < N1; ++k)
{
rfft_complex const &w = twiddle[k];
rfft_complex &x0 = out[0*N1];
rfft_complex &x1 = out[1*N1];
rfft_complex &x2 = out[2*N1];
rfft_complex &x3 = out[3*N1];
// This is the general case: (complex values, r=real part, i=imaginary part)
// w_k is the twiddle factor, (omega)^k.
//
// Z_k = conj(w_k) * x[k + 2N/4]
// Z'_k = w_k * x[k + 3N/4]
// U0 = x[k + 0N/4]
// U1 = x[k + 1N/4]
//
// Zsum_k = Z_k + Z'_k
// Zdif_k = i * (Z_k - Z'_k)
//
// new_x[k + 0N/4] = U0 + Zsum_k = U0 + (Z_k + Z'_k)
// new_x[k + 1N/4] = U1 + Zdif_k = U1 + i*(Z_k - Z'_k)
// new_x[k + 2N/4] = U0 - Zsum_k = U0 - (Z_k + Z'_k)
// new_x[k + 3N/4] = U1 - Zdif_k = U1 - i*(Z_k - Z'_k)
//
// Note that this is essentially the same as the forward transform (cfpass), except
// the twiddle factors (the w_k and i) are conjugated.
F32 Zkr = w.re*x2.re + w.im*x2.im;
F32 Zki = w.re*x2.im - w.im*x2.re;
F32 Zpkr = w.re*x3.re - w.im*x3.im;
F32 Zpki = w.re*x3.im + w.im*x3.re;
F32 Zsumr = Zkr + Zpkr;
F32 Zsumi = Zki + Zpki;
F32 Zdifr = Zpki - Zki;
F32 Zdifi = Zkr - Zpkr;
F32 U0r = x0.re;
F32 U0i = x0.im;
x2.re = U0r - Zsumr;
x2.im = U0i - Zsumi;
x0.re = U0r + Zsumr;
x0.im = U0i + Zsumi;
F32 U1r = x1.re;
F32 U1i = x1.im;
x3.re = U1r - Zdifr;
x3.im = U1i - Zdifi;
x1.re = U1r + Zdifr;
x1.im = U1i + Zdifi;
++out;
}
} while (plan->Nloop < Nover4);
}
// Radix-4 FFT codelet. To get the inverse (up to scale), swap in1 and in3.
static void scalar_radix4(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
for (UINTa j = j0; j < j1; ++j)
{
rfft_complex const &a = in0[j];
rfft_complex const &b = in2[j];
rfft_complex const &c = in1[j];
rfft_complex const &d = in3[j];
rfft_complex *o = out + perm[j];
F32 Ar = a.re + b.re;
F32 Ai = a.im + b.im;
F32 Br = a.re - b.re;
F32 Bi = a.im - b.im;
F32 Cr = c.re + d.re;
F32 Ci = c.im + d.im;
F32 Dr = c.im - d.im;
F32 Di = d.re - c.re;
o[0].re = Ar + Cr;
o[0].im = Ai + Ci;
o[2].re = Ar - Cr;
o[2].im = Ai - Ci;
o[1].re = Br + Dr;
o[1].im = Bi + Di;
o[3].re = Br - Dr;
o[3].im = Bi - Di;
}
}
// 2x Radix-2 FFT codelet. Self-inverse (up to scale).
static void scalar_2radix2(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
for (UINTa j = j0; j < j1; ++j)
{
rfft_complex const &a = in0[j];
rfft_complex const &b = in2[j];
rfft_complex const &c = in3[j];
rfft_complex const &d = in1[j];
rfft_complex *o = out + perm[j];
o[0].re = a.re + b.re;
o[0].im = a.im + b.im;
o[1].re = a.re - b.re;
o[1].im = a.im - b.im;
o[2].re = c.re + d.re;
o[2].im = c.im + d.im;
o[3].re = c.re - d.re;
o[3].im = c.im - d.im;
}
}
// DCT even/odd split
static void scalar_dct_split(F32 out[], F32 const in[], UINTa N)
{
UINTa N2 = N / 2;
F32 *out0 = out;
F32 *out1 = out + N;
F32 const *inp = in;
// Even-indexed input elems go to first half of out
// Odd-indexed input elems go to second half of out, in reverse order
for (UINTa k = 0; k < N2; ++k)
{
F32 a = *inp++;
F32 b = *inp++;
*out0++ = a;
*--out1 = b;
}
}
// DCT merge (inverse of split)
static void scalar_dct_merge(F32 out[], F32 const in[], UINTa N)
{
UINTa N2 = N / 2;
F32 *outp = out;
F32 const *in0 = in;
F32 const *in1 = in + N;
for (UINTa k = 0; k < N2; ++k)
{
F32 a = *in0++;
F32 b = *--in1;
*outp++ = a;
*outp++ = b;
}
}
// DCT merge (inverse of split)
static void scalar_dct_merge_s16(S16 out[], F32 scale, F32 const in[], UINTa N)
{
UINTa N2 = N / 2;
S16 *outp = out;
F32 const *in0 = in;
F32 const *in1 = in + N;
for (UINTa k = 0; k < N2; ++k)
{
F32 a = *in0++;
F32 b = *--in1;
S32 va = (S32)(a*scale);
S32 vb = (S32)(b*scale);
if ( va > 32767 ) va = 32767; if ( va <= -32768 ) va = -32768;
if ( vb > 32767 ) vb = 32767; if ( vb <= -32768 ) vb = -32768;
*outp++ = (S16)va;
*outp++ = (S16)vb;
}
}
// DCT merge (inverse of split)
static void scalar_dct_merge_s16s(S16 out[], S16 left[], F32 scale, F32 const in[], UINTa N)
{
UINTa N2 = N / 2;
S16 *outp = out;
F32 const *in0 = in;
F32 const *in1 = in + N;
for (UINTa k = 0; k < N2; ++k)
{
F32 a = *in0++;
F32 b = *--in1;
S32 va = (S32)(a*scale);
S32 vb = (S32)(b*scale);
if ( va > 32767 ) va = 32767; if ( va <= -32768 ) va = -32768;
if ( vb > 32767 ) vb = 32767; if ( vb <= -32768 ) vb = -32768;
*outp++ = *left++;
*outp++ = (S16)va;
*outp++ = *left++;
*outp++ = (S16)vb;
}
}
// DCT-II modulation step
static void scalar_dct2_modulate(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle)
{
F32 const kSqrtOneHalf = 0.7071067811865475244f;
// First and Nyquist buckets
out[0] = in[0];
out[N/2] = kSqrtOneHalf * in[1];
// Rest
for (UINTa k = 1; k < Nlast; ++k)
{
F32 Zr = in[k*2 + 0];
F32 Zi = in[k*2 + 1];
F32 wr = twiddle[k].re;
F32 wi = twiddle[k].im;
out[k] = wr*Zr + wi*Zi;
out[N-k] = wr*Zi - wi*Zr;
}
}
// DCT-III modulation step
static void scalar_dct3_modulate(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle)
{
F32 const kSqrtTwo = 1.4142135623730950488f;
// First and Nyquist buckets
out[0] = in[0] + in[0];
out[1] = kSqrtTwo * in[N/2];
// Rest
for (UINTa k = 1; k < Nlast; ++k)
{
F32 Zr = in[k];
F32 Zi = in[N-k];
F32 wr = twiddle[k].re;
F32 wi = twiddle[k].im;
out[k*2 + 0] = wr*Zr - wi*Zi;
out[k*2 + 1] = wr*Zi + wi*Zr;
}
}
#if !defined(__RADNEON__)
// Fully scalar kernel set is always an option
static KernelSet const s_kernel_scalar = {
scalar_rfpost,
scalar_ripre,
scalar_cfpass,
scalar_cipass,
scalar_radix4,
scalar_2radix2,
scalar_dct_split,
scalar_dct_merge,
scalar_dct_merge_s16,
scalar_dct_merge_s16s,
scalar_dct2_modulate,
scalar_dct3_modulate
};
#endif
// --------------------------------------------------------------------------
// SSE/SSE3 kernels for x86.
// --------------------------------------------------------------------------
#ifdef __RADX86__
#include <emmintrin.h> // we require SSE2.
#if defined( RAD_USES_SSE3 )
// this is just to compile with GCC 4.8 - there will be a warning but it works
// for GCC 4.9 and clang, this isn't necessary
#ifdef __RAD_GCC_VERSION__
#if __RAD_GCC_VERSION__ < 40900
#pragma GCC push_options
#pragma GCC target("sse3")
#define __SSE3__
#endif
#endif
#include <pmmintrin.h>
// First time round: actual SSE3
#define RADFFT_SSE3_PREFIX(name) RAD_USES_SSE3 sse3_##name
#include "radfft_sse3.inl"
#undef RADFFT_SSE3_PREFIX
// this is just to compile with GCC 4.8 - there will be a warning but it works
// for GCC 4.9 and clang, this isn't necessary
#ifdef __RAD_GCC_VERSION__
#if __RAD_GCC_VERSION__ < 40900
#pragma GCC pop_options
#endif
#endif
#endif // ifdef RADFFT_SSE3
#if !defined(RAD_GUARANTEED_SSE3)
// Second time round: plain SSE with these glorious hacks.
// This doesn't give optimal SSE code but it's not terrible either.
#define RADFFT_SSE3_PREFIX(name) sse_##name
#define _mm_moveldup_ps(x) _mm_shuffle_ps((x), (x), 0xa0)
#define _mm_movehdup_ps(x) _mm_shuffle_ps((x), (x), 0xf5)
#define _mm_addsub_ps(a,b) _mm_add_ps((a), _mm_xor_ps(b, _mm_setr_ps(-0.0f, 0.0f, -0.0f, 0.0f)))
#include "radfft_sse3.inl"
#undef _mm_moveldup_ps
#undef _mm_movehdup_ps
#undef _mm_addsub_ps
#undef RADFFT_SSE3_PREFIX
#endif
// SSE 2x Radix-2 FFT codelet. Self-inverse (up to scale).
static void sse_2radix2(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
if (j0 == j1)
return;
// SIMD loop wants an even number of elements at an aligned
// offset. Thus, we may have extra elements at the beginning
// or end.
UINTa j1a = j1 & ~1;
if (j0 & 1)
{
scalar_2radix2(out, in0, in1, in2, in3, j0, j0 + 1, perm);
++j0;
}
for (UINTa j = j0; j < j1; j += 2)
{
__m128 a = _mm_load_ps((const F32 *) &in0[j]);
__m128 b = _mm_load_ps((const F32 *) &in2[j]);
__m128 c = _mm_load_ps((const F32 *) &in3[j]);
__m128 d = _mm_load_ps((const F32 *) &in1[j]);
F32 *o0 = (F32 *) (out + perm[j+0]);
F32 *o1 = (F32 *) (out + perm[j+1]);
__m128 E = _mm_add_ps(a, b);
__m128 F = _mm_sub_ps(a, b);
__m128 G = _mm_add_ps(c, d);
__m128 H = _mm_sub_ps(c, d);
_mm_store_ps(o0 + 0, _mm_movelh_ps(E, F));
_mm_store_ps(o1 + 0, _mm_movehl_ps(F, E));
_mm_store_ps(o0 + 4, _mm_movelh_ps(G, H));
_mm_store_ps(o1 + 4, _mm_movehl_ps(H, G));
}
if (j1a != j1)
scalar_2radix2(out, in0, in1, in2, in3, j1a, j1, perm);
}
// SSE Radix-4 FFT codelet. To get the inverse (up to scale), swap in1 and in3.
static void sse_radix4(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
if (j0 == j1)
return;
// SIMD loop wants an even number of elements at an aligned
// offset. Thus, we may have extra elements at the beginning
// or end.
UINTa j1a = j1 & ~1;
if (j0 & 1)
{
scalar_radix4(out, in0, in1, in2, in3, j0, j0 + 1, perm);
++j0;
}
__m128 conjflip = _mm_setr_ps(0.0f, -0.0f, 0.0f, -0.0f);
for (UINTa j = j0; j < j1a; j += 2)
{
__m128 a = _mm_load_ps((const F32 *) &in0[j]);
__m128 b = _mm_load_ps((const F32 *) &in2[j]);
__m128 c = _mm_load_ps((const F32 *) &in1[j]);
__m128 d = _mm_load_ps((const F32 *) &in3[j]);
F32 *o0 = &out[perm[j+0]].re;
F32 *o1 = &out[perm[j+1]].re;
__m128 A = _mm_add_ps(a, b);
__m128 B = _mm_sub_ps(a, b);
__m128 C = _mm_add_ps(c, d);
__m128 D = _mm_sub_ps(c, d);
// D *= -i
D = _mm_xor_ps(_mm_shuffle_ps(D, D, 0xb1), conjflip);
__m128 E = _mm_add_ps(A, C);
__m128 F = _mm_add_ps(B, D);
__m128 G = _mm_sub_ps(A, C);
__m128 H = _mm_sub_ps(B, D);
_mm_store_ps(o0 + 0, _mm_movelh_ps(E, F));
_mm_store_ps(o1 + 0, _mm_movehl_ps(F, E));
_mm_store_ps(o0 + 4, _mm_movelh_ps(G, H));
_mm_store_ps(o1 + 4, _mm_movehl_ps(H, G));
}
if (j1a != j1)
scalar_radix4(out, in0, in1, in2, in3, j1a, j1, perm);
}
static void sse_dct_split(F32 out[], F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 8)
{
scalar_dct_split(out, in, N);
return;
}
#endif
// N is pow2 and >=8
F32 *out0 = out;
F32 *out1 = out + N;
F32 const *inp = in;
F32 const *inp_end = in + N;
do
{
__m128 v0 = _mm_load_ps(inp);
__m128 v1 = _mm_load_ps(inp + 4);
__m128 s0 = _mm_shuffle_ps(v0, v1, 0x88); // x0,x2,x4,x6
__m128 s1 = _mm_shuffle_ps(v1, v0, 0x77); // x7,x5,x3,x1
_mm_store_ps(out0, s0);
_mm_store_ps(out1 - 4, s1);
inp += 8;
out0 += 4;
out1 -= 4;
} while (inp != inp_end);
}
static void sse_dct_merge(F32 out[], F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 8)
{
scalar_dct_merge(out, in, N);
return;
}
#endif
// N is pow2 and >=8
F32 const *in0 = in;
F32 const *in1 = in + N;
F32 *outp = out;
F32 *outp_end = out + N;
do
{
__m128 i0 = _mm_load_ps(in0); // x0,x2,x4,x6
__m128 i1 = _mm_load_ps(in1 - 4); // x7,x5,x3,x1
__m128 s1 = _mm_shuffle_ps(i1, i1, 0x1b); // x1,x3,x5,x7
__m128 o0 = _mm_unpacklo_ps(i0, s1);
__m128 o1 = _mm_unpackhi_ps(i0, s1);
_mm_store_ps(outp, o0);
_mm_store_ps(outp + 4, o1);
outp += 8;
in0 += 4;
in1 -= 4;
}
while (outp != outp_end);
}
static void sse_dct_merge_s16(S16 out[], F32 scale, F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 8)
{
scalar_dct_merge_s16(out, scale, in, N);
return;
}
#endif
// N is pow2 and >=8
F32 const *in0 = in;
F32 const *in1 = in + N;
S16 *outp = out;
S16 *outp_end = out + N;
__m128 scale128 = _mm_load1_ps( (float*)(&scale) );
do
{
__m128 i0 = _mm_load_ps(in0); // x0,x2,x4,x6
__m128 i1 = _mm_load_ps(in1 - 4); // x7,x5,x3,x1
__m128 s1 = _mm_shuffle_ps(i1, i1, 0x1b); // x1,x3,x5,x7
__m128 o0 = _mm_unpacklo_ps(i0, s1);
__m128 o1 = _mm_unpackhi_ps(i0, s1);
o0 = _mm_mul_ps( o0, scale128 );
o1 = _mm_mul_ps( o1, scale128 );
// [ x, y, z, w ]
__m128i io0 = _mm_cvtps_epi32( o0 );
__m128i io1 = _mm_cvtps_epi32( o1 );
__m128i p0 = _mm_packs_epi32( io0, io1 );
_mm_store_si128((__m128i*)outp, p0);
outp += 8;
in0 += 4;
in1 -= 4;
}
while (outp != outp_end);
}
static void sse_dct_merge_s16s(S16 out[], S16 left[], F32 scale, F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 8)
{
scalar_dct_merge_s16(out, scale, in, N);
return;
}
#endif
// N is pow2 and >=8
F32 const *in0 = in;
F32 const *in1 = in + N;
S16 *outp = out;
S16 *outp_end = out + (N*2);
__m128 scale128 = _mm_load1_ps( (float*)(&scale) );
do
{
__m128 i0 = _mm_load_ps(in0); // x0,x2,x4,x6
__m128 i1 = _mm_load_ps(in1 - 4); // x7,x5,x3,x1
__m128 s1 = _mm_shuffle_ps(i1, i1, 0x1b); // x1,x3,x5,x7
__m128 o0 = _mm_unpacklo_ps(i0, s1);
__m128 o1 = _mm_unpackhi_ps(i0, s1);
o0 = _mm_mul_ps( o0, scale128 );
o1 = _mm_mul_ps( o1, scale128 );
// [ x, y, z, w ]
__m128i io0 = _mm_cvtps_epi32( o0 );
__m128i io1 = _mm_cvtps_epi32( o1 );
__m128i p0 = _mm_packs_epi32( io0, io1 );
io1 = _mm_load_si128((__m128i*)left);
io0 = _mm_unpacklo_epi16( io1, p0 );
io1 = _mm_unpackhi_epi16( io1, p0 );
_mm_store_si128((__m128i*)outp, io0);
_mm_store_si128((__m128i*)(outp+8), io1);
outp += 16;
left += 8;
in0 += 4;
in1 -= 4;
}
while (outp != outp_end);
}
static void sse_dct3_modulate(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle)
{
#ifdef WANT_TINY
if (N < 8)
{
scalar_dct3_modulate(out, in, N, Nlast, twiddle);
return;
}
#endif
// First few bins have exceptional cases, let scalar routine handle it
FFTASSERT((Nlast % 4) == 0 && Nlast >= 4);
scalar_dct3_modulate(out, in, N, 4, twiddle);
for (UINTa k = 4; k < Nlast; k += 4)
{
__m128 Zr = _mm_load_ps(in + k);
__m128 rZi = _mm_loadu_ps(in + N - 3 - k); // reversed Zi
__m128 Zi = _mm_shuffle_ps(rZi, rZi, 0x1b); // reverse it to get values the right way around
__m128 w0 = _mm_load_ps(&twiddle[k + 0].re);
__m128 w1 = _mm_load_ps(&twiddle[k + 2].re);
__m128 wr = _mm_shuffle_ps(w0, w1, 0x88); // real parts of twiddles
__m128 wi = _mm_shuffle_ps(w0, w1, 0xdd); // imag parts of twiddles
__m128 re = _mm_sub_ps(_mm_mul_ps(wr, Zr), _mm_mul_ps(wi, Zi));
__m128 im = _mm_add_ps(_mm_mul_ps(wr, Zi), _mm_mul_ps(wi, Zr));
_mm_store_ps(out + k*2 + 0, _mm_unpacklo_ps(re, im));
_mm_store_ps(out + k*2 + 4, _mm_unpackhi_ps(re, im));
}
}
#if !defined(RAD_GUARANTEED_SSE3)
// Kernels to use when SSE (but not SSE3) is available
static KernelSet const s_kernel_sse = {
sse_rfpost,
sse_ripre,
sse_cfpass,
sse_cipass,
sse_radix4,
sse_2radix2,
sse_dct_split,
sse_dct_merge,
sse_dct_merge_s16,
sse_dct_merge_s16s,
sse_dct2_modulate,
sse_dct3_modulate
};
#endif
#ifdef RAD_USES_SSE3
// Kernels to use when SSE3 is available
static KernelSet const s_kernel_sse3 = {
sse3_rfpost,
sse3_ripre,
sse3_cfpass,
sse3_cipass,
sse_radix4,
sse_2radix2,
sse_dct_split,
sse_dct_merge,
sse_dct_merge_s16,
sse_dct_merge_s16s,
sse3_dct2_modulate,
sse_dct3_modulate
};
#endif
#ifdef RADFFT_AVX
#include <immintrin.h>
// AVX complex forward conjugate split-radix reduction pass.
// This is literally sse3_cfpass copied, with "__m128" replaced with "__m256",
// "_mm_" replaced with "_mm256_" and the offset increments of "4" replaced with
// "8".
//
// The N1=2 special case needed to be added, but that's it.
static void avx_cfpass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
__m256 conjflip = _mm256_setr_ps(0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f);
--plan;
do
{
++plan;
UINTa N1 = plan->Nloop;
UINTa step = N1 * 2;
F32 *out = (F32 *) (data + plan->offs);
F32 *out_end = out + step;
F32 const *twiddle = (F32 const *) (s_twiddles + N1);
if (N1 > 2)
{
do
{
__m256 Zk = _mm256_load_ps(out + 2*step);
__m256 Zpk = _mm256_load_ps(out + 3*step);
__m256 w = _mm256_load_ps(twiddle);
__m256 w_re = _mm256_moveldup_ps(w);
__m256 w_im = _mm256_movehdup_ps(w);
// Twiddle Zk, Z'k
Zk = _mm256_addsub_ps(_mm256_mul_ps(Zk, w_re), _mm256_mul_ps(_mm256_shuffle_ps(Zk, Zk, 0xb1 /* yxwz */), w_im));
Zpk = _mm256_addsub_ps(_mm256_mul_ps(_mm256_shuffle_ps(Zpk, Zpk, 0xb1), w_re), _mm256_mul_ps(Zpk, w_im));
__m256 Zsum = _mm256_add_ps(_mm256_shuffle_ps(Zpk, Zpk, 0xb1), Zk);
__m256 Zdif = _mm256_sub_ps(_mm256_shuffle_ps(Zk, Zk, 0xb1), Zpk);
// Even inputs
__m256 Uk0 = _mm256_load_ps(out + 0*step);
__m256 Uk1 = _mm256_load_ps(out + 1*step);
// Output butterflies
_mm256_store_ps(out + 0*step, _mm256_add_ps(Uk0, Zsum));
_mm256_store_ps(out + 1*step, _mm256_add_ps(Uk1, _mm256_xor_ps(Zdif, conjflip)));
_mm256_store_ps(out + 2*step, _mm256_sub_ps(Uk0, Zsum));
_mm256_store_ps(out + 3*step, _mm256_addsub_ps(Uk1, Zdif));
out += 8;
twiddle += 8;
} while (out < out_end);
}
else
{
// N=2 (small case)
__m128 Zk = _mm_load_ps(out + 2*step);
__m128 Zpk = _mm_load_ps(out + 3*step);
__m128 w = _mm_load_ps(twiddle);
__m128 w_re = _mm_moveldup_ps(w);
__m128 w_im = _mm_movehdup_ps(w);
// Twiddle Zk, Z'k
Zk = _mm_addsub_ps(_mm_mul_ps(Zk, w_re), _mm_mul_ps(_mm_shuffle_ps(Zk, Zk, 0xb1 /* yxwz */), w_im));
Zpk = _mm_addsub_ps(_mm_mul_ps(_mm_shuffle_ps(Zpk, Zpk, 0xb1), w_re), _mm_mul_ps(Zpk, w_im));
__m128 Zsum = _mm_add_ps(_mm_shuffle_ps(Zpk, Zpk, 0xb1), Zk);
__m128 Zdif = _mm_sub_ps(_mm_shuffle_ps(Zk, Zk, 0xb1), Zpk);
// Even inputs
__m128 Uk0 = _mm_load_ps(out + 0*step);
__m128 Uk1 = _mm_load_ps(out + 1*step);
// Output butterflies
_mm_store_ps(out + 0*step, _mm_add_ps(Uk0, Zsum));
_mm_store_ps(out + 1*step, _mm_add_ps(Uk1, _mm_xor_ps(Zdif, _mm256_extractf128_ps(conjflip, 0))));
_mm_store_ps(out + 2*step, _mm_sub_ps(Uk0, Zsum));
_mm_store_ps(out + 3*step, _mm_addsub_ps(Uk1, Zdif));
}
} while (plan->Nloop < Nover4);
}
// AVX complex inverse conjugate split-radix reduction pass. This is the main workhorse inner loop for IFFTs.
// This is literally sse3_cipass copied, with "__m128" replaced with "__m256",
// "_mm_" replaced with "_mm256_" and the offset increments of "4" replaced with
// "8".
//
// The N1=2 special case needed to be added, but that's it.
static void avx_cipass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
__m256 conjflip = _mm256_setr_ps(0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f);
--plan;
do
{
++plan;
UINTa N1 = plan->Nloop;
UINTa step = N1 * 2;
F32 *out = (F32 *) (data + plan->offs);
F32 *out_end = out + step;
F32 const *twiddle = (F32 const *) (s_twiddles + N1);
if (N1 > 2)
{
do
{
__m256 Zk = _mm256_load_ps(out + 2*step);
__m256 Zpk = _mm256_load_ps(out + 3*step);
__m256 w = _mm256_load_ps(twiddle);
__m256 w_re = _mm256_moveldup_ps(w);
__m256 w_im = _mm256_movehdup_ps(w);
// Twiddle Zk, Z'k
Zpk = _mm256_addsub_ps(_mm256_mul_ps(Zpk, w_re), _mm256_mul_ps(_mm256_shuffle_ps(Zpk, Zpk, 0xb1), w_im));
Zk = _mm256_addsub_ps(_mm256_mul_ps(_mm256_shuffle_ps(Zk, Zk, 0xb1), w_re), _mm256_mul_ps(Zk, w_im));
__m256 Zsum = _mm256_add_ps(_mm256_shuffle_ps(Zk, Zk, 0xb1), Zpk);
__m256 Zdif = _mm256_sub_ps(_mm256_shuffle_ps(Zpk, Zpk, 0xb1), Zk);
// Even inputs
__m256 Uk0 = _mm256_load_ps(out + 0*step);
__m256 Uk1 = _mm256_load_ps(out + 1*step);
// Output butterflies
_mm256_store_ps(out + 0*step, _mm256_add_ps(Uk0, Zsum));
_mm256_store_ps(out + 1*step, _mm256_add_ps(Uk1, _mm256_xor_ps(Zdif, conjflip)));
_mm256_store_ps(out + 2*step, _mm256_sub_ps(Uk0, Zsum));
_mm256_store_ps(out + 3*step, _mm256_addsub_ps(Uk1, Zdif));
out += 8;
twiddle += 8;
} while (out < out_end);
}
else
{
__m128 Zk = _mm_load_ps(out + 2*step);
__m128 Zpk = _mm_load_ps(out + 3*step);
__m128 w = _mm_load_ps(twiddle);
__m128 w_re = _mm_moveldup_ps(w);
__m128 w_im = _mm_movehdup_ps(w);
// Twiddle Zk, Z'k
Zpk = _mm_addsub_ps(_mm_mul_ps(Zpk, w_re), _mm_mul_ps(_mm_shuffle_ps(Zpk, Zpk, 0xb1), w_im));
Zk = _mm_addsub_ps(_mm_mul_ps(_mm_shuffle_ps(Zk, Zk, 0xb1), w_re), _mm_mul_ps(Zk, w_im));
__m128 Zsum = _mm_add_ps(_mm_shuffle_ps(Zk, Zk, 0xb1), Zpk);
__m128 Zdif = _mm_sub_ps(_mm_shuffle_ps(Zpk, Zpk, 0xb1), Zk);
// Even inputs
__m128 Uk0 = _mm_load_ps(out + 0*step);
__m128 Uk1 = _mm_load_ps(out + 1*step);
// Output butterflies
_mm_store_ps(out + 0*step, _mm_add_ps(Uk0, Zsum));
_mm_store_ps(out + 1*step, _mm_add_ps(Uk1, _mm_xor_ps(Zdif, _mm256_extractf128_ps(conjflip, 0))));
_mm_store_ps(out + 2*step, _mm_sub_ps(Uk0, Zsum));
_mm_store_ps(out + 3*step, _mm_addsub_ps(Uk1, Zdif));
}
} while (plan->Nloop < Nover4);
}
// Kernels to use when AVX is available
static KernelSet const s_kernel_avx = {
sse3_rfpost,
sse3_ripre,
avx_cfpass,
avx_cipass,
sse_radix4,
sse_2radix2,
sse_dct_split,
sse_dct_merge,
sse_dct_merge_s16,
sse_dct_merge_s16s,
sse3_dct2_modulate,
sse_dct3_modulate
};
#endif
#ifdef _MSC_VER
#include <intrin.h>
#else
// Assume GCC-like and try to provide MSVC-style __cpuid
static void __cpuid(int *info, int which)
{
void *saved=0;
// NOTE(fg): With PIC on Mac, can't overwrite ebx, so we
// need to play it safe.
#ifdef __RAD64__
__asm__ __volatile__ (
"movq %%rbx, %5\n"
"cpuid\n"
"movl %%ebx, %1\n"
"movq %5, %%rbx"
: "=a" (info[0]),
"=r" (info[1]),
"=c" (info[2]),
"=d" (info[3])
: "a" (which),
"m" (saved)
: "cc");
#else
__asm__ __volatile__ (
"movl %%ebx, %5\n"
"cpuid\n"
"movl %%ebx, %1\n"
"movl %5, %%ebx"
: "=a" (info[0]),
"=r" (info[1]),
"=c" (info[2]),
"=d" (info[3])
: "a"(which),
"m" (saved)
: "cc");
#endif
}
#endif
static KernelSet const *x86_select_kernels()
{
#if defined(__RADCONSOLE__)
return &s_kernel_avx;
#else
// query features
int info[4];
__cpuid(info, 1);
#ifdef RADFFT_AVX
if (info[2] & (1u << 28)) // AVX available? (bit 28 in ECX)
return &s_kernel_avx;
#endif
#if defined(RAD_GUARANTEED_SSE3)
return &s_kernel_sse3;
#else
#ifdef RAD_USES_SSE3
if (info[2] & (1u << 0)) // SSE3 available? (bit 0 in ECX)
return &s_kernel_sse3;
#endif
#ifdef INC_BINK2
return &s_kernel_sse; // bink2 requires minimum of SSE2
#else
if (info[3] & (1u << 25)) // SSE available? (bit 25 in EDX)
return &s_kernel_sse;
return &s_kernel_scalar;
#endif
#endif
#endif
}
#define CHOOSE_KERNELS x86_select_kernels()
#endif // __RADX86__
// --------------------------------------------------------------------------
// NEON kernels for ARM.
// --------------------------------------------------------------------------
#ifdef __RADNEON__
#include <arm_neon.h>
static inline float32x4_t neon_reverse(float32x4_t v)
{
// Swap low/high halves
// Since VREV is on d registers, ideally this just copy-propagates away
v = vcombine_f32(vget_high_f32(v), vget_low_f32(v));
return vrev64q_f32(v);
}
// NEON real FFT post-pass
static void neon_rfpost(rfft_complex out[], UINTa kEnd, UINTa N4)
{
if (N4 < 4)
{
scalar_rfpost(out, N4, N4);
return;
}
// Handle first few bins scalar.
scalar_rfpost(out, 4, N4);
F32 const *twiddle = (F32 const *) (s_twiddles + (N4 + 4));
F32 *out0 = (F32 *) (out + 4);
F32 *out1 = (F32 *) (out + (N4 * 2) - 7);
F32 *out0_end = (F32 *) (out + kEnd);
while (out0 < out0_end)
{
float32x4x2_t w = vld2q_f32(twiddle);
float32x4x2_t i0 = vld2q_f32(out0);
float32x4x2_t i1 = vld2q_f32(out1);
// Reverse i1
i1.val[0] = neon_reverse(i1.val[0]);
i1.val[1] = neon_reverse(i1.val[1]);
float32x4_t dr = vmulq_n_f32(vsubq_f32(i1.val[0], i0.val[0]), 0.5f);
float32x4_t di = vmulq_n_f32(vaddq_f32(i1.val[1], i0.val[1]), 0.5f);
float32x4_t evr = vaddq_f32(i0.val[0], dr);
float32x4_t evi = vsubq_f32(i0.val[1], di);
float32x4_t odr = vaddq_f32(vmulq_f32(w.val[0], di), vmulq_f32(w.val[1], dr));
float32x4_t odi = vsubq_f32(vmulq_f32(w.val[0], dr), vmulq_f32(w.val[1], di));
float32x4x2_t o0, o1;
o0.val[0] = vaddq_f32(evr, odr);
o0.val[1] = vaddq_f32(evi, odi);
o1.val[0] = vsubq_f32(evr, odr);
o1.val[1] = vsubq_f32(odi, evi);
o1.val[0] = neon_reverse(o1.val[0]);
o1.val[1] = neon_reverse(o1.val[1]);
vst2q_f32(out0, o0);
vst2q_f32(out1, o1);
out0 += 8;
out1 -= 8;
twiddle += 8;
}
}
// NEON real IFFT pre-pass
static void neon_ripre(rfft_complex out[], UINTa kEnd, UINTa N4)
{
if (N4 < 4)
{
scalar_ripre(out, N4, N4);
return;
}
// Handle first few bins scalar.
scalar_ripre(out, 4, N4);
F32 const *twiddle = (F32 const *) (s_twiddles + (N4 + 4));
F32 *out0 = (F32 *) (out + 4);
F32 *out1 = (F32 *) (out + (N4 * 2) - 7);
F32 *out0_end = (F32 *) (out + kEnd);
while (out0 < out0_end)
{
float32x4x2_t w = vld2q_f32(twiddle);
float32x4x2_t i0 = vld2q_f32(out0);
float32x4x2_t i1 = vld2q_f32(out1);
// Reverse i1
i1.val[0] = neon_reverse(i1.val[0]);
i1.val[1] = neon_reverse(i1.val[1]);
float32x4_t dr = vmulq_n_f32(vsubq_f32(i0.val[0], i1.val[0]), 0.5f);
float32x4_t di = vmulq_n_f32(vaddq_f32(i0.val[1], i1.val[1]), 0.5f);
float32x4_t evr = vsubq_f32(i0.val[0], dr);
float32x4_t evi = vsubq_f32(i0.val[1], di);
float32x4_t odr = vaddq_f32(vmulq_f32(w.val[0], di), vmulq_f32(w.val[1], dr));
float32x4_t odi = vsubq_f32(vmulq_f32(w.val[0], dr), vmulq_f32(w.val[1], di));
float32x4x2_t o0, o1;
o0.val[0] = vsubq_f32(evr, odr);
o0.val[1] = vaddq_f32(evi, odi);
o1.val[0] = vaddq_f32(evr, odr);
o1.val[1] = vsubq_f32(odi, evi);
o1.val[0] = neon_reverse(o1.val[0]);
o1.val[1] = neon_reverse(o1.val[1]);
vst2q_f32(out0, o0);
vst2q_f32(out1, o1);
out0 += 8;
out1 -= 8;
twiddle += 8;
}
}
// NEON complex conjugate split-radix reduction pass. This is the main workhorse inner loop for IFFTs.
static void neon_cfpass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
UINTa N1;
do
{
N1 = plan->Nloop;
UINTa counter = N1/2;
float *out0 = (float *) (data + plan->offs);
float *out1 = out0 + 2*N1;
float *out2 = out0 + 4*N1;
float *out3 = out0 + 6*N1;
float const *twiddle = (float const *) (s_twiddles + N1);
++plan;
if (counter > 1)
{
do
{
// Load input complex values, unpacking as we go
float32x4x2_t w = vld2q_f32(twiddle);
float32x4x2_t x2 = vld2q_f32(out2);
float32x4x2_t x3 = vld2q_f32(out3);
// This is a straight translation of the scalar code
float32x4_t Zkr = vsubq_f32(vmulq_f32(w.val[0], x2.val[0]), vmulq_f32(w.val[1], x2.val[1]));
float32x4_t Zki = vaddq_f32(vmulq_f32(w.val[0], x2.val[1]), vmulq_f32(w.val[1], x2.val[0]));
float32x4_t Zpkr = vaddq_f32(vmulq_f32(w.val[0], x3.val[0]), vmulq_f32(w.val[1], x3.val[1]));
float32x4_t Zpki = vsubq_f32(vmulq_f32(w.val[0], x3.val[1]), vmulq_f32(w.val[1], x3.val[0]));
float32x4_t Zsumr = vaddq_f32(Zkr, Zpkr);
float32x4_t Zsumi = vaddq_f32(Zki, Zpki);
float32x4_t Zdifr = vsubq_f32(Zki, Zpki);
float32x4_t Zdifi = vsubq_f32(Zpkr, Zkr);
float32x4x2_t x0 = vld2q_f32(out0);
float32x4x2_t x1 = vld2q_f32(out1);
x2.val[0] = vsubq_f32(x0.val[0], Zsumr);
x2.val[1] = vsubq_f32(x0.val[1], Zsumi);
x0.val[0] = vaddq_f32(x0.val[0], Zsumr);
x0.val[1] = vaddq_f32(x0.val[1], Zsumi);
x3.val[0] = vsubq_f32(x1.val[0], Zdifr);
x3.val[1] = vsubq_f32(x1.val[1], Zdifi);
x1.val[0] = vaddq_f32(x1.val[0], Zdifr);
x1.val[1] = vaddq_f32(x1.val[1], Zdifi);
// Store back
vst2q_f32(out0, x0);
vst2q_f32(out1, x1);
vst2q_f32(out2, x2);
vst2q_f32(out3, x3);
out0 += 8;
out1 += 8;
out2 += 8;
out3 += 8;
twiddle += 8;
} while (counter -= 2);
}
else
{
// Load input complex values, unpacking as we go
float32x2x2_t w = vld2_f32(twiddle);
float32x2x2_t x2 = vld2_f32(out2);
float32x2x2_t x3 = vld2_f32(out3);
// This is a straight translation of the scalar code
float32x2_t Zkr = vsub_f32(vmul_f32(w.val[0], x2.val[0]), vmul_f32(w.val[1], x2.val[1]));
float32x2_t Zki = vadd_f32(vmul_f32(w.val[0], x2.val[1]), vmul_f32(w.val[1], x2.val[0]));
float32x2_t Zpkr = vadd_f32(vmul_f32(w.val[0], x3.val[0]), vmul_f32(w.val[1], x3.val[1]));
float32x2_t Zpki = vsub_f32(vmul_f32(w.val[0], x3.val[1]), vmul_f32(w.val[1], x3.val[0]));
float32x2_t Zsumr = vadd_f32(Zkr, Zpkr);
float32x2_t Zsumi = vadd_f32(Zki, Zpki);
float32x2_t Zdifr = vsub_f32(Zki, Zpki);
float32x2_t Zdifi = vsub_f32(Zpkr, Zkr);
float32x2x2_t x0 = vld2_f32(out0);
float32x2x2_t x1 = vld2_f32(out1);
x2.val[0] = vsub_f32(x0.val[0], Zsumr);
x2.val[1] = vsub_f32(x0.val[1], Zsumi);
x0.val[0] = vadd_f32(x0.val[0], Zsumr);
x0.val[1] = vadd_f32(x0.val[1], Zsumi);
x3.val[0] = vsub_f32(x1.val[0], Zdifr);
x3.val[1] = vsub_f32(x1.val[1], Zdifi);
x1.val[0] = vadd_f32(x1.val[0], Zdifr);
x1.val[1] = vadd_f32(x1.val[1], Zdifi);
// Store back
vst2_f32(out0, x0);
vst2_f32(out1, x1);
vst2_f32(out2, x2);
vst2_f32(out3, x3);
}
} while (N1 < Nover4);
}
// NEON complex inverse conjugate split-radix reduction pass. This is the main workhorse inner loop for IFFTs.
static void neon_cipass(rfft_complex data[], PlanElement const *plan, UINTa Nover4)
{
UINTa N1;
do
{
N1 = plan->Nloop;
UINTa counter = N1 / 2;
float *out0 = (float *) (data + plan->offs);
float *out1 = out0 + 2*N1;
float *out2 = out0 + 4*N1;
float *out3 = out0 + 6*N1;
float const *twiddle = (float const *) (s_twiddles + N1);
++plan;
if (counter > 1)
{
do
{
// Load input complex values, unpacking as we go
float32x4x2_t w = vld2q_f32(twiddle);
float32x4x2_t x2 = vld2q_f32(out2);
float32x4x2_t x3 = vld2q_f32(out3);
// This is a straight translation of the scalar code
float32x4_t Zkr = vaddq_f32(vmulq_f32(w.val[0], x2.val[0]), vmulq_f32(w.val[1], x2.val[1]));
float32x4_t Zpkr = vsubq_f32(vmulq_f32(w.val[0], x3.val[0]), vmulq_f32(w.val[1], x3.val[1]));
float32x4_t Zki = vsubq_f32(vmulq_f32(w.val[0], x2.val[1]), vmulq_f32(w.val[1], x2.val[0]));
float32x4_t Zpki = vaddq_f32(vmulq_f32(w.val[0], x3.val[1]), vmulq_f32(w.val[1], x3.val[0]));
float32x4_t Zsumr = vaddq_f32(Zkr, Zpkr);
float32x4_t Zdifi = vsubq_f32(Zkr, Zpkr);
float32x4_t Zsumi = vaddq_f32(Zki, Zpki);
float32x4_t Zdifr = vsubq_f32(Zpki, Zki);
float32x4x2_t x0 = vld2q_f32(out0);
float32x4x2_t x1 = vld2q_f32(out1);
x2.val[0] = vsubq_f32(x0.val[0], Zsumr);
x0.val[0] = vaddq_f32(x0.val[0], Zsumr);
x3.val[1] = vsubq_f32(x1.val[1], Zdifi);
x1.val[1] = vaddq_f32(x1.val[1], Zdifi);
x2.val[1] = vsubq_f32(x0.val[1], Zsumi);
x0.val[1] = vaddq_f32(x0.val[1], Zsumi);
x3.val[0] = vsubq_f32(x1.val[0], Zdifr);
x1.val[0] = vaddq_f32(x1.val[0], Zdifr);
// Store back
vst2q_f32(out2, x2);
vst2q_f32(out0, x0);
vst2q_f32(out3, x3);
vst2q_f32(out1, x1);
out0 += 8;
out1 += 8;
out2 += 8;
out3 += 8;
twiddle += 8;
} while (counter -= 2);
}
else
{
// Load input complex values, unpacking as we go
float32x2x2_t w = vld2_f32(twiddle);
float32x2x2_t x2 = vld2_f32(out2);
float32x2x2_t x3 = vld2_f32(out3);
// This is a straight translation of the scalar code
float32x2_t Zkr = vadd_f32(vmul_f32(w.val[0], x2.val[0]), vmul_f32(w.val[1], x2.val[1]));
float32x2_t Zpkr = vsub_f32(vmul_f32(w.val[0], x3.val[0]), vmul_f32(w.val[1], x3.val[1]));
float32x2_t Zki = vsub_f32(vmul_f32(w.val[0], x2.val[1]), vmul_f32(w.val[1], x2.val[0]));
float32x2_t Zpki = vadd_f32(vmul_f32(w.val[0], x3.val[1]), vmul_f32(w.val[1], x3.val[0]));
float32x2_t Zsumr = vadd_f32(Zkr, Zpkr);
float32x2_t Zdifi = vsub_f32(Zkr, Zpkr);
float32x2_t Zsumi = vadd_f32(Zki, Zpki);
float32x2_t Zdifr = vsub_f32(Zpki, Zki);
float32x2x2_t x0 = vld2_f32(out0);
float32x2x2_t x1 = vld2_f32(out1);
x2.val[0] = vsub_f32(x0.val[0], Zsumr);
x0.val[0] = vadd_f32(x0.val[0], Zsumr);
x3.val[1] = vsub_f32(x1.val[1], Zdifi);
x1.val[1] = vadd_f32(x1.val[1], Zdifi);
x2.val[1] = vsub_f32(x0.val[1], Zsumi);
x0.val[1] = vadd_f32(x0.val[1], Zsumi);
x3.val[0] = vsub_f32(x1.val[0], Zdifr);
x1.val[0] = vadd_f32(x1.val[0], Zdifr);
// Store back
vst2_f32(out2, x2);
vst2_f32(out0, x0);
vst2_f32(out3, x3);
vst2_f32(out1, x1);
}
} while (N1 < Nover4);
}
// NEON 2x Radix-2 FFT codelet.
static void neon_2radix2(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
if (j0 == j1)
return;
// SIMD loop wants an even number of elements at an aligned
// offset. Thus, we may have extra elements at the beginning
// or end.
UINTa j1a = j1 & ~1;
if (j0 & 1)
{
scalar_2radix2(out, in0, in1, in2, in3, j0, j0 + 1, perm);
++j0;
}
for (UINTa j = j0; j < j1a; j += 2)
{
float32x4_t a = vld1q_f32(&in0[j].re);
float32x4_t b = vld1q_f32(&in2[j].re);
float32x4_t c = vld1q_f32(&in3[j].re);
float32x4_t d = vld1q_f32(&in1[j].re);
F32 *o0 = &out[perm[j+0]].re;
F32 *o1 = &out[perm[j+1]].re;
float32x4_t E = vaddq_f32(a, b);
float32x4_t F = vsubq_f32(a, b);
float32x4_t G = vaddq_f32(c, d);
float32x4_t H = vsubq_f32(c, d);
// Store
vst1_f32(o0 + 0, vget_low_f32(E));
vst1_f32(o0 + 2, vget_low_f32(F));
vst1_f32(o0 + 4, vget_low_f32(G));
vst1_f32(o0 + 6, vget_low_f32(H));
vst1_f32(o1 + 0, vget_high_f32(E));
vst1_f32(o1 + 2, vget_high_f32(F));
vst1_f32(o1 + 4, vget_high_f32(G));
vst1_f32(o1 + 6, vget_high_f32(H));
}
if (j1a != j1)
scalar_2radix2(out, in0, in1, in2, in3, j1a, j1, perm);
}
// NEON Radix-4 FFT codelet.
static void neon_radix4(rfft_complex out[], rfft_complex const in0[], rfft_complex const in1[], rfft_complex const in2[], rfft_complex const in3[], UINTa j0, UINTa j1, Index const *perm)
{
if (j0 == j1)
return;
// SIMD loop wants an even number of elements at an aligned
// offset. Thus, we may have extra elements at the beginning
// or end.
UINTa j1a = j1 & ~1;
if (j0 & 1)
{
scalar_radix4(out, in0, in1, in2, in3, j0, j0 + 1, perm);
++j0;
}
// NOTE: pretty un-NEON to work with interleaved values like that, but
// it ends up better than the split variants I could come up with.
U32 conjflip_c[4] = { 0, 0x80000000u, 0, 0x80000000u };
uint32x4_t conjflip = vld1q_u32(conjflip_c);
for (UINTa j = j0; j < j1a; j += 2)
{
float32x4_t a = vld1q_f32(&in0[j].re);
float32x4_t b = vld1q_f32(&in2[j].re);
float32x4_t c = vld1q_f32(&in1[j].re);
float32x4_t d = vld1q_f32(&in3[j].re);
F32 *o0 = &out[perm[j+0]].re;
F32 *o1 = &out[perm[j+1]].re;
float32x4_t A = vaddq_f32(a, b);
float32x4_t B = vsubq_f32(a, b);
float32x4_t C = vaddq_f32(c, d);
float32x4_t D = vsubq_f32(c, d);
// D *= -i
D = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(vrev64q_f32(D)), conjflip));
float32x4_t E = vaddq_f32(A, C);
float32x4_t F = vaddq_f32(B, D);
float32x4_t G = vsubq_f32(A, C);
float32x4_t H = vsubq_f32(B, D);
// Store
vst1_f32(o0 + 0, vget_low_f32(E));
vst1_f32(o0 + 2, vget_low_f32(F));
vst1_f32(o0 + 4, vget_low_f32(G));
vst1_f32(o0 + 6, vget_low_f32(H));
vst1_f32(o1 + 0, vget_high_f32(E));
vst1_f32(o1 + 2, vget_high_f32(F));
vst1_f32(o1 + 4, vget_high_f32(G));
vst1_f32(o1 + 6, vget_high_f32(H));
}
if (j1a != j1)
scalar_radix4(out, in0, in1, in2, in3, j1a, j1, perm);
}
static void neon_dct_split(F32 out[], F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 16)
{
scalar_dct_split(out, in, N);
return;
}
#endif
// N is pow2 and >=16
F32 *out0 = out;
F32 *out1 = out + N;
F32 const *inp = in;
do
{
float32x4x2_t v0 = vld2q_f32(inp);
float32x4x2_t v1 = vld2q_f32(inp + 8);
vst1q_f32(out0, v0.val[0]);
vst1q_f32(out0 + 4, v1.val[0]);
vst1q_f32(out1 - 4, neon_reverse(v0.val[1]));
vst1q_f32(out1 - 8, neon_reverse(v1.val[1]));
inp += 16;
out0 += 8;
out1 -= 8;
} while (N -= 16);
}
static void neon_dct_merge(F32 out[], F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 16)
{
scalar_dct_merge(out, in, N);
return;
}
#endif
// N is pow2 and >=16
F32 const *in0 = in;
F32 const *in1 = in + N;
F32 *outp = out;
do
{
float32x4x2_t v0, v1;
v0.val[0] = vld1q_f32(in0);
v0.val[1] = neon_reverse(vld1q_f32(in1 - 4));
v1.val[0] = vld1q_f32(in0 + 4);
v1.val[1] = neon_reverse(vld1q_f32(in1 - 8));
vst2q_f32(outp, v0);
vst2q_f32(outp + 8, v1);
outp += 16;
in0 += 8;
in1 -= 8;
}
while (N -= 16);
}
static void neon_dct_merge_s16(S16 out[], F32 scale, F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 16)
{
scalar_dct_merge_s16(out, scale, in, N);
return;
}
#endif
// N is pow2 and >=16
F32 const *in0 = in;
F32 const *in1 = in + N;
S16 *outp = out;
float32x4_t scale128 = vmovq_n_f32( scale );
do
{
float32x4x2_t v0, v1;
int32x4x2_t i0, i1;
int16x8x2_t is16;
v0.val[0] = vld1q_f32(in0);
v0.val[1] = vld1q_f32(in0 + 4);
v1.val[0] = neon_reverse(vld1q_f32(in1 - 4));
v1.val[1] = neon_reverse(vld1q_f32(in1 - 8));
// scale the values
v0.val[0] = vmulq_f32( v0.val[0], scale128 );
v0.val[1] = vmulq_f32( v0.val[1], scale128 );
v1.val[0] = vmulq_f32( v1.val[0], scale128 );
v1.val[1] = vmulq_f32( v1.val[1], scale128 );
// convert to 32-bit ints
i0.val[0] = vcvtq_s32_f32( v0.val[0] );
i0.val[1] = vcvtq_s32_f32( v0.val[1] );
i1.val[0] = vcvtq_s32_f32( v1.val[0] );
i1.val[1] = vcvtq_s32_f32( v1.val[1] );
// merge them
is16.val[0] = vcombine_s16(vqmovn_s32(i0.val[0]), vqmovn_s32(i0.val[1]));
is16.val[1] = vcombine_s16(vqmovn_s32(i1.val[0]), vqmovn_s32(i1.val[1]));
// store if
vst2q_s16(outp, is16);
outp += 16;
in0 += 8;
in1 -= 8;
}
while (N -= 16);
}
static void neon_dct_merge_s16s(S16 out[], S16 left[], F32 scale, F32 const in[], UINTa N)
{
#ifdef WANT_TINY
if (N < 16)
{
scalar_dct_merge_s16s(out, left, scale, in, N);
return;
}
#endif
// N is pow2 and >=16
F32 const *in0 = in;
F32 const *in1 = in + N;
S16 *outp = out;
float32x4_t scale128 = vmovq_n_f32( scale );
do
{
float32x4x2_t v0, v1;
int32x4x2_t i0, i1;
int16x8x2_t rs16;
int16x8x2_t ls16;
int16x8x4_t zs16;
v0.val[0] = vld1q_f32(in0);
v0.val[1] = vld1q_f32(in0 + 4);
v1.val[0] = neon_reverse(vld1q_f32(in1 - 4));
v1.val[1] = neon_reverse(vld1q_f32(in1 - 8));
// scale the values
v0.val[0] = vmulq_f32( v0.val[0], scale128 );
v0.val[1] = vmulq_f32( v0.val[1], scale128 );
v1.val[0] = vmulq_f32( v1.val[0], scale128 );
v1.val[1] = vmulq_f32( v1.val[1], scale128 );
// convert to 32-bit ints
i0.val[0] = vcvtq_s32_f32( v0.val[0] );
i0.val[1] = vcvtq_s32_f32( v0.val[1] );
i1.val[0] = vcvtq_s32_f32( v1.val[0] );
i1.val[1] = vcvtq_s32_f32( v1.val[1] );
// merge them
rs16.val[0] = vcombine_s16(vqmovn_s32(i0.val[0]), vqmovn_s32(i0.val[1]));
rs16.val[1] = vcombine_s16(vqmovn_s32(i1.val[0]), vqmovn_s32(i1.val[1]));
ls16 = vld2q_s16( left );
zs16.val[0] = ls16.val[0];
zs16.val[1] = rs16.val[0];
zs16.val[2] = ls16.val[1];
zs16.val[3] = rs16.val[1];
vst4q_s16(outp, zs16);
left += 16;
outp += 32;
in0 += 8;
in1 -= 8;
}
while (N -= 16);
}
static void neon_dct2_modulate(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle)
{
if (N < 16)
{
scalar_dct2_modulate(out, in, N, Nlast, twiddle);
return;
}
// First few bins have exceptional cases, let scalar routine handle it
FFTASSERT((Nlast % 8) == 0 && Nlast >= 8);
scalar_dct2_modulate(out, in, N, 8, twiddle);
for (UINTa k = 8; k < Nlast; k += 8)
{
float32x4x2_t Z0 = vld2q_f32(in + k*2);
float32x4x2_t Z1 = vld2q_f32(in + k*2 + 8);
float32x4x2_t w0 = vld2q_f32(&twiddle[k].re);
float32x4x2_t w1 = vld2q_f32(&twiddle[k + 4].re);
float32x4_t x0r = vaddq_f32(vmulq_f32(w0.val[0], Z0.val[0]), vmulq_f32(w0.val[1], Z0.val[1]));
float32x4_t x0i = vsubq_f32(vmulq_f32(w0.val[0], Z0.val[1]), vmulq_f32(w0.val[1], Z0.val[0]));
float32x4_t x1r = vaddq_f32(vmulq_f32(w1.val[0], Z1.val[0]), vmulq_f32(w1.val[1], Z1.val[1]));
float32x4_t x1i = vsubq_f32(vmulq_f32(w1.val[0], Z1.val[1]), vmulq_f32(w1.val[1], Z1.val[0]));
x0i = neon_reverse(x0i);
x1i = neon_reverse(x1i);
vst1q_f32(out + k, x0r);
vst1q_f32(out + k + 4, x1r);
vst1q_f32(out + N - 3 - k, x0i);
vst1q_f32(out + N - 7 - k, x1i);
}
}
static void neon_dct3_modulate(F32 out[], F32 const in[], UINTa N, UINTa Nlast, rfft_complex const *twiddle)
{
if (N < 16)
{
scalar_dct3_modulate(out, in, N, Nlast, twiddle);
return;
}
// First few bins have exceptional cases, let scalar routine handle it
FFTASSERT((Nlast % 8) == 0 && Nlast >= 8);
scalar_dct3_modulate(out, in, N, 8, twiddle);
for (UINTa k = 8; k < Nlast; k += 8)
{
float32x4_t x0r = vld1q_f32(in + k);
float32x4_t x0i = vld1q_f32(in + N - 3 - k);
float32x4_t x1r = vld1q_f32(in + k + 4);
float32x4_t x1i = vld1q_f32(in + N - 7 - k);
float32x4x2_t w0 = vld2q_f32(&twiddle[k].re);
float32x4x2_t w1 = vld2q_f32(&twiddle[k + 4].re);
// Reverse xi
x0i = neon_reverse(x0i);
x1i = neon_reverse(x1i);
float32x4x2_t Z0, Z1;
Z0.val[0] = vsubq_f32(vmulq_f32(w0.val[0], x0r), vmulq_f32(w0.val[1], x0i));
Z0.val[1] = vaddq_f32(vmulq_f32(w0.val[1], x0r), vmulq_f32(w0.val[0], x0i));
Z1.val[0] = vsubq_f32(vmulq_f32(w1.val[0], x1r), vmulq_f32(w1.val[1], x1i));
Z1.val[1] = vaddq_f32(vmulq_f32(w1.val[1], x1r), vmulq_f32(w1.val[0], x1i));
vst2q_f32(out + k*2, Z0);
vst2q_f32(out + k*2 + 8, Z1);
}
}
// Kernels to use when NEON is available
static KernelSet const s_kernel_neon = {
neon_rfpost,
neon_ripre,
neon_cfpass,
neon_cipass,
neon_radix4,
neon_2radix2,
neon_dct_split,
neon_dct_merge,
neon_dct_merge_s16,
neon_dct_merge_s16s,
neon_dct2_modulate,
neon_dct3_modulate
};
#define STATIC_KERNEL &s_kernel_neon
#endif
// --------------------------------------------------------------------------
// Driver/glue layer
// --------------------------------------------------------------------------
static void fft_driver(rfft_complex outc[], UINTa N, CFFTKernel *kernel)
{
UINTa N1 = N/4;
PlanElement local;
PlanElement const *plan = s_recursion_plan;
if (N > kMaxPlan)
{
// We're gonna call the kernel once after we're done recursing
// to combine partial results; set that up here.
local.offs = 0;
local.Nloop = (Index)N1;
plan = &local;
// Recursion pattern for our conjugate split-radix FFT
fft_driver(outc, N1*2, kernel);
fft_driver(outc + N1*2, N1, kernel);
fft_driver(outc + N1*3, N1, kernel);
}
kernel(outc, plan, N1);
}
static void fft_base_driver(rfft_complex out[], rfft_complex const in[], UINTa N, bool inverse, BaseKernel *kern_radix4, BaseKernel *kern_2radix2)
{
// Handle base cases and permutation
UINTa N1 = N / 4;
U16 const *delta = s_permute + (N / kLeafN);
rfft_complex const *in0 = in + 0*N1;
rfft_complex const *in1 = in + 1*N1;
rfft_complex const *in2 = in + 2*N1;
rfft_complex const *in3 = in + 3*N1;
UINTa j1 = N1 / 3 + 1;
UINTa j2 = N1 - j1 + 1;
if (!inverse)
{
kern_radix4 (out, in0, in1, in2, in3, 0, j1, delta);
kern_2radix2(out, in0, in1, in2, in3, j1, j2, delta);
kern_radix4 (out, in3, in0, in1, in2, j2, N1, delta);
}
else
{
kern_radix4 (out, in0, in3, in2, in1, 0, j1, delta);
kern_2radix2(out, in0, in1, in2, in3, j1, j2, delta);
kern_radix4 (out, in3, in2, in1, in0, j2, N1, delta);
}
}
// --------------------------------------------------------------------------
// API / frontend layer
// --------------------------------------------------------------------------
// If we have no kernel set defined, default to everything scalar
// (this is for platforms we don't have optimizations for)
#if !defined(CHOOSE_KERNELS) && !defined(STATIC_KERNEL)
#define STATIC_KERNEL &s_kernel_scalar
#endif
#ifdef STATIC_KERNEL
#define s_kernel (STATIC_KERNEL)
#else
static KernelSet const *s_kernel;
#endif
static bool is_pow2(UINTa N)
{
return N != 0 && (N & (N - 1)) == 0;
}
static rfft_complex const *get_dct_twiddle(UINTa N)
{
#ifdef NO_DCT_TABLES
return NULL;
#else
if (N*4 <= kMaxN)
return s_twiddles + N;
else // DCT twiddles: two more levels! (but only one eight of a circle so the density is different)
return s_dct_twiddles + ((N/2) - (kMaxN/4));
#endif
}
#ifndef USETABLES
static void calc_twiddle(rfft_complex *twiddle, UINTa count, UINTa freq)
{
F64 const kPi = 3.1415926535897932384626433832795;
F64 step = -2.0 * kPi / (F64)freq;
for (UINTa k = 0; k < count; k++)
{
F64 phase = step * (F64)k;
twiddle[k].re = (F32)cos(phase);
twiddle[k].im = (F32)sin(phase);
}
}
static void init_permute_rec(U16 perm[], UINTa mask, UINTa N, UINTa offs_in, UINTa offs_out, UINTa stride)
{
if (N <= kLeafN)
perm[offs_in & mask] = (U16) offs_out;
else
{
init_permute_rec(perm, mask, N/2, offs_in, offs_out, stride * 2);
init_permute_rec(perm, mask, N/4, offs_in + stride, offs_out + N/2, stride * 4);
if (N/4 >= kLeafN)
init_permute_rec(perm, mask, N/4, offs_in - stride, offs_out + 3*N/4, stride * 4);
}
}
static void init_tables()
{
// Build twiddles
for (UINTa N = 1; N <= kMaxN / 4; N *= 2)
calc_twiddle(s_twiddles + N, N, N * 4);
// Two more at higher freq for DCT
calc_twiddle(s_dct_twiddles, kMaxN / 4, kMaxN * 2);
calc_twiddle(s_dct_twiddles + kMaxN / 4, kMaxN / 2, kMaxN * 4);
// Base permutation table
init_permute_rec(s_permute + (kMaxN / kLeafN), (kMaxN / kLeafN) - 1, kMaxN, 0, 0, 1);
// Then subsample to get smaller versions
for (UINTa N = (kMaxN / kLeafN) / 2; N >= 1; N /= 2)
{
U16 const *in = s_permute + N*2;
U16 *out = s_permute + N;
for (UINTa i = 0; i < N; i++)
out[i] = in[i*2];
}
}
#endif
static int radfft_init_helper()
{
#ifdef CHOOSE_KERNELS
s_kernel = CHOOSE_KERNELS;
#endif
#ifndef USETABLES
init_tables();
#endif
return 1;
}
void RADLINK radfft_init()
{
// lean on c++ static init here: this is guaranteed by c++ to be called once
// and will be wrapped with a mutex by the compiler.
static int done_init = radfft_init_helper();
}
// Complex FFT
void RADLINK radfft_cfft(rfft_complex out[], rfft_complex const in[], UINTa N)
{
FFTASSERT(is_pow2(N) && N <= kMaxN);
#ifdef WANT_TINY
if (N <= 2)
{
scalar_cfft_tiny(out, in, N);
return;
}
#endif
fft_base_driver(out, in, N, false, s_kernel->radix4, s_kernel->radix2_2);
if (N >= 8)
fft_driver(out, N, s_kernel->cfpass);
}
// Complex IFFT
void RADLINK radfft_cifft(rfft_complex out[], rfft_complex const in[], UINTa N)
{
FFTASSERT(is_pow2(N) && N <= kMaxN);
#ifdef WANT_TINY
if (N <= 2)
{
scalar_cfft_tiny(out, in, N);
return;
}
#endif
fft_base_driver(out, in, N, true, s_kernel->radix4, s_kernel->radix2_2);
if (N >= 8)
fft_driver(out, N, s_kernel->cipass);
}
// Real FFT
void RADLINK radfft_rfft(rfft_complex out[], F32 const in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
UINTa N4 = N/4;
// First do a size-N/2 complex IFFT on "in", computing two
// size-N/2 FFTs of the even and odd samples from "in".
radfft_cifft(out, (rfft_complex const *)in, N / 2);
// We now have FFT_N/2[even_samples] + i*FFT_N/2[odd_samples],
// and because of symmetries in FFTs of real samples, we can
// disentangle this just fine.
//
// That means we're almost done with a size-N FFT: all we need
// to do is a final radix-2 butterfly step. "rfpost" does just
// that: the disentangling using symmetry followed by a final
// radix-2 butterfly.
//
// Again, this is a real FFT which has conjugate symmetry:
// x[0..N-1] = input signal
// X[0..N-1] = output signal = FFT(x)
// then X[N - k] = conj(X[k]) (addressing mod N). In particular,
// this means that X[0] and X[N/2] are real, and the remaining
// values are uniquely determined by X[1..N/2-1].
//
// So we pack the real values for X[0] and X[N/2] into a single
// complex value at offset 0, and otherwise just return the first
// half.
// 0 / Nyquist bins
rfft_complex v = out[0];
out[0].re = v.re + v.im;
out[0].im = v.re - v.im;
// Remaining bins
if (N4 > 0)
s_kernel->rfpost(out, N4, N4);
}
// DCT-II
void RADLINK radfft_dct(F32 out[], F32 in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
// A size-N DCT-II can be expressed as the first N elements of
// a real DFT of 4N samples
// [0,in_0, 0,in_1, 0,in_2, ..., 0,in_{N-1},
// 0,in_{N-1}, 0,in_{N-2}, 0,in_{N-3}, ..., 0,in_0]
//
// First, do a radix-4 Cooley-Tukey DIT step, yielding four
// size-N sub-DFTs on the expanded input samples. Both of
// the even sub-DFTs are of all-0 samples (so themselves 0),
// and the two odd DFTs are closely related because the
// input vectors are symmetric.
//
// Long story short, this reduces into a size-N DFT of
// permuted input data followed by modulation (point-wise
// complex multiplication) with a bunch of twiddle factors.
s_kernel->dct_split(out, in, N);
radfft_rfft((rfft_complex *)in, out, N);
s_kernel->dct2_mod(out, in, N, N/2, get_dct_twiddle(N));
}
// Real IFFT
void RADLINK radfft_rifft(F32 out[], rfft_complex in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
UINTa N4 = N/4;
// Dual to the process in "rfft": we have the first half of the
// FFT of a signal we know is real (and hence has conjugate
// symmetry) and want to do the IFFT. So we unpack the packed
// 0/Nyquist bin and then do a combined symmetric expand and
// radix-2 butterfly, after which we're back where we were
// after "radfft_cifft" above:
// FFT_N/2[even_samples] + i*FFT_N/2[odd_samples]
// From the on, it's just a complex IFFT.
// Pre-pass
rfft_complex v = in[0];
in[0].re = 0.5f * (v.re + v.im);
in[0].im = 0.5f * (v.re - v.im);
if (N4 > 0)
s_kernel->ripre(in, N4, N4);
// Complex IFFT computes the result
radfft_cfft((rfft_complex *)out, in, N/2);
}
// DCT-III (IDCT)
void RADLINK radfft_idct(F32 out[], F32 in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
// The corresponding IDCT is the inverse of the above:
// (De)modulate, IFFT, then un-permute.
s_kernel->dct3_mod(out, in, N, N/2, get_dct_twiddle(N));
radfft_rifft(in, (rfft_complex *)out, N);
s_kernel->dct_merge(out, in, N);
}
// DCT-III (IDCT)
void RADLINK radfft_idct_to_S16(S16 outs16[], F32 scale, F32 tmp[], F32 in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
// The corresponding IDCT is the inverse of the above:
// (De)modulate, IFFT, then un-permute.
s_kernel->dct3_mod(tmp, in, N, N/2, get_dct_twiddle(N));
radfft_rifft(in, (rfft_complex *)tmp, N);
s_kernel->dct_merge_s16(outs16, scale, in, N);
}
// DCT-III (IDCT)
void RADLINK radfft_idct_to_S16_stereo_interleave(S16 outs16[], S16 left[], F32 scale, F32 tmp[], F32 in[], UINTa N)
{
FFTASSERT(is_pow2(N) && 2 <= N && N <= kMaxN);
// The corresponding IDCT is the inverse of the above:
// (De)modulate, IFFT, then un-permute.
s_kernel->dct3_mod(tmp, in, N, N/2, get_dct_twiddle(N));
radfft_rifft(in, (rfft_complex *)tmp, N);
s_kernel->dct_merge_s16s(outs16, left, scale, in, N);
}
#ifdef TABLEGEN
#include <stdio.h>
#define COUNTOF(x) (sizeof(x)/sizeof(*(x)))
static void print_twiddles(FILE *f, char const *name, rfft_complex *vals, UINTa count)
{
static UINTa const kValsPerLine = 16;
U32 const *data = (U32 const *)vals;
count *= 2; // complex numbers are pairs of real values
fprintf(f, "#define %s ((rfft_complex const *)radfft_%s_data)\n", name, name);
fprintf(f, "FFTTABLE(U32, radfft_%s_data[%d]) = {\n", name, (int)count);
for (UINTa i = 0; i < count; ++i)
{
if ((i % kValsPerLine) == 0)
fprintf(f, " ");
fprintf(f, "0x%08x,", data[i]);
if (i == count - 1 || ((i % kValsPerLine) == kValsPerLine - 1))
fprintf(f, "\n");
}
fprintf(f, "};\n\n");
}
static void print_permute(FILE *f, char const *name, Index const *vals, UINTa count)
{
static UINTa const kValsPerLine = 16;
fprintf(f, "#define %s ((Index const *)radfft_%s)\n", name, name);
fprintf(f, "FFTTABLE(Index, radfft_%s[%d]) = {\n", name, (int)count);
for (UINTa i = 0; i < count; ++i)
{
if ((i % kValsPerLine) == 0)
fprintf(f, " ");
fprintf(f, "%4d,", (int)vals[i]);
if (i == count - 1 || ((i % kValsPerLine) == kValsPerLine - 1))
fprintf(f, "\n");
else
fprintf(f, " ");
}
fprintf(f, "};\n\n");
}
int main()
{
radfft_init();
static char const filename[] = "radfft_tables.inl";
FILE *f = fopen(filename, "w");
if (!f)
{
printf("error opening '%s' for writing!\n", filename);
return 0;
}
// Size asserts
fprintf(f, "typedef int fft_assert_MaxN[(kMaxN == %d) ? 1 : -1];\n", (int)kMaxN);
fprintf(f, "typedef int fft_assert_LeafN[(kLeafN == %d) ? 1 : -1];\n", (int)kLeafN);
fprintf(f, "\n");
print_twiddles(f, "s_twiddles", s_twiddles, COUNTOF(s_twiddles));
print_twiddles(f, "s_dct_twiddles", s_dct_twiddles, COUNTOF(s_dct_twiddles));
print_permute(f, "s_permute", s_permute, COUNTOF(s_permute));
fclose(f);
printf("%s written.\n", filename);
return 0;
}
#endif
// vim:et:sts=4:sw=4