Description
- 给你一个
n
n
n个点的树,你可以选择一个点在上面随机游走,每次等概率随机跳到一个距离不超过2的点(包括自己)。
- 现在给出
m
m
m个标记点,求每一个点跳到任意一个标记点的期望步数。
-
n
,
m
≤
1
e
5
n,m\le1e5
n,m≤1e5
Solution
- 考虑从叶子往上面推,那么一个点的期望
E
(
x
)
E(x)
E(x)可以表示成
s
u
m
[
f
a
x
]
,
E
(
f
a
x
)
,
E
(
f
a
f
a
x
)
sum[fa_x],E(fa_x),E(fa_{fa_x})
sum[fax],E(fax),E(fafax)的和,其中
s
u
m
[
x
]
=
∑
E
(
s
o
n
x
)
sum[x]=\sum E(son_x)
sum[x]=∑E(sonx)
- 考虑一个点,要解出所有儿子的
E
(
x
)
E(x)
E(x),而这些
E
(
x
)
E(x)
E(x)还互相有关。
- 实际上我们可以把所有
E
(
x
)
=
.
.
.
E(x)=...
E(x)=...的方程加在一起,这样左边就是
s
u
m
sum
sum了,就可以把
s
u
m
sum
sum解出来了,这样就可以表示为
E
(
f
a
x
)
,
E
(
f
a
f
a
x
)
E(fa_x),E(fa_{fa_x})
E(fax),E(fafax)的和了。
- 最后再从根节点推下来即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 100005
#define ll long long
#define mo 998244353
using namespace std;
int n,m,i,j,k,bz[maxn],du[maxn],cnt[maxn];
int em,e[maxn*2],nx[maxn*2],ls[maxn],fa[maxn];
ll f[maxn][3],g[maxn],inv[maxn];
ll ksm(ll x,ll y){
ll s=1;
for(;y;y/=2,x=x*x%mo) if (y&1)
s=s*x%mo;
return s;
}
void insert(int x,int y){
du[x]++,du[y]++;
em++; e[em]=y; nx[em]=ls[x]; ls[x]=em;
em++; e[em]=x; nx[em]=ls[y]; ls[y]=em;
}
ll s[4];
void dfs(int x,int p){
fa[x]=p;
if (!bz[x]) {
f[x][0]=inv[cnt[x]-1]*cnt[x]%mo;
if (fa[x]) f[x][1]=inv[cnt[x]-1],g[x]=inv[cnt[x]-1];
if (fa[fa[x]]) f[x][2]=inv[cnt[x]-1];
}
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p) dfs(e[i],x);
s[0]=s[1]=s[2]=0; ll psum=0;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p) {
ll tmp=ksm(g[e[i]]+1,mo-2);
(f[e[i]][0]*=tmp)%=mo;
(f[e[i]][1]*=tmp)%=mo;
(f[e[i]][2]*=tmp)%=mo;
(g[e[i]]*=tmp)%=mo;
(s[0]+=f[e[i]][0])%=mo;
(s[1]+=f[e[i]][1])%=mo;
(s[2]+=f[e[i]][2])%=mo;
(psum+=g[e[i]])%=mo;
}
ll Inv=ksm(mo+1-psum,mo-2);
s[0]=s[0]*Inv%mo,s[1]=s[1]*Inv%mo,s[2]=s[2]*Inv%mo;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p) {
(f[e[i]][0]+=s[0]*g[e[i]])%=mo;
(f[e[i]][1]+=s[1]*g[e[i]])%=mo;
(f[e[i]][2]+=s[2]*g[e[i]])%=mo;
}
if (!bz[x]){
s[0]=s[1]=s[2]=s[3]=0;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p){
int y=e[i]; ll sumy=0;
for(int j=ls[y];j;j=nx[j]) if (e[j]!=x){
int z=e[j];
(s[0]+=f[z][0])%=mo;
(sumy+=f[z][1])%=mo;
(s[1]+=f[z][2])%=mo;
}
sumy++;
(s[0]+=f[y][0]*sumy)%=mo;
(s[1]+=f[y][1]*sumy)%=mo;
(s[2]+=f[y][2]*sumy)%=mo;
}
s[0]=s[0]*inv[cnt[x]-1]%mo;
s[1]=s[1]*inv[cnt[x]-1]%mo;
s[2]=s[2]*inv[cnt[x]-1]%mo;
(f[x][0]+=s[0])%=mo;
(f[x][1]+=s[2])%=mo;
Inv=ksm(mo+1-s[1],mo-2);
f[x][0]=f[x][0]*Inv%mo;
f[x][1]=f[x][1]*Inv%mo;
f[x][2]=f[x][2]*Inv%mo;
g[x]=g[x]*Inv%mo;
}
}
ll ans[maxn];
void dfs2(int x,int p){
ans[x]=f[x][0];
if (fa[x]) (ans[x]+=ans[fa[x]]*f[x][1])%=mo;
if (fa[fa[x]]) (ans[x]+=ans[fa[fa[x]]]*f[x][2])%=mo;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p)
dfs2(e[i],x);
}
int main(){
freopen("ceshi.in","r",stdin);
freopen("ceshi.out","w",stdout);
scanf("%d%d",&n,&m);
inv[0]=1;for(i=1;i<=n;i++) inv[i]=ksm(i,mo-2);
for(i=1;i<n;i++) scanf("%d%d",&j,&k),insert(j,k);
for(i=1;i<=m;i++) scanf("%d",&k),bz[k]=1;
for(int x=1;x<=n;x++) {
cnt[x]=du[x]+1;
for(i=ls[x];i;i=nx[i]) cnt[x]+=du[e[i]]-1;
}
dfs(1,0);
dfs2(1,0);
for(i=1;i<=n;i++) printf("%lld\n",ans[i]);
}