diff --git a/kissfft.hh b/kissfft.hh index 4f6ac92..9a7fb7c 100644 --- a/kissfft.hh +++ b/kissfft.hh @@ -1,87 +1,169 @@ #ifndef KISSFFT_CLASS_HH #define KISSFFT_CLASS_HH #include +#include #include -namespace kissfft_utils { - -template -struct traits -{ - typedef T_scalar scalar_type; - typedef std::complex cpx_type; - void fill_twiddles( std::complex * dst ,int nfft,bool inverse) - { - T_scalar phinc = (inverse?2:-2)* acos( (T_scalar) -1) / nfft; - for (int i=0;i(0,i*phinc) ); - } - - void prepare( - std::vector< std::complex > & dst, - int nfft,bool inverse, - std::vector & stageRadix, - std::vector & stageRemainder ) - { - _twiddles.resize(nfft); - fill_twiddles( &_twiddles[0],nfft,inverse); - dst = _twiddles; - - //factorize - //start factoring out 4's, then 2's, then 3,5,7,9,... - int n= nfft; - int p=4; - do { - while (n % p) { - switch (p) { - case 4: p = 2; break; - case 2: p = 3; break; - default: p += 2; break; - } - if (p*p>n) - p=n;// no more factors - } - n /= p; - stageRadix.push_back(p); - stageRemainder.push_back(n); - }while(n>1); - } - std::vector _twiddles; - - - const cpx_type twiddle(int i) { return _twiddles[i]; } -}; - -} template + typename T_Complex=std::complex > class kissfft { public: - typedef T_traits traits_type; - typedef typename traits_type::scalar_type scalar_type; - typedef typename traits_type::cpx_type cpx_type; + typedef T_Scalar scalar_type; + typedef T_Complex cpx_type; - kissfft(int nfft,bool inverse,const traits_type & traits=traits_type() ) - :_nfft(nfft),_inverse(inverse),_traits(traits) + kissfft( std::size_t nfft, + bool inverse ) + :_nfft(nfft) + ,_inverse(inverse) { - _traits.prepare(_twiddles, _nfft,_inverse ,_stageRadix, _stageRemainder); + // fill twiddle factors + _twiddles.resize(_nfft); + const scalar_type phinc = (_inverse?2:-2)* acos( (scalar_type) -1) / _nfft; + for (std::size_t i=0;i<_nfft;++i) + _twiddles[i] = exp( cpx_type(0,i*phinc) ); + + //factorize + //start factoring out 4's, then 2's, then 3,5,7,9,... + std::size_t n= _nfft; + std::size_t p=4; + do { + while (n % p) { + switch (p) { + case 4: p = 2; break; + case 2: p = 3; break; + default: p += 2; break; + } + if (p*p>n) + p = n;// no more factors + } + n /= p; + _stageRadix.push_back(p); + _stageRemainder.push_back(n); + }while(n>1); } - void transform(const cpx_type * src , cpx_type * dst) + + /// Changes the FFT-length and/or the transform direction. + /// + /// @post The @c kissfft object will be in the same state as if it + /// had been newly constructed with the passed arguments. + /// However, the implementation may be faster than constructing a + /// new fft object. + void assign( std::size_t nfft, + bool inverse ) + { + if ( nfft != _nfft ) + { + kissfft tmp( nfft, inverse ); // O(n) time. + std::swap( tmp, *this ); // this is O(1) in C++11, O(n) otherwise. + } + else if ( inverse != _inverse ) + { + // conjugate the twiddle factors. + for ( typename std::vector::iterator it = _twiddles.begin(); + it != _twiddles.end(); ++it ) + it->imag( -it->imag() ); + } + } + + /// Calculates the complex Discrete Fourier Transform. + /// + /// The size of the passed arrays must be passed in the constructor. + /// The sum of the squares of the absolute values in the @c dst + /// array will be @c N times the sum of the squares of the absolute + /// values in the @c src array, where @c N is the size of the array. + /// In other words, the l_2 norm of the resulting array will be + /// @c sqrt(N) times as big as the l_2 norm of the input array. + /// This is also the case when the inverse flag is set in the + /// constructor. Hence when applying the same transform twice, but with + /// the inverse flag changed the second time, then the result will + /// be equal to the original input times @c N. + void transform( const cpx_type * src, + cpx_type * dst ) const { kf_work(0, dst, src, 1,1); } - private: - void kf_work( int stage,cpx_type * Fout, const cpx_type * f, size_t fstride,size_t in_stride) + /// Calculates the Discrete Fourier Transform (DFT) of a real input + /// of size @c 2*N. + /// + /// The 0-th and N-th value of the DFT are real numbers. These are + /// stored in @c dst[0].real() and @c dst[1].imag() respectively. + /// The remaining DFT values up to the index N-1 are stored in + /// @c dst[1] to @c dst[N-1]. + /// The other half of the DFT values can be calculated from the + /// symmetry relation + /// @code + /// DFT(src)[2*N-k] == conj( DFT(src)[k] ); + /// @endcode + /// The same scaling factors as in @c transform() apply. + /// + /// @note For this to work, the types @c scalar_type and @c cpx_type + /// must fulfill the following requirements: + /// + /// For any object @c z of type @c cpx_type, + /// @c reinterpret_cast(z)[0] is the real part of @c z and + /// @c reinterpret_cast(z)[1] is the imaginary part of @c z. + /// For any pointer to an element of an array of @c cpx_type named @c p + /// and any valid array index @c i, @c reinterpret_cast(p)[2*i] + /// is the real part of the complex number @c p[i], and + /// @c reinterpret_cast(p)[2*i+1] is the imaginary part of the + /// complex number @c p[i]. + /// + /// Since C++11, these requirements are guaranteed to be satisfied for + /// @c scalar_types being @c float, @c double or @c long @c double + /// together with @c cpx_type being @c std::complex. + void transform_real( const scalar_type * src, + cpx_type * dst ) const { - int p = _stageRadix[stage]; - int m = _stageRemainder[stage]; - cpx_type * Fout_beg = Fout; - cpx_type * Fout_end = Fout + p*m; + const std::size_t N = _nfft; + if ( N == 0 ) + return; + + // perform complex FFT + transform( reinterpret_cast(src), dst ); + + // post processing for k = 0 and k = N + dst[0] = cpx_type( dst[0].real() + dst[0].imag(), + dst[0].real() - dst[0].imag() ); + + // post processing for all the other k = 1, 2, ..., N-1 + const scalar_type pi = acos( (scalar_type) -1); + const scalar_type half_phi_inc = ( _inverse ? pi : -pi ) / N; + const cpx_type twiddle_mul = exp( cpx_type(0, half_phi_inc) ); + for ( std::size_t k = 1; 2*k < N; ++k ) + { + const cpx_type w = 0.5 * cpx_type( + dst[k].real() + dst[N-k].real(), + dst[k].imag() - dst[N-k].imag() ); + const cpx_type z = 0.5 * cpx_type( + dst[k].imag() + dst[N-k].imag(), + -dst[k].real() + dst[N-k].real() ); + const cpx_type twiddle = + k % 2 == 0 ? + _twiddles[k/2] : + _twiddles[k/2] * twiddle_mul; + dst[ k] = w + twiddle * z; + dst[N-k] = conj( w - twiddle * z ); + } + if ( N % 2 == 0 ) + dst[N/2] = conj( dst[N/2] ); + } + + private: + void kf_work( std::size_t stage, + cpx_type * Fout, + const cpx_type * f, + std::size_t fstride, + std::size_t in_stride) const + { + const std::size_t p = _stageRadix[stage]; + const std::size_t m = _stageRemainder[stage]; + cpx_type * const Fout_beg = Fout; + cpx_type * const Fout_end = Fout + p*m; if (m==1) { do{ @@ -111,92 +193,75 @@ class kissfft } } - // these were #define macros in the original kiss_fft - void C_ADD( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a+b;} - void C_MUL( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a*b;} - void C_SUB( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a-b;} - void C_ADDTO( cpx_type & c,const cpx_type & a) { c+=a;} - void C_FIXDIV( cpx_type & ,int ) {} // NO-OP for float types - scalar_type S_MUL( const scalar_type & a,const scalar_type & b) { return a*b;} - scalar_type HALF_OF( const scalar_type & a) { return a*.5;} - void C_MULBYSCALAR(cpx_type & c,const scalar_type & a) {c*=a;} - - void kf_bfly2( cpx_type * Fout, const size_t fstride, int m) + void kf_bfly2( cpx_type * Fout, const size_t fstride, std::size_t m) const { - for (int k=0;kreal() - HALF_OF(scratch[3].real() ) , Fout->imag() - HALF_OF(scratch[3].imag() ) ); + Fout[m] = Fout[0] - scratch[3]*scalar_type(0.5); + scratch[0] *= epi3.imag(); - C_MULBYSCALAR( scratch[0] , epi3.imag() ); - - C_ADDTO(*Fout,scratch[3]); + Fout[0] += scratch[3]; Fout[m2] = cpx_type( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() ); - C_ADDTO( Fout[m] , cpx_type( -scratch[0].imag(),scratch[0].real() ) ); + Fout[m] += cpx_type( -scratch[0].imag(),scratch[0].real() ); ++Fout; }while(--k); } - void kf_bfly5( cpx_type * Fout, const size_t fstride, const size_t m) + void kf_bfly5( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const { cpx_type *Fout0,*Fout1,*Fout2,*Fout3,*Fout4; - size_t u; cpx_type scratch[13]; - cpx_type * twiddles = &_twiddles[0]; - cpx_type *tw; - cpx_type ya,yb; - ya = twiddles[fstride*m]; - yb = twiddles[fstride*2*m]; + const cpx_type ya = _twiddles[fstride*m]; + const cpx_type yb = _twiddles[fstride*2*m]; Fout0=Fout; Fout1=Fout0+m; @@ -204,52 +269,54 @@ class kissfft Fout3=Fout0+3*m; Fout4=Fout0+4*m; - tw=twiddles; - for ( u=0; u=Norig) twidx-=Norig; - C_MUL(t,scratchbuf[q] , twiddles[twidx] ); - C_ADDTO( Fout[ k ] ,t); + if (twidx>=_nfft) + twidx-=_nfft; + Fout[ k ] += scratchbuf[q] * twiddles[twidx]; } k += m; } } } - int _nfft; + std::size_t _nfft; bool _inverse; std::vector _twiddles; - std::vector _stageRadix; - std::vector _stageRemainder; - traits_type _traits; + std::vector _stageRadix; + std::vector _stageRemainder; }; #endif