博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【BZOJ4944】【NOI2017】泳池 概率DP 常系数线性递推 特征多项式 多项式取模
阅读量:5046 次
发布时间:2019-06-12

本文共 7164 字,大约阅读时间需要 23 分钟。

题目大意

  有一个\(1001\times n\)的的网格,每个格子有\(q\)的概率是安全的,\(1-q\)的概率是危险的。

  定义一个矩形是合法的当且仅当:

  • 这个矩形中每个格子都是安全的
  • 必须紧贴网格的下边界

  问你最大的合法子矩形大小为\(k\)的概率是多少。

  \(n\leq {10}^9,k\leq 1000\)

  吉老师:这题本来是\(k\leq 20000\)

题解

  一道好题。

  我们计算最大子矩形不超过\(i\)的答案\(s_i\),那么答案就是\(s_k-s_{k-1}\)

  显然最后一行连续的安全格子不会超过\(k\)个。

  设\(g_{i,j}\)表示长度为\(j\),高度为\(i\)的海域全部是安全的,剩下的部分未知,最大子矩形\(\leq k\)的概率。

  设\(h_{i,j}\)表示长度为\(j\),高度为\(i+1\)的海域中,前\(i\)行全部是安全的,剩下的未知且\((i+1,j)\)是危险的,最大子矩形\(\leq k\)的概率。

  边界:

\[ \begin{align} g_{k,1}&=q^k(1-q)\\ g_{i,0}&=1\\ h_{i,0}&=1 \end{align} \]
  那么我们从\(k-1\)\(1\)DP,对于\(i\)\(j\)列,枚举第\(i+1\)行的下一个危险的格子在哪个地方,然后转移:
\[ \begin{align} g_{i,j}&=\sum_{k=0}^{j}h_{i,k}g_{i+1,j-k}\\ h_{i,j}&=\sum_{k=0}^{j-1}h_{i,k}g_{i+1,j-k-1}q^i(1-q) \end{align} \]
  因为第\(i\)行的宽度不会超过\(\lfloor\frac{k}{i}\rfloor\),所以的暴力的时间复杂度是\(\sum_{i=1}^k{\lfloor\frac{k}{i}\rfloor}^2=O(k^2)\)

  这已经足够了,但我们可以做的更好。

  设

\[ \begin{align} A_i(x)&=\sum_{j\geq 0}g_{i,j}x^j\\ B_i(x)&=\sum_{j\geq 0}h_{i,j}x^j\\ c_i&=q^i(1-q)\\ \end{align} \]
那么
\[ \begin{align} A_i(x)&=B_i(x)A_{i+1}(x)\\ B_i(x)&=c_ixA_{i+1}(x)B_i(x)+1\\ B_i(x)&=\frac{1}{1-c_ixA_{i+1}(x)}\\ \end{align} \]
  时间复杂度是\(\sum_{i=1}^k\lfloor\frac{k}{i}\rfloor\log\lfloor\frac{k}{i}\rfloor=O(k\log^2k)\)

  设\(f_i\)为前\(i\)列最大子矩形\(\leq k\)的概率,那么

\[ f_i=\sum_{j=1}^kf_{i-j-1}g_{1,j}(1-q) \]
  这就是一个常系数线性递推。
\[ \begin{align} a_i&=g_{1,i-1}(1-q)\\ f_i&=\sum_{j=1}^kf_{i-j}a_j \end{align} \]

  时间复杂度:

  • 暴力:\(O(nk)\)\(70\)pts
  • 矩阵快速幂:\(O(k^3\log n)\)\(90\)pts
  • 特征多项式+暴力:\(O(k^2\log n)\)\(100\)pts
  • 特征多项式+NTT取模:\(O(k\log k\log n)\)\(100\)pts

  这里简单讲一下最后一个做法

  矩阵快速幂是给你一个矩阵\(A\),求\((A^n)_{1,1}\)

  设矩阵的大小为\(k\)

  根据Cayley-Hamilton定理,\(|\lambda I-A|\)是一个关于\(\lambda\)\(k\)次多项式,记为\(g(\lambda)\)。对于任意矩阵\(A\),有\(g(A)=0\)

  对于常系数线性递推的矩阵,设\(f_i=\sum_{j=1}^kf_{i-j}a_j\)\(g(\lambda)=\lambda^k-\sum_{i=1}^{k}a_{i}\lambda^{k-i}\)

  所以我们只需要求\(A^n\mod g(A)\)。可以用快速幂(倍增取模)求解。

  然后还要求出\(f_1\ldots f_k\),可以通过其他方法计算(多项式求逆或者题目给你了)。

  最后一次卷积可以得到答案。

  如果要求\(f_{n-k+1}\ldots f_n\),那就把\(f_1\ldots f_{2k}\)带进去卷积。

  总时间复杂度:\(O(k\log^2k+k\log k\log n)\)

代码

  暴力取模

#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair
pii;typedef pair
pll;void sort(int &a,int &b){ if(a>b) swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE char str[100]; sprintf(str,"%s.in",s); freopen(str,"r",stdin); sprintf(str,"%s.out",s); freopen(str,"w",stdout);#endif}int rd(){ int s=0,c; while((c=getchar())<'0'||c>'9'); do { s=s*10+c-'0'; } while((c=getchar())>='0'&&c<='9'); return s;}int upmin(int &a,int b){ if(b
a) { a=b; return 1; } return 0;}ll p=998244353;void add(ll &a,ll b){ a=(a+b)%p;}ll fp(ll a,ll b){ ll s=1; for(;b;b>>=1,a=a*a%p) if(b&1) s=s*a%p; return s;}ll inv(ll a){ return fp(a,p-2);}ll pw1[1010];ll pw2[1010];ll q;ll q2;ll g[1010][1010];ll h[1010][1010];ll f[2010];ll a[2010];ll c[2010];ll d[2010];ll final[2010];void mul(ll *a,ll *b,ll *e,int len){ static ll c[2010]; int i,j; for(i=0;i<=2*len;i++) c[i]=0; for(i=0;i<=len;i++) for(j=0;j<=len;j++) add(c[i+j],a[i]*b[j]); for(i=2*len;i>=len;i--) { ll v=c[i]*inv(e[len]); if(v) for(j=0;j<=len;j++) c[i-len+j]=(c[i-len+j]-e[j]*v)%p; } for(i=0;i<=len;i++) a[i]=c[i];}ll solve(int n,int k){ if(!k) return fp(q2,n); memset(g,0,sizeof g); memset(h,0,sizeof h); g[k][1]=q2*pw1[k]%p; g[k][0]=1; int i,j,l; for(i=k-1;i>=1;i--) { int m=k/i; g[i][0]=1; h[i][0]=1; for(j=0;j<=m;j++) { for(l=j+1;l<=m;l++) add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p); for(l=j;l<=m;l++) if(l) add(g[i][l],h[i][j]*g[i+1][l-j]%p); } } memset(f,0,sizeof f); f[0]=1; for(i=1;i<=2*(k+1);i++) for(j=0;j
<=k;j++) add(f[i],f[i-j-1]*q2%p*g[1][j]); if(n<=2*(k+1)) { ll s=0; for(i=0;i<=n&&i<=k;i++) add(s,f[n-i]*g[1][i]); return s; } int len=k+1; for(i=0;i
>=1; } memset(final,0,sizeof final); for(i=1;i<=k+1;i++) for(j=0;j<=k;j++) add(final[i],d[j]*f[i+j]); ll s=0; for(i=1;i<=k+1;i++) add(s,final[i]*g[1][k+1-i]); return s;}int main(){ open("bzoj4944"); int n,k,x,y; scanf("%d%d%d%d",&n,&k,&x,&y); q=x*inv(y)%p; q2=(y-x)*inv(y)%p; pw1[0]=pw2[0]=1; int i; for(i=1;i<=k;i++) { pw1[i]=pw1[i-1]*q%p; pw2[i]=pw2[i-1]*q2%p; } ll ans1=solve(n,k); ll ans2=solve(n,k-1); ll ans=((ans1-ans2)%p+p)%p; printf("%lld\n",ans); return 0;}

  NTT取模

#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair
pii;typedef pair
pll;void sort(int &a,int &b){ if(a>b) swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE char str[100]; sprintf(str,"%s.in",s); freopen(str,"r",stdin); sprintf(str,"%s.out",s); freopen(str,"w",stdout);#endif}int rd(){ int s=0,c; while((c=getchar())<'0'||c>'9'); do { s=s*10+c-'0'; } while((c=getchar())>='0'&&c<='9'); return s;}int upmin(int &a,int b){ if(b
a) { a=b; return 1; } return 0;}const ll p=998244353;const int maxn=300000;ll fp(ll a,ll b){ ll s=1; for(;b;b>>=1,a=a*a%p) if(b&1) s=s*a%p; return s;}namespace ntt{ const ll g=3; ll w1[maxn]; ll w2[maxn]; int rev[maxn]; int n; void init(int m) { n=1; while(n
>1]>>1)|((i&1)*(n>>1)); } void ntt(ll *a,int t) { int i,j,k; ll u,v,w,wn; for(i=0;i
>1); init(m<<1); copy_clear(x,a,m); copy_clear(y,b,m>>1); ntt(x,1); ntt(y,1); int i; for(i=0;i
>1);// copy_clear(c,b,m>>1); int i; for(i=m;i
<<1;i++) b[i]=0; inverse(b,d,m); init(m<<1); for(i=m;i
<<1;i++) b[i]=d[i]=0; ll inv2=fp(2,p-2); copy_clear(x,a,m); ntt(x,1); ntt(d,1); for(i=0;i
=1;i--)// b[i]=a[i-1]*inv[i]%p; b[0]=0; } void ln(ll *a,ll *b,int m) { static ll c[maxn],d[maxn]; derivative(a,c,m); inverse(a,d,m); init(m<<1); int i; for(i=m;i
>1); int i; for(i=m>>1;i
>1); mul(a,a,c,m); if(n&1) mul(a,b,c,m);}ll solve(int n,int k){ memset(g,0,sizeof g); memset(h,0,sizeof h); int now=0; g[now][1]=q2*pw1[k]%p; g[now][0]=1; h[0]=1; int i,j; for(i=k-1;i>=1;i--) { now^=1; int m=k/i; ll c=q2*pw1[i]%p; int len=1; while(len<=m) len<<=1; for(j=1;j
>=1; for(i=0;i<=k;i++) printf("%lld ",(d[i]+p)%p); printf("\n");// } reverse(d,d+k+1); ntt::init(len<<2); ntt::ntt(d,1); ntt::ntt(f,1); for(i=0;i
<<2;i++) final[i]=d[i]*f[i]%p; ntt::ntt(final,-1); ll s=0; for(i=0;i<=k;i++) add(s,g[now][i]*final[2*k-i]); return s;// for(i=0;i<=k;i++)// g[now][i]=(g[now][i]+p)%p;// memset(f,0,sizeof f);// f[0]=1;// for(i=1;i<=2*(k+1);i++)// for(j=0;j
<=k;j++)// add(f[i],f[i-j-1]*q2%p*g[now][j]);// if(n<=2*(k+1))// {// ll s=0;// for(i=0;i<=n&&i<=k;i++)// add(s,f[n-i]*g[now][i]);// return s;// }// int len=k+1;// for(i=0;i
>=1;// }// memset(final,0,sizeof final);// for(i=1;i<=k+1;i++)// for(j=0;j<=k;j++)// add(final[i],d[j]*f[i+j]);// ll s=0;// for(i=1;i<=k+1;i++)// add(s,final[i]*g[now][k+1-i]);// return s;}int main(){ open("bzoj4944"); int n,k,x,y; scanf("%d%d%d%d",&n,&k,&x,&y); q=x*inv(y)%p; q2=(y-x)*inv(y)%p; pw1[0]=pw2[0]=1; int i; for(i=1;i<=k;i++) { pw1[i]=pw1[i-1]*q%p; pw2[i]=pw2[i-1]*q2%p; } ll ans1=solve(n,k); ll ans2=solve(n,k-1); ll ans=((ans1-ans2)%p+p)%p; printf("%lld\n",ans); return 0;}

转载于:https://www.cnblogs.com/ywwyww/p/8513404.html

你可能感兴趣的文章
日志框架--(一)基础篇
查看>>
Java设计模式之原型模式
查看>>
Spring学习(四)-----Spring Bean引用同xml和不同xml bean的例子
查看>>
哲理故事与管理之道(20)-用危机激励下属
查看>>
关于源程序到可运行程序的过程
查看>>
wepy的使用
查看>>
转载:mysql数据库密码忘记找回方法
查看>>
scratch少儿编程第一季——06、人在江湖混,没有背景怎么行。
查看>>
面向对象1
查看>>
在ns2.35中添加myevalvid框架
查看>>
【贪心+DFS】D. Field expansion
查看>>
为什么要使用href=”javascript:void(0);”
查看>>
二进制文件的查看和编辑
查看>>
Openstack neutron:SDN现状
查看>>
python 打印对象的所有属性值的方法
查看>>
HDU 1160 FatMouse&#39;s Speed (最长有序的上升子序列)
查看>>
[数字图像处理]常见噪声的分类与Matlab实现
查看>>
开发指南专题六:JEECG微云高速开发平台代码生成
查看>>
node-gyp rebuild 卡住?
查看>>
maven filter不起作用
查看>>