diff --git a/kissfft.hh b/kissfft.hh index 3c41213..96fdcb3 100644 --- a/kissfft.hh +++ b/kissfft.hh @@ -5,14 +5,12 @@ #include -template - > +template class kissfft { public: - typedef T_Scalar scalar_type; - typedef T_Complex cpx_type; + + using cpx_t = std::complex; kissfft( std::size_t nfft, bool inverse ) @@ -21,9 +19,9 @@ class kissfft { // fill twiddle factors _twiddles.resize(_nfft); - const scalar_type phinc = (_inverse?2:-2)* acos( (scalar_type) -1) / _nfft; + const scalar_t phinc = (_inverse?2:-2)* acos( (scalar_t) -1) / _nfft; for (std::size_t i=0;i<_nfft;++i) - _twiddles[i] = exp( cpx_type(0,i*phinc) ); + _twiddles[i] = exp( cpx_t(0,i*phinc) ); //factorize //start factoring out 4's, then 2's, then 3,5,7,9,... @@ -43,7 +41,7 @@ class kissfft _stageRadix.push_back(p); _stageRemainder.push_back(n); }while(n>1); - } + }; /// Changes the FFT-length and/or the transform direction. @@ -63,11 +61,11 @@ class kissfft else if ( inverse != _inverse ) { // conjugate the twiddle factors. - for ( typename std::vector::iterator it = _twiddles.begin(); + for ( typename std::vector::iterator it = _twiddles.begin(); it != _twiddles.end(); ++it ) it->imag( -it->imag() ); } - } + }; /// Calculates the complex Discrete Fourier Transform. /// @@ -81,11 +79,40 @@ class kissfft /// constructor. Hence when applying the same transform twice, but with /// the inverse flag changed the second time, then the result will /// be equal to the original input times @c N. - void transform( const cpx_type * src, - cpx_type * dst ) const + void transform(const cpx_t * fft_in, cpx_t * fft_out, std::size_t stage = 0, std::size_t fstride = 1, std::size_t in_stride = 1) const { - kf_work(0, dst, src, 1,1); - } + const std::size_t p = _stageRadix[stage]; + const std::size_t m = _stageRemainder[stage]; + cpx_t * const Fout_beg = fft_out; + cpx_t * const Fout_end = fft_out + p*m; + + if (m==1) { + do{ + *fft_out = *fft_in; + fft_in += fstride*in_stride; + }while(++fft_out != Fout_end ); + }else{ + do{ + // recursive call: + // DFT of size m*p performed by doing + // p instances of smaller DFTs of size m, + // each one takes a decimated version of the input + transform(fft_in, fft_out, stage+1, fstride*p,in_stride); + fft_in += fstride*in_stride; + }while( (fft_out += m) != Fout_end ); + } + + fft_out=Fout_beg; + + // recombine the p smaller DFTs + switch (p) { + case 2: kf_bfly2(fft_out,fstride,m); break; + case 3: kf_bfly3(fft_out,fstride,m); break; + case 4: kf_bfly4(fft_out,fstride,m); break; + case 5: kf_bfly5(fft_out,fstride,m); break; + default: kf_bfly_generic(fft_out,fstride,m,p); break; + } + }; /// Calculates the Discrete Fourier Transform (DFT) of a real input /// of size @c 2*N. @@ -101,48 +128,48 @@ class kissfft /// @endcode /// The same scaling factors as in @c transform() apply. /// - /// @note For this to work, the types @c scalar_type and @c cpx_type + /// @note For this to work, the types @c scalar_t and @c cpx_t /// must fulfill the following requirements: /// - /// For any object @c z of type @c cpx_type, - /// @c reinterpret_cast(z)[0] is the real part of @c z and - /// @c reinterpret_cast(z)[1] is the imaginary part of @c z. - /// For any pointer to an element of an array of @c cpx_type named @c p + /// For any object @c z of type @c cpx_t, + /// @c reinterpret_cast(z)[0] is the real part of @c z and + /// @c reinterpret_cast(z)[1] is the imaginary part of @c z. + /// For any pointer to an element of an array of @c cpx_t named @c p /// and any valid array index @c i, @c reinterpret_cast(p)[2*i] /// is the real part of the complex number @c p[i], and /// @c reinterpret_cast(p)[2*i+1] is the imaginary part of the /// complex number @c p[i]. /// /// Since C++11, these requirements are guaranteed to be satisfied for - /// @c scalar_types being @c float, @c double or @c long @c double - /// together with @c cpx_type being @c std::complex. - void transform_real( const scalar_type * src, - cpx_type * dst ) const + /// @c scalar_ts being @c float, @c double or @c long @c double + /// together with @c cpx_t being @c std::complex. + void transform_real( const scalar_t * src, + cpx_t * dst ) const { const std::size_t N = _nfft; if ( N == 0 ) return; // perform complex FFT - transform( reinterpret_cast(src), dst ); + transform( reinterpret_cast(src), dst ); // post processing for k = 0 and k = N - dst[0] = cpx_type( dst[0].real() + dst[0].imag(), + dst[0] = cpx_t( dst[0].real() + dst[0].imag(), dst[0].real() - dst[0].imag() ); // post processing for all the other k = 1, 2, ..., N-1 - const scalar_type pi = acos( (scalar_type) -1); - const scalar_type half_phi_inc = ( _inverse ? pi : -pi ) / N; - const cpx_type twiddle_mul = exp( cpx_type(0, half_phi_inc) ); + const scalar_t pi = acos( (scalar_t) -1); + const scalar_t half_phi_inc = ( _inverse ? pi : -pi ) / N; + const cpx_t twiddle_mul = exp( cpx_t(0, half_phi_inc) ); for ( std::size_t k = 1; 2*k < N; ++k ) { - const cpx_type w = (scalar_type)0.5 * cpx_type( + const cpx_t w = (scalar_t)0.5 * cpx_t( dst[k].real() + dst[N-k].real(), dst[k].imag() - dst[N-k].imag() ); - const cpx_type z = (scalar_type)0.5 * cpx_type( + const cpx_t z = (scalar_t)0.5 * cpx_t( dst[k].imag() + dst[N-k].imag(), -dst[k].real() + dst[N-k].real() ); - const cpx_type twiddle = + const cpx_t twiddle = k % 2 == 0 ? _twiddles[k/2] : _twiddles[k/2] * twiddle_mul; @@ -151,87 +178,26 @@ class kissfft } if ( N % 2 == 0 ) dst[N/2] = conj( dst[N/2] ); - } + }; private: - void kf_work( std::size_t stage, - cpx_type * Fout, - const cpx_type * f, - std::size_t fstride, - std::size_t in_stride) const - { - const std::size_t p = _stageRadix[stage]; - const std::size_t m = _stageRemainder[stage]; - cpx_type * const Fout_beg = Fout; - cpx_type * const Fout_end = Fout + p*m; - if (m==1) { - do{ - *Fout = *f; - f += fstride*in_stride; - }while(++Fout != Fout_end ); - }else{ - do{ - // recursive call: - // DFT of size m*p performed by doing - // p instances of smaller DFTs of size m, - // each one takes a decimated version of the input - kf_work(stage+1, Fout , f, fstride*p,in_stride); - f += fstride*in_stride; - }while( (Fout += m) != Fout_end ); - } - - Fout=Fout_beg; - - // recombine the p smaller DFTs - switch (p) { - case 2: kf_bfly2(Fout,fstride,m); break; - case 3: kf_bfly3(Fout,fstride,m); break; - case 4: kf_bfly4(Fout,fstride,m); break; - case 5: kf_bfly5(Fout,fstride,m); break; - default: kf_bfly_generic(Fout,fstride,m,p); break; - } - } - - void kf_bfly2( cpx_type * Fout, const size_t fstride, std::size_t m) const + void kf_bfly2( cpx_t * Fout, const size_t fstride, std::size_t m) const { for (std::size_t k=0;k _twiddles; - std::vector _stageRadix; - std::vector _stageRemainder; + std::size_t _nfft; + bool _inverse; + std::vector _twiddles; + std::vector _stageRadix; + std::vector _stageRemainder; }; #endif