bzoj 5306 [HAOI2018] 染色
推式子题
首先枚举有几种颜色选择恰好 \(s\) 次,可以得到一个式子:
\[\sum _{i = 0} ^ {\min(\frac n s,m )} \frac {\binom n {i \cdot s} \binom m i (i \cdot s) ! \cdot f_i \cdot (m-i)^{n - is}} {(s!) ^ i} \]
但是,\(f_i\) 不能单纯地等于 \(w_i\),因为会重复计算,我们不能保证当前选择的 \(i\) 种颜色之外的颜色是否没有出现恰好 \(s\) 次的
发现对于一种染色方案而言,假设其中选择恰好 \(s\) 次的颜色数量为 \(i\),那么他在枚举到 \(j\) 种颜色出现 \(s\) 次是被算进去 \(\binom i j\) 次,所有被算进去的 \(f\) 值之和就是 \(\sum _{j = 0} ^ i (f_j \cdot \binom i j)\)
我们令 \(w_i = \sum _{j = 0} ^i (f_j \cdot \binom i j)\), 那么这样可以推出来 \(f\) 序列,并且把这里的 \(f\) 序列代入之前的式子就可以求出答案了
怎么求 \(f\) 呢?有两种方式
发现 \(\sum _{j = 0} ^ i \frac{f_j} {j !} \cdot \frac 1 {(i - j) !} = \frac {w_i} {i!}\),可以用多项式除法直接求得 \(\frac {f_j} {j!}\)
-
倒推。
$w_0 = f_0 $
\(w_1 = f_0 + f_1\)
\(w_2=f_0+2f_1+f_2\)
\(w_3=f_0+3f_1+3f_2+f_3\)
\(\dots\)
解得
\(f_0=w_0\)
\(f_1=w_1-w_0\)
\(f_2=w_2-2w_1+w_0\)
\(f_3=w_3-3w_2+3w_1-w_0\)
找规律,就是 \(f_i = \sum _{j = 0} ^ i (-1)^{i-j} \cdot \binom i j \cdot w_j = i! \cdot \sum_{j=0}^i \frac {(-1)^{i-j}} {(i-j)!} \cdot \frac {w_j} {j!}\)
卷积即可求出 \(f_i\)
惊讶地发现,\(1004535809\) 这个模数的原根也是 \(3\)
// copyright lzt #include<stdio.h> #include<cstring> #include<cstdlib> #include<algorithm> #include<vector> #include<map> #include<set> #include<cmath> #include<iostream> #include<queue> #include<string> #include<ctime> using namespace std; typedef long long ll; typedef std::pair<int,int> pii; typedef long double ld; typedef unsigned long long ull; typedef std::pair<long long,long long> pll; #define fi first #define se second #define pb push_back #define mp make_pair #define rep(i,j,k) for (register int i = (int)(j); i <= (int)(k); i++) #define rrep(i,k) for (register int i = (int)(j); i >= (int)(k); i--) #define Debug(...) fprintf(stderr,__VA_ARGS__) inline ll read() { ll x = 0,f = 1; char ch = getchar(); while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } while (ch <= '9' && ch >= '0') { x = 10 * x + ch - '0'; ch = getchar(); } return x * f; } const int mod = 1004535809; const int GEN = 3; namespace poly { int alpha[2][400400],rev[400400],ntt_lst = -1; inline int sum(int x,int y) { x += y; return x >= mod ? x - mod : x; } inline int sub(int x,int y) { x -= y; return x < 0 ? x + mod : x; } inline int ksm(int x,int p) { int ret = 1; while (p) { if (p & 1) ret = ret * 1ll * x % mod; x = x * 1ll * x % mod; p >>= 1; } return ret; } inline void ntt_init(int n) { if (ntt_lst == n) return; ntt_lst = n; alpha[0][0] = alpha[1][0] = 1; alpha[0][1] = ksm(GEN,(mod - 1) / n); alpha[1][1] = ksm(alpha[0][1],mod - 2); rep(i,2,n) rep(x,1) alpha[x][i] = alpha[x][i - 1] * 1ll * alpha[x][1] % mod; int nw = n >> 1; rep(i,1,n - 2) { rev[i] = nw; int j = n >> 1; while (nw >= j) { nw -= j; j >>= 1; } nw += j; } } inline void ntt(int *a,int n,bool f) { ntt_init(n); rep(i,n - 2) if (i < rev[i]) swap(a[i],a[rev[i]]); for (int i = 1; i < n; i <<= 1) { for (int j = 0,off = n / (i << 1); j + i < n; j += (i << 1)) { for (int k = j,cur = 0; k < j + i; k++,cur += off) { int x = a[k],y = a[k + i] * 1ll * alpha[f][cur] % mod; a[k] = sum(x,y); a[k + i] = sub(x,y); } } } if (f) { int x = ksm(n,mod - 2); rep(i,n - 1) a[i] = a[i] * 1ll * x % mod; } } inline void mul(int *a,int *b,int n) { ntt(a,n,false); ntt(b,false); rep(i,n - 1) a[i] = a[i] * 1ll * b[i] % mod; ntt(a,true); } inline void mul(int *a,int m,int *res,int mx_len = -1) { static int A[400400],B[400400]; if (mx_len == -1) mx_len = n + m; n = min(n,mx_len); m = min(m,mx_len); int len = 1; while (len < n + m) len <<= 1; rep(i,len - 1) { A[i] = i < n ? a[i] : 0; B[i] = i < m ? b[i] : 0; } mul(A,B,len); memcpy(res,A,mx_len << 2); } } const int maxn = 400400; int n,m,s; int w[maxn],a[maxn],b[maxn],f[maxn]; int fac[10000100],inv[10000100]; inline int C(int x,int y) { return fac[x] * 1ll * inv[y] % mod * inv[x - y] % mod; } void work() { n = read(),m = read(),s = read(); rep(i,m) w[i] = read(); int lim = min(n / s,m),up = max(n,m); fac[0] = 1; rep(i,up) fac[i] = fac[i - 1] * 1ll * i % mod; inv[up] = poly::ksm(fac[up],mod - 2); rrep(i,up - 1,0) inv[i] = inv[i + 1] * 1ll * (i + 1) % mod; rep(i,lim) { a[i] = w[i] * 1ll * inv[i] % mod; if (i & 1) b[i] = (mod - inv[i]) % mod; else b[i] = inv[i] % mod; } poly::mul(a,lim + 1,b,a); rep(i,lim) f[i] = fac[i] * 1ll * a[i] % mod; int ans = 0; rep(i,lim) { int nw = C(n,i * s) * 1ll * C(m,i) % mod * fac[i * s] % mod * poly::ksm(inv[s],i) % mod * f[i] % mod * poly::ksm(m - i,n - i * s) % mod; ans = (ans + nw) % mod; } printf("%d\n",ans); } int main() { #ifdef LZT freopen("in","r",stdin); #endif work(); #ifdef LZT Debug("My Time: %.3lfms\n",(double)clock() / CLOCKS_PER_SEC); #endif }