20230828 B 组模拟赛 题解
前言
赛时发挥当暴力哥,结果挂了 20pts,喜提 150pts。这次久违的有一道签到题。
T1
签到题,赛时五分钟秒了,但写了半小时对拍验证自己的结论。
考虑排序的过程,容易想到每个逆序对需要一次交换,且仅需一次交换。那么就可以猜结论了:每对逆序对 $(i,j)$ 提供 $|p_i-p_j|$ 的贡献。证明其实可以从刚刚猜结论的那句话得出,因为那两句话恰好说明它是是充分必要条件。
然后树状数组随便维护一下就好了。
参考代码
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define forup(i,s,e) for(int i=(s);i<=(e);i++)
#define fordown(i,s,e) for(int i=(s);i>=(e);i--)
using namespace std;
using i64=long long;
#define gc getchar()
inline int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=1e6+5,inf=0x3f3f3f3f;
int n,a[N];
i64 ans;
struct BIT{
i64 c[N];
void upd(i64 x,i64 k){for(;x<=n;x+=x&-x)c[x]+=k;}
i64 sum(i64 x){i64 res=0;for(;x>0;x-=x&-x)res+=c[x];return res;}
}t1,t2;
signed main(){
n=read();
forup(i,1,n){
a[i]=read();
ans+=t1.sum(n)-t1.sum(a[i])-a[i]*(t2.sum(n)-t2.sum(a[i]));
t1.upd(a[i],a[i]);t2.upd(a[i],1);
}
printf("%lld",ans);
}
T2
首先,设 $f(i)$ 为恰好在第 $i$ 轮删掉 $a_1$ 的概率,那么期望就是 $\sum_{i=1}^ki\cdot f(i)$(根据期望的定义)。但是容易发现 $f(i)$ 很不好求(或者说求出来也不好优化到比较能过的复杂度),那么考虑求 $\sum_{i=0}^{k-1}g(i)$,其中 $g(i)$ 为在第 $i$ 轮 $1$ 还没有被删掉的概率。正确性显然,因为每个 $f(i)$ 会对 $g(0,1,2\dots i-1)$ 分别提供 $f(i)$ 的贡献,加起来就是 $i\cdot f(i)$ 了。
结果我赛时想到这一步后没想到后面怎么做。
然后考虑如何求 $g(i)$,首先容易发现,在任何时候想要不删掉 $a_1$,那么需要满足以下条件:
- $[a_1,a_2)$ 中所有点加起来至多被选 $0$ 次。
- $[a_1,a_3)$ 中所有点加起来至多被选 $1$ 次。
- $[a_1,a_4)$ 中所有点加起来至多被选 $2$ 次。
- $\cdots$
- $[a_1,n]$ 中中所有点加起来至多被选 $k-1$ 次。
那么就可以想到 dp 求解,先令 $a_{k+1}=n+1$(方便处理边界)。设 $dp_{i,j}$ 表示考虑前 $i$ 个点(即考虑到 $[a_1,a_{i+1})$ 中所有点是否被选),选 $j$ 个后点 $1$ 不被删的方案数。转移考虑在 $[a_i,a_{i+1})$ 中选几个,然后插入选择顺序中的哪些位置。转移方程如下:
$$dp_{i,j}=\sum_{p=0}^jdp_{i-1,j-p}\times\dbinom{j}{k}\times\left(\dfrac{a_{i+1}-a_i}{n}\right)^k$$
然后根据前面的结论答案就是 $\sum_{i=0}^{k-1}dp_{k,i}$。
复杂度 $O(k^3)$,可以通过此题。
参考代码
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define forup(i,s,e) for(int i=(s);i<=(e);i++)
#define fordown(i,s,e) for(int i=(s);i>=(e);i--)
using namespace std;
#define gc getchar()
inline int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=505,inf=0x3f3f3f3f,mod=998244353;
int n,m,a[N],dp[N][N],invn,ans;
int C[N][N];
int ksm(int a,int b){
int c=1;
while(b){
if(b&1) c=1ll*c*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return c;
}
signed main(){
n=read();m=read();
invn=ksm(n,mod-2);
forup(i,1,m){
a[i]=read();
}
C[0][0]=1;
forup(i,0,m){
forup(j,0,i){
(C[i+1][j]+=C[i][j])%=mod;
(C[i+1][j+1]+=C[i][j])%=mod;
}
}
a[m+1]=n+1;
dp[0][0]=1;
forup(i,1,m){
forup(j,0,i-1){
dp[i][j]=0;
forup(k,0,j){
(dp[i][j]+=1ll*dp[i-1][j-k]*C[j][k]%mod*ksm(1ll*(a[i+1]-a[i])*invn%mod,k)%mod)%=mod;
}
}
}
forup(i,0,m-1){
(ans+=dp[m][i])%=mod;
}
printf("%d\n",ans);
}
T3
很有意思的一道题。
先不考虑第二问,容易想到考虑树形 DP,对于每一块的 $a$ 点,需要恰好一个长度为 $1$ 的儿子和一个长度为 $2$ 的儿子,剩下的不能向上延伸,那么容易想到设 $dp_{i,0/1/2}$ 表示考虑点 $i$ 及其字数,点 $i$ 不作为某一块的儿子/作为长度为 $1$ 的儿子/作为长度为 $2$ 的儿子的方案数。转移的话,$dp_{i,1}$ 从所有 $dp_{v\in son_i,0}$ 转移过来,$dp_{i,2}$ 需要一个儿子从 $dp_{v,1}$,其它从 $dp_{v,0}$ 转移,$dp_{i,0}$ 则分作不作为某一块的 $a$ 点考虑。直接做不好维护,因为要枚举哪个儿子的第二维不同。容易想到用 DP 套 DP 维护。
具体地,设 $g_{i,0/1,0/1}$ 表示考虑前 $i$ 个儿子,有没有选一个长度为 $1$ 的儿子,有没有选一个长度为 $2$ 的儿子的方案数。设点 $i$ 的儿子分别是 $1,2,\dots k$,其中 $k$ 为点 $i$ 儿子的个数,转移如下:
$$ g_{j,0,0}=g_{j-1,0,0}\times dp_{j,0}\\ g_{j,1,0}=g_{j-1,0,0}\times dp_{j,1}+g_{j-1,1,0}\times dp_{j,0}\\ g_{j,0,1}=g_{j-1,0,0}\times dp_{j,2}+g_{j-1,0,1}\times dp_{j,0}\\ g_{j,1,1}=g_{j-1,0,1}\times dp_{j,1}+g_{j-1,1,0}\times dp_{j,2}+g_{j-1,1,1}\times dp_{j,0} $$
最后 $dp_{i,1}=g_{k,0,0},dp_{i,2}=g_{k,1,0},dp_{i,0}=g_{k,0,0}+g_{k,1,1}$。然后就可以获得一半的分了。
考虑第二问怎么做。其实容易发现可以和第一问一起 DP,但是这样比较麻烦。可以考虑把两个 DP 数组开在一个结构体里面。
具体来说,设二元组 $(a,b)$ 表示分别的答案,考虑给它重载加法和乘法。
首先考虑 $(a_1,b_1)+(a_2,b_2)$,容易发现两边的方案和对应的集合数是绑定的,可以直接等于 $(a_1+a_2,b_1+b_2)$。
然后考虑 $(a_1,b_1)\times(a_2,b_2)$,乘法的意义是在两边每边选一个方案拼在一起的方案数。那么左边的每个集合都可以对应右边一个方案,右边同理,故等于 $(a_1a_2,a_1b_2+a_2b_1)$。
然后和之前一样转移就行了,需要注意此时 $dp_{i,0}=g_{k,0,0}+g_{k,1,1}\times(1,1)$,因为合并时集合数会增加。
复杂度 $O(n)$,因为每个数会在计算到它自己和父亲时各计算常数次,加起来也是常数的。
参考代码
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define forup(i,s,e) for(int i=(s);i<=(e);i++)
#define fordown(i,s,e) for(int i=(s);i>=(e);i--)
using namespace std;
#define gc getchar()
inline int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=5e5+5,inf=0x3f3f3f3f,mod=998244353;
int n,ty;
vector<int> e[N];
struct Node{
int a,b;
Node(int _a=0,int _b=0):a(_a),b(_b){}
Node operator +(const Node &r){
Node res;
res.a=(a+r.a)%mod;res.b=(b+r.b)%mod;
return res;
}
Node operator *(const Node &r){
Node res;
res.a=1ll*a*r.a%mod;res.b=(1ll*r.a*b%mod+1ll*a*r.b%mod)%mod;
return res;
}
}dp[N][3],g[2][2][2];
void dfs(int x,int fa){
for(auto i:e[x]){
if(i==fa) continue;
dfs(i,x);
}
int cnt=0;
g[0][0][0]=Node(1,0);
g[0][1][1]=g[0][1][0]=g[0][0][1]=Node(0,0);
for(auto i:e[x]){
if(i==fa) continue;
++cnt;
int p=cnt&1,q=p^1;
g[p][0][0]=g[q][0][0]*dp[i][0];
g[p][0][1]=g[q][0][0]*dp[i][2]+g[q][0][1]*dp[i][0];
g[p][1][0]=g[q][0][0]*dp[i][1]+g[q][1][0]*dp[i][0];
g[p][1][1]=g[q][1][0]*dp[i][2]+g[q][0][1]*dp[i][1]+g[q][1][1]*dp[i][0];
}
dp[x][1]=g[cnt&1][0][0];
dp[x][2]=g[cnt&1][1][0];
dp[x][0]=g[cnt&1][1][1]*Node(1,1);
dp[x][0]=dp[x][0]+dp[x][1];
}
signed main(){
n=read();ty=read();
forup(i,1,n-1){
int u=read(),v=read();
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0);
printf("%d\n",dp[1][0].a);
if(ty==1){
printf("%d\n",dp[1][0].b);
}
}
T4
点分治套路题。
考虑进行一个简单的容斥,即“与点 $x$ 或 $y$ 的距离小于等于 $d$ 的点”等于“与点 $x$ 距离小于等于 $d$ 的点 $+$ 与点 $y$ 距离小于等于 $d$ 的点 $-$ 与点 $x,y$ 距离均小于等于 $d$ 的点”。
假设 $x,y$ 距离为偶数,那么它们必定有一个中点 $l$,容易发现“与点 $x,y$ 距离均小于等于 $d$ 的点”就是与 $l$ 距离小于等于 $d-dis(x,l)$ 的点。因为与 $l$ 距离小于等于 $d-dis(x,l)$ 的点必定与 $x,y$ 距离均小于等于 $d$ (必要性),而于 $l$ 距离大于 $d-dis(x,l)$ 的点必定与 $x,y$ 其中至少一个的距离大于 $d$(反证充分性)。
那么这个问题就拆成了三个经典问题,即“树上统计与某个点距离小于 $d$ 的点的个数”,可以离线下来用点分治做。
注意求重心的时候子树大小应取“子树内节点数 $+$ 子树内询问数”,因为点分治过程中这两个会各被遍历常数遍(其实或许乘点常数效率会更高?),实测证明“子树内节点数 $+$ 子树内询问数”快于“子树内节点数”快于“子树内询问数”。
参考代码
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define forup(i,s,e) for(int i=(s);i<=(e);i++)
#define fordown(i,s,e) for(int i=(s);i>=(e);i--)
using namespace std;
#define gc getchar()
inline int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=2e5+5,inf=0x3f3f3f3f;
int n,q,ans[N];
vector<int> e[N<<1];
struct Node{
int dis,pos,mult;
};
vector<Node> que[N<<1];
int f[N<<1][21],dpt[N<<1];
void dfs(int x,int fa){
f[x][0]=fa;dpt[x]=dpt[fa]+1;
forup(i,1,20){
f[x][i]=f[f[x][i-1]][i-1];
}
for(auto i:e[x]){
if(i==fa) continue;
dfs(i,x);
}
}
int lca(int x,int y){
if(dpt[x]>dpt[y]) swap(x,y);
for(int i=20;i>=0&&dpt[y]>dpt[x];--i){
if(dpt[f[y][i]]>=dpt[x]){
y=f[y][i];
}
}
if(x==y) return x;
for(int i=20;i>=0&&f[x][0]!=f[y][0];--i){
if(f[x][i]!=f[y][i]){
x=f[x][i];y=f[y][i];
}
}
return f[x][0];
}
int vis[N<<1],sz[N<<1],mx[N<<1],als,wp;
void dfs1(int x,int fa){
sz[x]=que[x].size()+2;mx[x]=0;
for(auto i:e[x]){
if(i==fa||vis[i]) continue;
dfs1(i,x);
sz[x]+=sz[i];
mx[x]=max(mx[x],sz[i]);
}
mx[x]=max(mx[x],als-mx[x]);
if(wp==-1||mx[x]<mx[wp]) wp=x;
}
int dis[N<<1];
int vec[N],cvec,v1[N],cv1;
void dfs3(int x,int fa){
if(x<=n){
vec[cvec++]=dis[x];
}
for(auto i:e[x]){
if(i==fa||vis[i]) continue;
dis[i]=dis[x]+1;
dfs3(i,x);
}
}
vector<Node> qu;
void dfs4(int x,int fa){
if(x<=n){
v1[cv1++]=dis[x];
}
for(auto i:que[x]){
qu.push_back(Node{i.dis-dis[x],i.pos,i.mult});
}
for(auto i:e[x]){
if(i==fa||vis[i]) continue;
dfs4(i,x);
}
}
void dfz(int rt){
cvec=0;dis[rt]=0;
dfs3(rt,0);
sort(vec,vec+cvec);
for(auto i:que[rt]){
int pos=i.pos,dis=i.dis,mult=i.mult;
ans[pos]+=mult*(upper_bound(vec,vec+cvec,dis)-vec);
}
for(auto i:e[rt]){
if(vis[i]) continue;
cv1=0;qu.clear();
dfs4(i,rt);
sort(v1,v1+cv1);
for(auto j:qu){
int pos=j.pos,dis=j.dis,mult=j.mult;
ans[pos]+=mult*((upper_bound(vec,vec+cvec,dis)-vec)-(upper_bound(v1,v1+cv1,dis)-v1));
}
}
vis[rt]=1;
for(auto i:e[rt]){
if(vis[i]) continue;
als=sz[i];wp=-1;
dfs1(i,rt);
dfz(wp);
}
}
signed main(){
n=read();q=read();
forup(i,1,n-1){
int u=read(),v=read();
e[u].push_back(n+i);e[v].push_back(n+i);
e[n+i].push_back(u);e[n+i].push_back(v);
}
dfs(1,0);
forup(i,1,q){
int x=read(),y=read(),d=read()*2,l;
int dis=(dpt[x]+dpt[y]-dpt[lca(x,y)]*2)>>1;
que[x].push_back(Node{d,i,1});
if(y!=x){
que[y].push_back(Node{d,i,1});
if(d>=dis){
if(dpt[x]>dpt[y]) swap(x,y);
l=y;
fordown(i,20,0){
if(dis&(1<<i)) l=f[l][i];
}
que[l].push_back(Node{d-dis,i,-1});
}
}
}
als=2*n+q;wp=-1;
dfs1(1,0);
dfz(wp);
forup(i,1,q){
printf("%d\n",ans[i]);
}
}