// Copyright Epic Games, Inc. All Rights Reserved. /*============================================================================= GPUFastFourierTransformCore.usf: Core Fast Fourier Transform Code =============================================================================*/ //------------------------------------------------------- RECOMPILE HASH #pragma message("UESHADERMETADATA_VERSION B96BD6F6-825A-4B16-AA27-96502C97998C") // This file requires the two values to be defined. // SCAN_LINE_LENGTH must be a power of RADIX and RADIX in turn must be a power of two // Need define: RADIX and SCAN_LINE_LENGTH #pragma once #include "Common.ush" #ifndef TWO_PI #define TWO_PI 2.0f * PI #endif #define FFTMemoryBarrier() GroupMemoryBarrierWithGroupSync() // Use floating point number representation. // Complex = x + iy, where x, y are both real numbers (floats) and i = sqrt(-1) #define Complex float2 // Complex Multiplication using Complex as a complex number // The full complex arithmetic is defined as: // Real(A) = A.x, Image(A) = A.y // Real(A + B) = A.x + B.x; Imag(A + B) = A.y + B.y // Real(A * B) = A.x * B.x - A.y * B.y; Imag(A * B) = A.x * B.y + B.x * A.y Complex ComplexMult(in Complex A, in Complex B) { return Complex(A.x * B.x - A.y * B.y, A.x * B.y + B.x * A.y); } Complex ComplexConjugate(in Complex Z) { return Complex(Z.x, -Z.y); } // ------------------------------------------------ // The Definition of the Fourier Transform: // ------------------------------------------------ // There is ambiguity in the definition of the Fourier Transform // so we explicitly state our definition. // // Given N data samples f_n (gernerally complex) with n = [0, N-1] // the transform produces N values F_k with k = [0, N-1]. // // These values (generally complex) are defined as // // F_k = Sum[f_n * Exp[i * n *k * 2 * Pi / N], {n, 0, N-1}] // // This may be invereted by computing // // f_n = (1/N) Sum[F_k * Exp[-i * n * k * 2 * Pi / N], {k, 0, N-1}] // // In both these sums 'i' is the complex number sqrt[-1], and // Exp[ i x] = Cos[x] + i Sin[x]. // // The actual Fast Fourier Transform (FFT) exploits symetries in the // trigonometric functions the allow us to compute the N different values // in the transform in N*Log(N) time (not N*N as you direct implimentaiton of // the sum would give) // // Note: the 'normalization' term (1/N) in the inverse. // // Also Note : // If the complex input signal z_n = x_n + i y_n is in fact two separate real signals x_n, y_n // then the FFT of each signal can be extracted from Z_k = FFT(z). // X_k = FFT(x) = (Z_k + Conjugate(Z_{N-k})) / 2 // Y_k = FFT(y) = -i (Z_k - Conjugate(Z_{N-k})) / 2 // // In general the transform of real data f_n has the symmetry // F_k = Conjugate(F_{N-k}) #define NORMALIZE_ORDER 1 float ForwardScale(uint ArrayLength) { #if NORMALIZE_ORDER == 1 return float(1); #else return (float(1) / float(ArrayLength) ); #endif } float InverseScale(uint ArrayLength) { #if NORMALIZE_ORDER == 1 return ( float(1) / float(ArrayLength) ); #else return float(1); #endif } // ------------------------------------------------ // FFT for different Radix size: // ------------------------------------------------ // Specialized code for doing the FFT or inverse FFT on arrays of size 2, 4, or 8 elements. // Note: because these small transforms are used as building blocks for transforming longer // arrays, the normalization term is omitted from the small inverse transforms. // void Radix2FFT(in bool bIsForward, inout Complex V0, inout Complex V1) { #if 0 // Reference. Complex Tmp = V1; V1 = V0 - Tmp; V0 = V0 + Tmp; #endif V0 = V0 + V1; V1 = V0 - V1 - V1; } void Radix2FFT(in bool bIsForward, inout Complex V[2]) { V[0] = V[0] + V[1]; V[1] = V[0] - V[1] - V[1]; } void Radix3FFT(in bool bIsForward, inout Complex V[3]) { Complex Tmp0 = V[0] + V[1] + V[2]; // exp( i 2 Pi / 3) = Cos(2 Pi / 3) + i Sin( 2 Pi / 3) Complex CS2PiOn3 = Complex(-0.5, 0.5 * sqrt(3.f)); // exp(i 4 Pi / 3) = Cos(4 Pi / 3) + i Sin(4 Pi / 3) Complex CS4PiOn3 = Complex(-0.5, -0.5 * sqrt(3.f)); if (bIsForward) { Complex Tmp1 = V[0] + ComplexMult(CS2PiOn3, V[1]) + ComplexMult(CS4PiOn3, V[2]); V[2] = V[0] + ComplexMult(CS4PiOn3, V[1]) + ComplexMult(CS2PiOn3, V[2]); V[1] = Tmp1; V[0] = Tmp0; } else { Complex Tmp1 = V[0] + ComplexMult(CS4PiOn3, V[1]) + ComplexMult(CS2PiOn3, V[2]); V[2] = V[0] + ComplexMult(CS2PiOn3, V[1]) + ComplexMult(CS4PiOn3, V[2]); V[1] = Tmp1; V[0] = Tmp0; } } void Radix4FFT(in bool bIsForward, inout Complex V0, inout Complex V1, inout Complex V2, inout Complex V3) { // The even and odd transforms Radix2FFT(bIsForward, V0, V2); Radix2FFT(bIsForward, V1, V3); // The butterfly merge of the even and odd transforms Complex Tmp; Complex TmpV1 = V1; if (bIsForward) { // Complex(0, 1) * V3 Tmp = Complex(-V3.y, V3.x); } else { // Complex(0, -1) * V3 Tmp = Complex(V3.y, -V3.x); } #if 0 // Reference. Complex Rslt[4]; Rslt[0] = V0 + V1; Rslt[1] = V2 + Tmp; Rslt[2] = V0 - V1; Rslt[3] = V2 - Tmp; V0 = Rslt[0]; V1 = Rslt[1]; V2 = Rslt[2]; V3 = Rslt[3]; #endif V0 = V0 + TmpV1; V1 = V2 + Tmp; V3 = V2 - Tmp; V2 = V0 - TmpV1 - TmpV1; } void Radix4FFT(in bool bIsForward, inout Complex V[4]) { // The even and odd transforms V[0] = V[0] + V[2]; V[1] = V[1] + V[3]; V[2] = V[0] - V[2] - V[2]; V[3] = V[1] - V[3] - V[3]; // The butterfly merge of the even and odd transforms Complex Tmp; Complex TmpV1 = V[1]; if (bIsForward) { // Complex(0, 1) * V3 Tmp = Complex(-V[3].y, V[3].x); } else { // Complex(0, -1) * V3 Tmp = Complex(V[3].y, -V[3].x); } V[0] = V[0] + TmpV1; V[1] = V[2] + Tmp; V[3] = V[2] - Tmp; V[2] = V[0] - TmpV1 - TmpV1; } void Radix8FFT(in bool bIsForward, inout Complex V0, inout Complex V1, inout Complex V2, inout Complex V3, inout Complex V4, inout Complex V5, inout Complex V6, inout Complex V7) { // The even and odd transforms Radix4FFT(bIsForward, V0, V2, V4, V6); Radix4FFT(bIsForward, V1, V3, V5, V7); Complex Twiddle; // 0.7071067811865475 = 1/sqrt(2) float InvSqrtTwo = float(1.f) / sqrt(2.f); if (bIsForward) { Twiddle = Complex(InvSqrtTwo, InvSqrtTwo); } else { Twiddle = Complex(InvSqrtTwo, -InvSqrtTwo); } Complex Rslt[8]; Complex Tmp = ComplexMult(Twiddle, V3); Rslt[0] = V0 + V1; Rslt[4] = V0 - V1; Rslt[1] = V2 + Tmp; Rslt[5] = V2 - Tmp; if (bIsForward) { // V4 + i V5 Rslt[2] = Complex(V4.x - V5.y, V4.y + V5.x); // V4 - i V5 Rslt[6] = Complex(V4.x + V5.y, V4.y - V5.x); } else { // V4 - iV5 Rslt[2] = Complex(V4.x + V5.y, V4.y - V5.x); // V4 + iV5 Rslt[6] = Complex(V4.x - V5.y, V4.y + V5.x); } Twiddle.x = -Twiddle.x; Tmp = ComplexMult(Twiddle, V7); Rslt[3] = V6 + Tmp; Rslt[7] = V6 - Tmp; V0 = Rslt[0]; V1 = Rslt[1]; V2 = Rslt[2]; V3 = Rslt[3]; V4 = Rslt[4]; V5 = Rslt[5]; V6 = Rslt[6]; V7 = Rslt[7]; } // Wrapper to select the correct radix FFT // at compile time. // // Radix in [2, 4, 8] void RadixFFT(in bool bIsForward, inout Complex v[RADIX]) { #if (RADIX == 2) //Radix2FFT(bIsForward, v[0], v[1]); Radix2FFT(bIsForward, v); #endif #if (RADIX == 3) Radix3FFT(bIsForward, v); #endif #if (RADIX == 4) // Radix4FFT(bIsForward, v[0], v[1], v[2], v[3]); Radix4FFT(bIsForward, v); #endif #if (RADIX == 8) Radix8FFT(bIsForward, v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); #endif #if (RADIX != 2 && RADIX != 4 && RADIX !=8 && RADIX != 3) #error "Undefined radix length!" #endif } // Utility: As a function of j, // returns Ns contiguous values, then skips R*Ns values, then the next Ns values, etc // (e.g. R = 3, Ns = 2: 0, 1, 6, 7, 12, 13..) // (e.g. R = 2, Ns = 4: 0, 1, 2, 3, 8, 9, 10, 11,..) uint Expand(in uint j, in uint Ns, in uint R) { return (j / Ns) * Ns * R + (j % Ns); } ///-------------------------------------------------------------------------- // Performs a single pass Stockham FFT using group shared memory. // // Inspired by // "High Performance Discrete Fourier Transforms on Graphics Processors" - Naga K Govindaraju et al, 2008 // Published in SC'08: Proceedings of the 2008 ACM/IEEE Conference on Supercomputing // One dimensional arrays of data (e.g. scanlines) are interperated as blocks of row-major data, // the columns of which are independently transformed with one thread group per column. // // NB: it is required that the a single column of Complex data fits in group shared memory. // // NUMTHREADSX : Number of threads in a thread group. // Multiple threads in a single thread group transform an array of lenght // 'ArrayLength' using the group shared memory as temporary scratch space. // #define RADIX; // this should come from the CPU // #define NUMTHREADSX; // this should come from the CPU . BlockHeight / RADIX = NUMTHREADSX // Group shared memory limited to 32k. We need to split this over 2 buffers. // 32k * (1024 byte / k) * (1 float / 4 byte ) = 8192 floats - or 4096 Complex // This will allow for transforms of lenght 4096 to be done in a single dispatch. #if SCAN_LINE_LENGTH > 4096 #define SHARED_MEMORY_BUFFER_COUNT 1 // Testing shows that in general having 2 buffers is faster: the overhead of additional syncs. #else #define SHARED_MEMORY_BUFFER_COUNT 1 #endif #if SHARED_MEMORY_BUFFER_COUNT==2 groupshared float SharedReal[ SCAN_LINE_LENGTH ]; groupshared float SharedImag[ SCAN_LINE_LENGTH ]; void CopyLocalToGroupShared(in Complex Local[RADIX], in uint Head, in uint Stride) { { for (uint r = 0, i = Head; r < RADIX; ++r, i += Stride) { SharedReal[ i ] = Local[ r ].x; } } { for (uint r = 0, i = Head; r < RADIX; ++r, i += Stride) { SharedImag[ i ] = Local[ r ].y; } } } void CopyLocalToGroupSharedWSync(in Complex Local[RADIX], in uint Head, in uint Stride) { CopyLocalToGroupShared(Local, Head, Stride); FFTMemoryBarrier(); } void CopyGroupSharedToLocal(inout Complex Local[RADIX], in uint Head, in uint Stride) { { for (uint r = 0, i = Head; r < RADIX; ++r, i += Stride) { Local[ r ].x = SharedReal[ i ]; } } { for (uint r = 0, i = Head; r < RADIX; ++r, i += Stride) { Local[ r ].y = SharedImag[ i ]; } } } // Exchange data with other threads by affecting a transpose // This accounts for about 1/3rd of the total time. void TransposeData(inout Complex Local[RADIX], uint AHead, uint AStride, uint BHead, uint BStride) { CopyLocalToGroupSharedWSync(Local, AHead, AStride); // Does a memory barrier CopyGroupSharedToLocal(Local, BHead, BStride); } #else // SHARED_MEMORY_BUFFER_COUNT == 1 groupshared float SharedReal[ 2 * SCAN_LINE_LENGTH ]; #define NUM_BANKS 32 void CopyLocalXToGroupShared(in Complex Local[RADIX], in uint Head, in uint Stride, in uint BankSkip) { uint i = Head; UNROLL for (uint r = 0; r < RADIX; ++r, i += Stride) { uint j = i + (i / NUM_BANKS) * BankSkip; SharedReal[ j ] = Local[ r ].x; } } void CopyLocalYToGroupShared(in Complex Local[RADIX], in uint Head, in uint Stride, in uint BankSkip) { uint i = Head; UNROLL for (uint r = 0; r < RADIX; ++r, i += Stride) { uint j = i + (i / NUM_BANKS) * BankSkip; SharedReal[ j ] = Local[ r ].y; } } void CopyGroupSharedToLocalX(inout Complex Local[RADIX], in uint Head, in uint Stride, in uint BankSkip) { uint i = Head; UNROLL for (uint r = 0; r < RADIX; ++r, i += Stride) { uint j = i + (i / NUM_BANKS) * BankSkip; Local[ r ].x = SharedReal[ j ]; } } void CopyGroupSharedToLocalY(inout Complex Local[RADIX], in uint Head, in uint Stride, in uint BankSkip) { uint i = Head; UNROLL for (uint r = 0; r < RADIX; ++r, i += Stride) { uint j = i + (i / NUM_BANKS) * BankSkip; Local[ r ].y = SharedReal[ j ]; } } void CopyLocalXToGroupShared(in Complex Local[RADIX], in uint Head, in uint Stride) { CopyLocalXToGroupShared(Local, Head, Stride, 0); } void CopyLocalYToGroupShared(in Complex Local[RADIX], in uint Head, in uint Stride) { CopyLocalYToGroupShared(Local, Head, Stride, 0); } void CopyGroupSharedToLocalX(inout Complex Local[RADIX], in uint Head, in uint Stride) { CopyGroupSharedToLocalX(Local, Head, Stride, 0); } void CopyGroupSharedToLocalY(inout Complex Local[RADIX], in uint Head, in uint Stride) { CopyGroupSharedToLocalY(Local, Head, Stride, 0); } // Exchange data with other threads by affecting a transpose void TransposeData(inout Complex Local[RADIX], uint AHead, uint AStride, uint BHead, uint BStride) { uint BankSkip = (AStride < NUM_BANKS) ? AStride : 0; CopyLocalXToGroupShared(Local, AHead, AStride, BankSkip); FFTMemoryBarrier(); CopyGroupSharedToLocalX(Local, BHead, BStride, BankSkip); FFTMemoryBarrier(); CopyLocalYToGroupShared(Local, AHead, AStride, BankSkip); FFTMemoryBarrier(); CopyGroupSharedToLocalY(Local, BHead, BStride, BankSkip); } #endif // SHARED_MEMORY_BUFFER_COUNT // This accounts for about 9% of total time void Butterfly(bool bIsForward, inout Complex Local[RADIX], uint ThreadIdx, uint Length) { float angle = TWO_PI * float( ThreadIdx % Length) / float(Length * RADIX); if (!bIsForward) { angle *= -1; } Complex TwiddleInc; sincos(angle, TwiddleInc.y, TwiddleInc.x); Complex Twiddle = TwiddleInc; for (uint r = 1; r < RADIX; ++r) { Local[r] = ComplexMult(Twiddle, Local[r]); Twiddle = ComplexMult(Twiddle, TwiddleInc); } } void ButterflyRadix4(bool bIsForward, inout Complex Local[RADIX], uint ThreadIdx, uint Length) { float AngleEven = TWO_PI * float(ThreadIdx) / float(Length * RADIX); float AngleOdd = TWO_PI * float(ThreadIdx+Length) / float(Length * RADIX); if (!bIsForward) { AngleEven *= -1; AngleOdd *= -1; } Complex TwiddleIncEven; sincos(AngleEven, TwiddleIncEven.y, TwiddleIncEven.x); Complex TwiddleIncOdd; sincos(AngleOdd, TwiddleIncOdd.y, TwiddleIncOdd.x); Complex TwiddleEven = TwiddleIncEven; Complex TwiddleOdd = TwiddleIncOdd; for (uint r = 2; r < RADIX; r+=2) { Local[r] = ComplexMult(TwiddleEven, Local[r]); TwiddleEven = ComplexMult(TwiddleEven, TwiddleIncEven); Local[r+1] = ComplexMult(TwiddleOdd, Local[r+1]); TwiddleOdd = ComplexMult(TwiddleOdd, TwiddleIncOdd); } } void ButterflyRadix2(bool bIsForward, inout Complex Local[RADIX], uint ThreadIdx, uint Length) { float AngleStart = TWO_PI * float(ThreadIdx) / float(Length * RADIX); float AngleInc = TWO_PI * float(Length) / float(Length * RADIX); if (!bIsForward) { AngleStart *= -1; AngleInc *= -1; } Complex TwiddleInc; sincos(AngleInc, TwiddleInc.y, TwiddleInc.x); Complex Twiddle; sincos(AngleStart, Twiddle.y, Twiddle.x); for (uint r = 4; r < RADIX; ++r) { Local[r] = ComplexMult(Twiddle, Local[r]); Twiddle = ComplexMult(Twiddle, TwiddleInc); } } // Performs a single pass Stockham FFT using group shared memory. void GroupSharedFFT(in const bool bIsForward, inout Complex Local[RADIX], in const uint ArrayLength, in const uint ThreadIdx) { // Parallelism from writting the 1d array as a 2d row-major array: n = J *m + j. 0 <= j < m, 0 <= J < M // f(n) = f(J*m + j) = f(J,j) transforms to // F(k) = F(w*M + W) = f(w, W) 0 <= w < m, 0 <= K < M // // F(k) = sum_{n=0}^{N-1} \exp\left( 2\pi i \frac{n k}{N} \right) f(n) // // can be expressed as // // F(w*m + W) = sum_{j=0}^{m-1} \exp\left( 2 \pi i \frac{w j}{m}\right) \left[ \exp\left(2\pi i \frac{W j}{m M} \sum_{J=0}^{M-1} \exp(2\pi i \frac{J W}{M} f(J,j) \right] // Number of rows in 2d row-major layout of an array of lenght ArrayLength uint NumCols = ArrayLength / RADIX; //uint IdxS = Expand(j, NumCols, RADIX); //uint IdxS = (ThreadIdx / NumCols) * ArrayLength + (ThreadIdx % NumCols); uint IdxS = ThreadIdx; // uint Ns = 1; // (j / Ns) * Ns * R + (j % Ns); // Expand(j, Ns, RADIX); uint IdxD = ThreadIdx * RADIX; // Transform these RADIX values. RadixFFT(bIsForward, Local); // Exchange data with other threads TransposeData(Local, IdxD, 1, IdxS, NumCols); // NumCols = ArrayLength / RADIX: NumCols is an int power of RADIX (ie. 0, 1, 2..) uint Ns = RADIX; #if COMPILER_PSSL LOOP #endif for ( ; Ns < NumCols; Ns *= RADIX ) { Butterfly(bIsForward, Local, ThreadIdx, Ns); // IdxD(ThreadIdx) is contiguous blocks of lenght Ns IdxD = Expand(ThreadIdx, Ns, RADIX); //uint Expand(in uint j, in uint Ns, in uint R) { return (j / Ns) * Ns * R + (j % Ns);} // Transform these RADIX values. RadixFFT(bIsForward, Local); // Wait for the rest of the thread group to finish working on this ArrayLength of data. FFTMemoryBarrier(); TransposeData(Local, IdxD, Ns, IdxS, NumCols); } #if MIXED_RADIX == 0 || SCAN_LINE_LENGTH == 4096 || SCAN_LINE_LENGTH == 512 || SCAN_LINE_LENGTH == 64 || SCAN_LINE_LENGTH <= 8 Butterfly(bIsForward, Local, ThreadIdx, Ns); // Transform these RADIX values. RadixFFT(bIsForward, Local); #elif SCAN_LINE_LENGTH == 2048 || SCAN_LINE_LENGTH == 256 || SCAN_LINE_LENGTH == 32 ButterflyRadix4(bIsForward, Local, ThreadIdx, NumCols); Radix4FFT(bIsForward, Local[0], Local[2], Local[4], Local[6]); Radix4FFT(bIsForward, Local[1], Local[3], Local[5], Local[7]); #elif SCAN_LINE_LENGTH == 1024 || SCAN_LINE_LENGTH == 128 || SCAN_LINE_LENGTH == 16 ButterflyRadix2(bIsForward, Local, ThreadIdx, NumCols); Radix2FFT(bIsForward, Local[0], Local[4]); Radix2FFT(bIsForward, Local[1], Local[5]); Radix2FFT(bIsForward, Local[2], Local[6]); Radix2FFT(bIsForward, Local[3], Local[7]); #else #error Unknown SCAN_LINE_LENGTH #endif // not needed for the simple case FFTMemoryBarrier(); } // Scale two complex signals. void Scale(inout Complex LocalBuffer[2][RADIX], in float ScaleValue) { // Scale { for (uint r = 0; r < RADIX; ++r) { LocalBuffer[0][r] *= ScaleValue; } } { for (uint r = 0; r < RADIX; ++r) { LocalBuffer[1][r] *= ScaleValue; } } } // Transforming two Complex signals. void GroupSharedFFT(in bool bIsForward, inout Complex LocalBuffer[2][RADIX], in uint ArrayLength, in uint ThreadIdx) { // Note: The forward and inverse FFT require a 'normalization' scale, such that the normalization scale // of the forward times normalization scale of the inverse = 1 / ArrayLenght. // ForwardScale(ArrayLength) * InverseScale(ArrayLength) = 1 / ArrayLength; // Scale Forward if (bIsForward) { Scale(LocalBuffer, ForwardScale(ArrayLength)); } else { Scale(LocalBuffer, InverseScale(ArrayLength)); } // Transform each buffer. GroupSharedFFT(bIsForward, LocalBuffer[0], ArrayLength, ThreadIdx); GroupSharedFFT(bIsForward, LocalBuffer[1], ArrayLength, ThreadIdx); } // --------------------------------------------------------------------------------------------------------------------------------------- // Specialized Utilities used when transforming two real signals as one complex // The "Two-For-One" trick // Note these use group-shared memory to detangle the data // // Transforming 4 channels (r.g.b.a) as two complex signals (r+ig, b+ia) // and unpacks the 4 resulting 1/2 length complex transforms R, G, B, A // --------------------------------------------------------------------------------------------------------------------------------------- // The Two-For-One trick: FFT two real signals (f,g) of length N // are transformed simultaneously by treating them as a single complex signal z. // z_n = f_n + I * g_n. // // Giving the complex amplitudes (i.e. the transform) Z_k = F_k + I * G_k. // // The transforms of the real signals can be recovered as // // F_k = (1/2) (Z_k + ComplexConjugate(Z_{N-k})) // G_k = (-i/2)(Z_k - ComplexConjugate(Z_{N-k})) // // Furthermore, the transform of any real signal x_n has the symmetry // X_k = ComplexConjugate(X_{N-k}) // Separate the FFT of two real signals that were processed togehter // using the two-for-one trick. // // On input the data in LocalBuffer across the threads in this group // holds the transform of the complex signal z = f + I * g of lenght N. // // On return the transform of f and g have been separated and stored // in the LocalBuffer in a form consistent with the packing the x-signal transform forward, // followed by the y-signal transform in reverse. // // {F_o, G_o}, {F_1}, {F_2},..,{F_{N/2-1}}, {F_N/2, G_N/2}, {G_{N/2 +1}}, {G_{N/2 +2}}, .., {G_{N-2}}, {G_{N-1}} // 0 , 1 , 2 ,..., N/2-1, N/2, N/2 +1, N/2 +2, .., N/2- (2-N/2), N/2 - (1-N/2) // here {} := Complex #if SHARED_MEMORY_BUFFER_COUNT==2 void SplitTwoForOne(inout Complex LocalBuffer[RADIX], in uint Head, in uint Stride, in uint N) { const uint Non2 = N / 2; // On input, Each thread has, in LocalBuffer, the frequencies // K = Head + j * Stride; where j = 0, .., RADIX-1.. Head = ThreadIDx, Stride = NumberOfThreads // Write the complex FFT into group shared memory. CopyLocalToGroupSharedWSync(LocalBuffer, Head, Stride); // Construct the transform for the two real signals in the LocalBuffer { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // // K = N/2 - abs(SrcIdx - N/2) // DstK = SrcIdx % Non2; // N - k uint NmK = (K > 0) ? ( N - K) : 0; // Z_k = LocalBuffer[i] // If k < N/2 : Store F_k = 1/2 (Z_k + Z*_{N-k}) // If k > N/2 : Compute I*G_k = 1/2 (Z_k - Z*_{N-k}) // Tmp = {+,-}ComplexConjugate( Z_{N-k}) Complex Tmp = Complex( SharedReal[NmK], -SharedImag[NmK] ); Tmp *= (K > Non2)? -1 : 1; LocalBuffer[i] += Tmp; } } { for (uint i =0; i < RADIX; ++i) LocalBuffer[ i ] *= 0.5f; } { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // If k > N/2 get G_k from I*G_k: G_k = -I * (I G_k) if (K > Non2) LocalBuffer[i] = ComplexMult(Complex(0, -1), LocalBuffer[i] ); if (K == Non2) { // F_N/2 + I * G_N/2 LocalBuffer[i] = Complex(SharedReal[Non2], SharedImag[Non2]); } } } if (Head == 0) { // F_o + I * G_o LocalBuffer[0] = Complex(SharedReal[0], SharedImag[0]); } } #else // SHARED_MEMORY_COUNT == 1 void SplitTwoForOne(inout Complex LocalBuffer[RADIX], in uint Head, in uint Stride, in uint N) { const uint Non2 = N / 2; // On input, Each thread has, in LocalBuffer, the frequencies // K = Head + j * Stride; where j = 0, .., RADIX-1.. Head = ThreadIDx, Stride = NumberOfThreads // Write the complex FFT into group shared memory. CopyLocalXToGroupShared(LocalBuffer, Head, Stride); FFTMemoryBarrier(); // Construct the transform for the two real signals in the LocalBuffer { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // // K = N/2 - abs(SrcIdx - N/2) // DstK = SrcIdx % Non2; // N - k uint NmK = (K > 0) ? ( N - K) : 0; // Z_k = LocalBuffer[i] // If k < N/2 : Store F_k = 1/2 (Z_k + Z*_{N-k}) // If k > N/2 : Compute I*G_k = 1/2 (Z_k - Z*_{N-k}) // Tmp = {+,-}ComplexConjugate( Z_{N-k}) float Tmp = SharedReal[NmK]; Tmp *= (K > Non2)? -1 : 1; LocalBuffer[i].x += Tmp; } } if (Head == 0 ) LocalBuffer[0].x = 2.f * SharedReal[0]; FFTMemoryBarrier(); CopyLocalYToGroupShared(LocalBuffer, Head, Stride); FFTMemoryBarrier(); { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // // K = N/2 - abs(SrcIdx - N/2) // DstK = SrcIdx % Non2; // N - k uint NmK = (K > 0) ? ( N - K) : 0; // Z_k = LocalBuffer[i] // If k < N/2 : Store F_k = 1/2 (Z_k + Z*_{N-k}) // If k > N/2 : Compute I*G_k = 1/2 (Z_k - Z*_{N-k}) // Tmp = {+,-}ComplexConjugate( Z_{N-k}) float Tmp = -SharedReal[NmK]; Tmp *= (K < Non2)? 1 : -1; LocalBuffer[i].y += Tmp; //LocalBuffer[i] *= 0.5; } } if (Head == 0) LocalBuffer[0].y = 2.f * SharedReal[0]; { UNROLL for (uint i = 0; i < RADIX; ++i) LocalBuffer[i] *= 0.5; } { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // If k > N/2 get G_k from I*G_k: G_k = -I * (I G_k) if (K > Non2) LocalBuffer[i] = ComplexMult(Complex(0, -1), LocalBuffer[i] ); } } } #endif // SHARED_MEMORY_BUFFER_COUNT void SplitTwoForOne(inout Complex LocalBuffer[2][RADIX], in uint Head, in uint Stride, in uint ArrayLength) { SplitTwoForOne(LocalBuffer[ 0 ], Head, Stride, ArrayLength); FFTMemoryBarrier(); SplitTwoForOne(LocalBuffer[ 1 ], Head, Stride, ArrayLength); } // Inverse of SplitTwoForOne() // Combine the FFT of two real signals (f,g) into a single complex signal // for use in inverting a "two-for-one" FFT . // // // On return the transform of f and g have been separated and stored // in the LocalBuffer in a form consistent with the packing the x-signal transform forward, // followed by the y-signal transform in reverse. // // {F_o, G_o}, {F_1}, {F_2},..,{F_{N/2-1}}, {F_N/2, G_N/2}, {G_{N/2 +1}}, {G_{N/2 +2}}, .., { G_{N-2}}, {G_{N-1}} // 0 , 1 , 2 ,..., N/2-1, N/2, N/2 +1, N/2 +2, .., N/2- (2-N/2), N/2 - (1-N/2) // here {} := Complex // // On return the data in LocalBuffer across the threads in this group // hold the transform of the complex signal z = f + ig of lenght N. // #if SHARED_MEMORY_BUFFER_COUNT==2 void MergeTwoForOne(inout Complex LocalBuffer[RADIX], in uint Head, in uint Stride, in uint N) { int Non2 = N / 2; // Write the two FFTs into shared memory. CopyLocalToGroupSharedWSync(LocalBuffer, Head, Stride); // Compose the transform of a single f+ig signal from the two real signals in the LocalBuffer { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // N - k uint NmK = (K > 0) ? (N - K) : 0 ; // If k < N/2 : LocalBuffer[i] = F_k, // Shared[N-K] = G[N-K] = Conjugate(G[k]) = (G[k]_r, -G[k]_i) = Complex(G_r, -G_i) // want I G[k] = float(-G[k]_i, G[k]_r) // Tmp = I * G_k // If k > N/2 : LocalBuffer[i] = G_k // Shared[N-K] = F[N-k] = Conjugate(F[k]) = Complex(F[k]_r, -F[k]_i) // want -I F[k] = Complex(F[k]_i, -F[k]_r) // Tmp = -I * F // ComplexConjugate( Z_{N-k}) // Tmp = (K < Non2) ? I * G_k : -I * F Complex Tmp = Complex( SharedImag[ NmK ], SharedReal[ NmK ] ); Tmp *= (K > Non2) ? -1 : 1; LocalBuffer[i] += Tmp; } } // Compose the transform of a single f+ig signal from the two real signals in the LocalBuffer { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // if k > N/2 we have G - I * F. Multiply by I to get F + I * G if (K > Non2) LocalBuffer[ i ] = ComplexMult(Complex(0, 1), LocalBuffer[ i ]); if (K == Non2) { // F_N/2 + I * G_N/2 LocalBuffer[ i ] = Complex(SharedReal[ Non2 ], SharedImag[ Non2 ]); } } } if (Head == 0) { // F_0 + I * G_0 LocalBuffer[ 0 ] = Complex(SharedReal[ 0 ], SharedImag[ 0 ]); } } #else // SHARED_MEMORY_BUFFER_COUNT==1 void MergeTwoForOne(inout Complex LocalBuffer[RADIX], in uint Head, in uint Stride, in uint N) { uint Non2 = N / 2; // Write the two FFTs into shared memory. float TmpX[RADIX]; { for (uint i = 0; i < RADIX; ++i) TmpX[i] = LocalBuffer[i].x; } CopyLocalYToGroupShared(LocalBuffer, Head, Stride); FFTMemoryBarrier(); // Compose the transform of a single f+ig signal from the two real signals in the LocalBuffer { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // N - k uint NmK = (K > 0) ? (N - K) : 0 ; // If k < N/2 : LocalBuffer[i] = F_k, // Shared[N-K] = G[N-K] = Conjugate(G[k]) = (G[k]_r, -G[k]_i) = Complex(G_r, -G_i) // want I G[k] = float(-G[k]_i, G[k]_r) // Tmp = I * G_k // If k > N/2 : LocalBuffer[i] = G_k // Shared[N-K] = F[N-k] = Conjugate(F[k]) = Complex(F[k]_r, -F[k]_i) // want -I F[k] = Complex(F[k]_i, -F[k]_r) // Tmp = -I * F // ComplexConjugate( Z_{N-k}) // Tmp = (K < Non2) ? I * G_k : -I * F float Tmp = SharedReal[ NmK ]; Tmp *= (K > Non2) ? -1 : 1; LocalBuffer[i].x += Tmp; } } float2 FirstElement = float2(0, SharedReal[0]); float2 MiddleElement = float2(0, SharedReal[Non2]); FFTMemoryBarrier(); // Copy TmpX to GroupShared UNROLL for (uint r = 0, i = Head; r < RADIX; ++r, i += Stride) { SharedReal[ i ] = TmpX[ r ]; } FFTMemoryBarrier(); FirstElement.x = SharedReal[0]; MiddleElement.x = SharedReal[Non2]; { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // N - k uint NmK = (K > 0) ? (N - K) : 0 ; // If k < N/2 : LocalBuffer[i] = F_k, // Shared[N-K] = G[N-K] = Conjugate(G[k]) = (G[k]_r, -G[k]_i) = Complex(G_r, -G_i) // want I G[k] = float(-G[k]_i, G[k]_r) // Tmp = I * G_k // If k > N/2 : LocalBuffer[i] = G_k // Shared[N-K] = F[N-k] = Conjugate(F[k]) = Complex(F[k]_r, -F[k]_i) // want -I F[k] = Complex(F[k]_i, -F[k]_r) // Tmp = -I * F // ComplexConjugate( Z_{N-k}) // Tmp = (K < Non2) ? I * G_k : -I * F float Tmp = SharedReal[ NmK ]; Tmp *= (K > Non2) ? -1 : 1; LocalBuffer[i].y += Tmp; } } { UNROLL for (uint i = 0, K = Head; i < RADIX; ++i, K += Stride) { // if k > N/2 we have G - I * F. Multiply by I to get F + I * G if (K > Non2) LocalBuffer[ i ] = ComplexMult(Complex(0, 1), LocalBuffer[ i ]); if (K == Non2) { // F_N/2 + I * G_N/2 LocalBuffer[ i ] = MiddleElement; } } } if (Head == 0) LocalBuffer[ 0 ] = FirstElement; } #endif //SHARED_MEMORY_BUFFER_COUNT void MergeTwoForOne(inout Complex LocalBuffer[2][RADIX], in uint Head, in uint Stride, in uint ArrayLength) { MergeTwoForOne(LocalBuffer[ 0 ], Head, Stride, ArrayLength); FFTMemoryBarrier(); MergeTwoForOne(LocalBuffer[ 1 ], Head, Stride, ArrayLength); }