AT_ddcc_2016_qual_d

先假设对于所有 (i,j)(i,j) 之间都连一条长度为 XX 的边,则问题转换成:只保留树边,记 D(i,j)D(i,j) 表示 min(dis(i,j),X)\min(dis(i,j),X),求 D(i,j)\sum D(i,j)

考虑套路点分治,算出 ij>i[dis(i,j)<X]=A,ij>i[dis(i,j)<X]×dis(i,j)=B\sum_i\sum_{j>i}[dis(i,j)<X]=A,\sum_i\sum_{j>i}[dis(i,j)<X]\times dis(i,j)=B,则 ans=B+((n2)A)×Xans=B+(\binom n2-A)\times X。算的方式就是将当前处理的以 rtrt 为根的子树中所有点 iidis(i,rt)dis(i,rt) 排序后双指针。

这一部分是一个板子,不做赘述。只需要注意点分治常见错误,如求重心某些地方写成 nn 而不是 sizsiz 即可(我怎么又踩坑啊)。

但是注意到如果有树边 (u,v)(u,v),则实际上不会再连 (i,j,X)(i,j,X),但是此时直接走 (u,v)(u,v) 不一定是最短路。于是分情况讨论:

  • degu=n1degv=n1deg_u=n-1\vee deg_v=n-1

此时不管怎么都是不能只走非树边的,但是可能先走非树边到另一个点 xx 再走到 vv。于是令 degv=n1,ci=min(i,j,w)Ewdeg_v=n-1,c_i=\min_{(i,j,w)\in E}w(u,v)(u,v) 边权为 WW,则 D(u,v)=min(W,X+cv)D(u,v)=\min(W,X+c_v)

  • degu+degv=ndeg_u+deg_v=n

此时由于没有点同时不与 u,vu,v 相邻,所以也无法走两条非树边到。但是发现可以走三条到,或者先走非树边到另一个点,再一步到 vv。则有 D(u,v)=min(W,3X,X+min(cu,cv))D(u,v)=\min(W,3X,X+\min(c_u,c_v))

  • degu+degv<ndeg_u+deg_v<n

此时可以走两条非树边到,但是也不能漏了 degu+degv=ndeg_u+deg_v=n 时就可行的一些情况,有 D(u,v)=min(W,2X,X+min(cu,cv))D(u,v)=\min(W,2X,X+\min(c_u,c_v))

于是解决了。时间复杂度 O(nlog2n)O(n\log^2n)

code:

int n,m,rt,siz[N],deg[N],c[N];
bool vis[N];
ll cnt,ans,dis[N];
vector<ll> g,h;
int tot,head[N];
struct node{
	int to,nxt,cw;
}e[N<<1];
il void add(int u,int v,int w){
	e[++tot]={v,head[u],w},head[u]=tot;
}
void getRt(int u,int f,int s){
	siz[u]=1;
	int mx=0;
	go(i,u){
		int v=e[i].to;
		if(v==f||vis[v]){
			continue;
		}
		getRt(v,u,s);
		siz[u]+=siz[v];
		mx=max(mx,siz[v]);
	}
	mx=max(mx,s-siz[u]);
	if(mx<=s/2){
		rt=u;
	}
}
void getSiz(int u,int f){
	siz[u]=1;
	go(i,u){
		int v=e[i].to;
		if(v==f||vis[v]){
			continue;
		}
		getSiz(v,u);
		siz[u]+=siz[v];
	}
}
void dfs(int u,int f){
	g.eb(dis[u]),h.eb(dis[u]);
	go(i,u){
		int v=e[i].to;
		if(v==f||vis[v]){
			continue;
		}
		dis[v]=dis[u]+e[i].cw;
		dfs(v,u);
	}
}
void calcD(){
	sort(g.begin(),g.end());
	int p=-1;
	vector<ll> pre;
	pre.eb(*g.begin());
	rep(i,1,(int)g.size()-1){
		pre.eb(g[i]+pre[i-1]);
	}
	drep(i,(int)g.size()-1,0){
		while(p<i-1&&g[p+1]+g[i]<m){
			p++;
		}
		while(p>=i){
			p--;
		}
		if(p>=0){
			ans-=g[i]*(p+1)+pre[p];
			cnt-=p+1;
		}
	}
}
void calcA(){
	sort(h.begin(),h.end());
	int p=-1;
	vector<ll> pre;
	pre.eb(*h.begin());
	rep(i,1,(int)h.size()-1){
		pre.eb(h[i]+pre[i-1]);
	}
	drep(i,(int)h.size()-1,0){
		while(p<i-1&&h[p+1]+h[i]<m){
			p++;
		}
		while(p>=i){
			p--;
		}
		if(p>=0){
			ans+=h[i]*(p+1)+pre[p];
			cnt+=p+1;
		}
	}
}
void solve(int u){
	vis[u]=1,dis[u]=0;
	h.clear(),h.eb(0);
	go(i,u){
		int v=e[i].to;
		if(vis[v]){
			continue;
		}
		g.clear();
		dis[v]=e[i].cw,dfs(v,u);
		calcD();
	}
	calcA();
	getSiz(u,0);
	go(i,u){
		int v=e[i].to;
		if(vis[v]){
			continue;
		}
		getRt(v,u,siz[v]);
		solve(rt);
	}
}
void Yorushika(){
	read(n,m);
	rep(i,1,n-1){
		int u,v,w;read(u,v,w);
		add(u,v,w),add(v,u,w);
		deg[u]++,deg[v]++;
	}
	rep(u,1,n){
		c[u]=inf;
		go(i,u){
			c[u]=min(c[u],e[i].cw);
		}
	}
	getRt(1,0,n);
	solve(rt);
	ans+=(1ll*n*(n-1)/2-cnt)*m;
	for(int i=2;i<=tot;i+=2){
		int u=e[i].to,v=e[i-1].to;
		if(e[i].cw>m){
			if(deg[u]+deg[v]==n){
				if(deg[u]==n-1){
					ans+=min(e[i].cw,c[u]+m)-m;
				}else if(deg[v]==n-1){
					ans+=min(e[i].cw,c[v]+m)-m;
				}else{
					ans+=min({e[i].cw,m+m+m,c[v]+m,c[u]+m})-m;
				}
			}else{
				ans+=min({e[i].cw,m+m,c[u]+m,c[v]+m})-m;
			}
		}
	}
	printf("%lld\n",ans);
}
signed main(){
	int t=1;
	//read(t);
	while(t--){
		Yorushika();
	}
}