cp-includes

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub rsalesc/cp-includes

:warning: FFT.cpp

Depends on

Code

#ifndef _LIB_FFT
#define _LIB_FFT
#include "DFT.cpp"
#include "Complex.cpp"
#include "geometry/Trigonometry.cpp"
#include <bits/stdc++.h>

namespace lib {
using namespace std;
namespace linalg {

template<typename T>
struct ComplexRootProvider {
  typedef Complex<T> cd;
  typedef Complex<long double> cld;
  static vector<cd> w;
  static vector<cld> wl;

  static cld root(long double ang) {
    return cld(geo::trig::cos(ang), geo::trig::sin(ang));
  }

  cd operator()(int n, int k) {
    long double ang = 2.0l * geo::trig::PI / (n / k);
    return root(ang);
  }
  void operator()(int n) {
    n = max(n, 2);
    int k = max((int)w.size(), 2);
    if ((int)w.size() < n)
      w.resize(n), wl.resize(n);
    else
      return;
    w[0] = w[1] = cd(1.0, 0.0);
    wl[0] = wl[1] = cld(1.0, 0.0);
    for (; k < n; k *= 2) {
      long double ang = 2.0l * geo::trig::PI / (2*k);
      cld step = root(ang);
      for(int i = k; i < 2*k; i++)
        w[i] = wl[i] = (i&1) ? wl[i/2] * step : wl[i/2];
    }
  }
  cd operator[](int i) {
    return w[i];
  }
  cd inverse(int n) {
    return cd(1.0 / n, 0.0);
  }
};

template<typename T>
vector<Complex<T>> ComplexRootProvider<T>::w = vector<Complex<T>>();
template<typename T>
vector<Complex<long double>> ComplexRootProvider<T>::wl = vector<Complex<long double>>();

template<typename T = double>
struct FFT : public DFT<Complex<T>, ComplexRootProvider<T>> {
  using Parent = DFT<Complex<T>, ComplexRootProvider<T>>;
  using Parent::fa;

  template <typename U>
  static void _convolve(const vector<U> &a, const vector<U> &b) {
    typedef Complex<T> cd;
    int n = Parent::ensure(a.size(), b.size());
    for (size_t i = 0; i < (size_t)n; i++)
      fa[i] = cd(i < a.size() ? (T)a[i] : T(),
                 i < b.size() ? (T)b[i] : T());
    Parent::dft(n);
    for (int i = 0; i < n; i++)
      fa[i] *= fa[i];
    Parent::idft(n);
    for (int i = 0; i < n; i++)
      fa[i] = cd(fa[i].imag() / 2, T());
  }

  template<typename U>
  static vector<U> convolve(const vector<U>& a, const vector<U>& b) {
    int sz = (int)a.size() + b.size() - 1;
    _convolve(a, b);
    return retrieve<Parent, U>(sz);
  }

  template<typename U>
  static vector<U> convolve_rounded(const vector<U>& a, const vector<U>& b) {
    int sz = (int)a.size() + b.size() - 1;
    _convolve(a, b);
    vector<U> res(sz);
    for(int i = 0; i < sz; i++) res[i] = (U)(long long)(fa[i].real() + 0.5);
    return res;
  }

  // TODO: use separate static buffers for this function
  template <typename M>
  static vector<M> convolve_mod(const vector<M> &a, const vector<M> &b) {
    typedef typename M::type_int type_int;
    typedef typename M::large_int large_int;
    typedef Complex<T> cd;
    typedef vector<cd> vcd;

    static_assert(sizeof(M::mods) / sizeof(type_int) == 1,
                  "cant multiply with multiple mods");
    type_int base = sqrtl(M::mods[0]) + 0.5;
    M base_m = base;
    int sza = a.size();
    int szb = b.size();
    int sz = sza+szb-1;
    int n = next_power_of_two(sz);
    Parent::dft_rev(n);

    // establish buffers
    vcd fa(n), fb(n), C1(n), C2(n);

    for (int i = 0; i < n; i++)
      fa[i] = i < sza ? cd((type_int)a[i] / base, (type_int)a[i] % base) : cd();
    for (int i = 0; i < n; i++)
      fb[i] = i < szb ? cd((type_int)b[i] / base, (type_int)b[i] % base) : cd();
    Parent::dft(fa, n);
    Parent::dft(fb, n);

    for (int i = 0; i < n; i++) {
      int j = i ? n - i : 0;
      cd a1 = (fa[i] + fa[j].conj()) * cd(0.5, 0.0);
      cd a2 = (fa[i] - fa[j].conj()) * cd(0.0, -0.5);
      cd b1 = (fb[i] + fb[j].conj()) * cd(0.5, 0.0);
      cd b2 = (fb[i] - fb[j].conj()) * cd(0.0, -0.5);
      cd c11 = a1 * b1, c12 = a1 * b2;
      cd c21 = a2 * b1, c22 = a2 * b2;
      C1[j] = c11 + c12 * cd(0.0, 1.0);
      C2[j] = c21 + c22 * cd(0.0, 1.0);
    }
    Parent::idft(C1, n), Parent::idft(C2, n);

    vector<M> res(sz);
    for (int i = 0; i < sz; i++) {
      int j = i ? n - i : 0;
      M x = large_int(C1[j].real() + 0.5);
      M y1 = large_int(C1[j].imag() + 0.5);
      M y2 = large_int(C2[j].real() + 0.5);
      M z = large_int(C2[j].imag() + 0.5);
      res[i] = x * base_m * base_m + (y1 + y2) * base_m + z;
    }

    return res;
  }
};
} // namespace linalg

namespace math {
struct FastMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve_rounded(a, b);
  }
};

struct FFTMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve(a, b);
  }
};

struct SafeMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve_mod(a, b);
  };
};
} // namespace math
} // namespace lib

#endif
#line 1 "FFT.cpp"


#line 1 "DFT.cpp"


#include <bits/stdc++.h>
#line 1 "BitTricks.cpp"


#line 4 "BitTricks.cpp"

namespace lib {
long long next_power_of_two(long long n) {
  if (n <= 0) return 1;
  return 1LL << (sizeof(long long) * 8 - 1 - __builtin_clzll(n) +
                 ((n & (n - 1LL)) != 0));
}
} // namespace lib


#line 5 "DFT.cpp"

namespace lib {
using namespace std;
namespace linalg {
template <typename Ring, typename Provider>
struct DFT {
  static vector<int> rev;
  static vector<Ring> fa;

  // function used to precompute rev for fixed size fft (n is a power of two)
  static void dft_rev(int n) {
    Provider()(n);
    int lbn = __builtin_ctz(n);
    if ((int)rev.size() < (1 << lbn))
      rev.resize(1 << lbn);
    int h = -1;
    for (int i = 1; i < n; i++) {
      if ((i & (i - 1)) == 0)
        h++;
      rev[i] = rev[i ^ (1 << h)] | (1 << (lbn - h - 1));
    }
  }

  static void dft_iter(Ring *p, int n) {
    Provider w;
    for (int L = 2; L <= n; L <<= 1) {
      for (int i = 0; i < n; i += L) {
        for (int j = 0; j < L / 2; j++) {
          Ring z = p[i + j + L / 2] * w[j + L / 2];
          p[i + j + L / 2] = p[i + j] - z;
          p[i + j] += z;
        }
      }
    }
  }

  static void swap(vector<Ring> &buf) { std::swap(fa, buf); }
  static void _dft(Ring *p, int n) {
    dft_rev(n);
    for (int i = 0; i < n; i++)
      if (i < rev[i])
        std::swap(p[i], p[rev[i]]);
    dft_iter(p, n);
  }
  static void _idft(Ring *p, int n) {
    _dft(p, n);
    reverse(p + 1, p + n);
    Ring inv = Provider().inverse(n);
    for (int i = 0; i < n; i++)
      p[i] *= inv;
  }

  static void dft(int n) { _dft(fa.data(), n); }

  static void idft(int n) { _idft(fa.data(), n); }

  static void dft(vector<Ring> &v, int n) {
    swap(v);
    dft(n);
    swap(v);
  }
  static void idft(vector<Ring> &v, int n) {
    swap(v);
    idft(n);
    swap(v);
  }

  static int ensure(int a, int b = 0) {
    int n = a+b;
    n = next_power_of_two(n);
    if ((int)fa.size() < n)
      fa.resize(n);
    return n;
  }

  static void clear(int n) { fill(fa.begin(), fa.begin() + n, 0); }

  template<typename Iterator>
  static void fill(Iterator begin, Iterator end) {
    int n = ensure(distance(begin, end));
    int i = 0;
    for(auto it = begin; it != end; ++it) {
      fa[i++] = *it;
    }
    for(;i < n; i++) fa[i] = Ring();
  }
};

template<typename DF, typename U>
static vector<U> retrieve(int n) {
  assert(n <= DF::fa.size());
  vector<U> res(n);
  for(int i = 0; i < n; i++) res[i] = (U)DF::fa[i];
  return res;
}

template<typename Ring, typename Provider>
vector<int> DFT<Ring, Provider>::rev = vector<int>();

template<typename Ring, typename Provider>
vector<Ring> DFT<Ring, Provider>::fa = vector<Ring>();
}
} // namespace lib


#line 1 "Complex.cpp"


#line 4 "Complex.cpp"

namespace lib {
using namespace std;
template <typename T> struct Complex {
  T re, im;
  Complex(T a = T(), T b = T()) : re(a), im(b) {}
  T real() const { return re; }
  T imag() const { return im; }
  explicit operator T() const { return re; }
  template<typename G>
  operator Complex<G>() const { return Complex<G>(re, im); }
  Complex conj() const { return Complex(re, -im); }
  void operator+=(const Complex<T> &rhs) { re += rhs.re, im += rhs.im; }
  void operator-=(const Complex<T> &rhs) { re -= rhs.re, im -= rhs.im; }
  void operator*=(const Complex<T> &rhs) {
    tie(re, im) =
        make_pair(re * rhs.re - im * rhs.im, re * rhs.im + im * rhs.re);
  }
  Complex<T> operator+(const Complex<T> &rhs) {
    Complex<T> res = *this;
    res += rhs;
    return res;
  }
  Complex<T> operator-(const Complex<T> &rhs) {
    Complex<T> res = *this;
    res -= rhs;
    return res;
  }
  Complex<T> operator*(const Complex<T> &rhs) {
    Complex<T> res = *this;
    res *= rhs;
    return res;
  }
  Complex<T> operator-() const {
    return {-re, -im};
  }
  void operator/=(const T x) { re /= x, im /= x; }
};
} // namespace lib


#line 1 "geometry/Trigonometry.cpp"


#line 4 "geometry/Trigonometry.cpp"

namespace lib {
using namespace std;
namespace geo {
namespace trig {
constexpr static long double PI = 3.141592653589793238462643383279502884197169399375105820974944l;
double cos(double x) { return ::cos(x); }
double sin(double x) { return ::sin(x); }
double asin(double x) { return ::asin(x); }
double acos(double x) { return ::acos(x); }
double atan2(double y, double x) { return ::atan2(y, x); }
long double cos(long double x) { return ::cosl(x); }
long double sin(long double x) { return ::sinl(x); }
long double asin(long double x) { return ::asinl(x); }
long double acos(long double x) { return ::acosl(x); }
long double atan2(long double y, long double x) { return ::atan2l(y, x); }
} // namespace trig
} // namespace geo
} // namespace lib


#line 7 "FFT.cpp"

namespace lib {
using namespace std;
namespace linalg {

template<typename T>
struct ComplexRootProvider {
  typedef Complex<T> cd;
  typedef Complex<long double> cld;
  static vector<cd> w;
  static vector<cld> wl;

  static cld root(long double ang) {
    return cld(geo::trig::cos(ang), geo::trig::sin(ang));
  }

  cd operator()(int n, int k) {
    long double ang = 2.0l * geo::trig::PI / (n / k);
    return root(ang);
  }
  void operator()(int n) {
    n = max(n, 2);
    int k = max((int)w.size(), 2);
    if ((int)w.size() < n)
      w.resize(n), wl.resize(n);
    else
      return;
    w[0] = w[1] = cd(1.0, 0.0);
    wl[0] = wl[1] = cld(1.0, 0.0);
    for (; k < n; k *= 2) {
      long double ang = 2.0l * geo::trig::PI / (2*k);
      cld step = root(ang);
      for(int i = k; i < 2*k; i++)
        w[i] = wl[i] = (i&1) ? wl[i/2] * step : wl[i/2];
    }
  }
  cd operator[](int i) {
    return w[i];
  }
  cd inverse(int n) {
    return cd(1.0 / n, 0.0);
  }
};

template<typename T>
vector<Complex<T>> ComplexRootProvider<T>::w = vector<Complex<T>>();
template<typename T>
vector<Complex<long double>> ComplexRootProvider<T>::wl = vector<Complex<long double>>();

template<typename T = double>
struct FFT : public DFT<Complex<T>, ComplexRootProvider<T>> {
  using Parent = DFT<Complex<T>, ComplexRootProvider<T>>;
  using Parent::fa;

  template <typename U>
  static void _convolve(const vector<U> &a, const vector<U> &b) {
    typedef Complex<T> cd;
    int n = Parent::ensure(a.size(), b.size());
    for (size_t i = 0; i < (size_t)n; i++)
      fa[i] = cd(i < a.size() ? (T)a[i] : T(),
                 i < b.size() ? (T)b[i] : T());
    Parent::dft(n);
    for (int i = 0; i < n; i++)
      fa[i] *= fa[i];
    Parent::idft(n);
    for (int i = 0; i < n; i++)
      fa[i] = cd(fa[i].imag() / 2, T());
  }

  template<typename U>
  static vector<U> convolve(const vector<U>& a, const vector<U>& b) {
    int sz = (int)a.size() + b.size() - 1;
    _convolve(a, b);
    return retrieve<Parent, U>(sz);
  }

  template<typename U>
  static vector<U> convolve_rounded(const vector<U>& a, const vector<U>& b) {
    int sz = (int)a.size() + b.size() - 1;
    _convolve(a, b);
    vector<U> res(sz);
    for(int i = 0; i < sz; i++) res[i] = (U)(long long)(fa[i].real() + 0.5);
    return res;
  }

  // TODO: use separate static buffers for this function
  template <typename M>
  static vector<M> convolve_mod(const vector<M> &a, const vector<M> &b) {
    typedef typename M::type_int type_int;
    typedef typename M::large_int large_int;
    typedef Complex<T> cd;
    typedef vector<cd> vcd;

    static_assert(sizeof(M::mods) / sizeof(type_int) == 1,
                  "cant multiply with multiple mods");
    type_int base = sqrtl(M::mods[0]) + 0.5;
    M base_m = base;
    int sza = a.size();
    int szb = b.size();
    int sz = sza+szb-1;
    int n = next_power_of_two(sz);
    Parent::dft_rev(n);

    // establish buffers
    vcd fa(n), fb(n), C1(n), C2(n);

    for (int i = 0; i < n; i++)
      fa[i] = i < sza ? cd((type_int)a[i] / base, (type_int)a[i] % base) : cd();
    for (int i = 0; i < n; i++)
      fb[i] = i < szb ? cd((type_int)b[i] / base, (type_int)b[i] % base) : cd();
    Parent::dft(fa, n);
    Parent::dft(fb, n);

    for (int i = 0; i < n; i++) {
      int j = i ? n - i : 0;
      cd a1 = (fa[i] + fa[j].conj()) * cd(0.5, 0.0);
      cd a2 = (fa[i] - fa[j].conj()) * cd(0.0, -0.5);
      cd b1 = (fb[i] + fb[j].conj()) * cd(0.5, 0.0);
      cd b2 = (fb[i] - fb[j].conj()) * cd(0.0, -0.5);
      cd c11 = a1 * b1, c12 = a1 * b2;
      cd c21 = a2 * b1, c22 = a2 * b2;
      C1[j] = c11 + c12 * cd(0.0, 1.0);
      C2[j] = c21 + c22 * cd(0.0, 1.0);
    }
    Parent::idft(C1, n), Parent::idft(C2, n);

    vector<M> res(sz);
    for (int i = 0; i < sz; i++) {
      int j = i ? n - i : 0;
      M x = large_int(C1[j].real() + 0.5);
      M y1 = large_int(C1[j].imag() + 0.5);
      M y2 = large_int(C2[j].real() + 0.5);
      M z = large_int(C2[j].imag() + 0.5);
      res[i] = x * base_m * base_m + (y1 + y2) * base_m + z;
    }

    return res;
  }
};
} // namespace linalg

namespace math {
struct FastMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve_rounded(a, b);
  }
};

struct FFTMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve(a, b);
  }
};

struct SafeMultiplication {
  template<typename T>
  using Transform = linalg::FFT<T>;
  template <typename Field, typename U = double>
  vector<Field> operator()(const vector<Field> &a,
                           const vector<Field> &b) const {
    return linalg::FFT<U>::convolve_mod(a, b);
  };
};
} // namespace math
} // namespace lib
Back to top page