先假设对于所有 之间都连一条长度为 的边,则问题转换成:只保留树边,记 表示 ,求 。
考虑套路点分治,算出 ,则 。算的方式就是将当前处理的以 为根的子树中所有点 的 排序后双指针。
这一部分是一个板子,不做赘述。只需要注意点分治常见错误,如求重心某些地方写成 而不是 即可(我怎么又踩坑啊)。
但是注意到如果有树边 ,则实际上不会再连 ,但是此时直接走 不一定是最短路。于是分情况讨论:
此时不管怎么都是不能只走非树边的,但是可能先走非树边到另一个点 再走到 。于是令 , 边权为 ,则 。
此时由于没有点同时不与 相邻,所以也无法走两条非树边到。但是发现可以走三条到,或者先走非树边到另一个点,再一步到 。则有 。
此时可以走两条非树边到,但是也不能漏了 时就可行的一些情况,有 。
于是解决了。时间复杂度 。
code:
int n,m,rt,siz[N],deg[N],c[N];
bool vis[N];
ll cnt,ans,dis[N];
vector<ll> g,h;
int tot,head[N];
struct node{
int to,nxt,cw;
}e[N<<1];
il void add(int u,int v,int w){
e[++tot]={v,head[u],w},head[u]=tot;
}
void getRt(int u,int f,int s){
siz[u]=1;
int mx=0;
go(i,u){
int v=e[i].to;
if(v==f||vis[v]){
continue;
}
getRt(v,u,s);
siz[u]+=siz[v];
mx=max(mx,siz[v]);
}
mx=max(mx,s-siz[u]);
if(mx<=s/2){
rt=u;
}
}
void getSiz(int u,int f){
siz[u]=1;
go(i,u){
int v=e[i].to;
if(v==f||vis[v]){
continue;
}
getSiz(v,u);
siz[u]+=siz[v];
}
}
void dfs(int u,int f){
g.eb(dis[u]),h.eb(dis[u]);
go(i,u){
int v=e[i].to;
if(v==f||vis[v]){
continue;
}
dis[v]=dis[u]+e[i].cw;
dfs(v,u);
}
}
void calcD(){
sort(g.begin(),g.end());
int p=-1;
vector<ll> pre;
pre.eb(*g.begin());
rep(i,1,(int)g.size()-1){
pre.eb(g[i]+pre[i-1]);
}
drep(i,(int)g.size()-1,0){
while(p<i-1&&g[p+1]+g[i]<m){
p++;
}
while(p>=i){
p--;
}
if(p>=0){
ans-=g[i]*(p+1)+pre[p];
cnt-=p+1;
}
}
}
void calcA(){
sort(h.begin(),h.end());
int p=-1;
vector<ll> pre;
pre.eb(*h.begin());
rep(i,1,(int)h.size()-1){
pre.eb(h[i]+pre[i-1]);
}
drep(i,(int)h.size()-1,0){
while(p<i-1&&h[p+1]+h[i]<m){
p++;
}
while(p>=i){
p--;
}
if(p>=0){
ans+=h[i]*(p+1)+pre[p];
cnt+=p+1;
}
}
}
void solve(int u){
vis[u]=1,dis[u]=0;
h.clear(),h.eb(0);
go(i,u){
int v=e[i].to;
if(vis[v]){
continue;
}
g.clear();
dis[v]=e[i].cw,dfs(v,u);
calcD();
}
calcA();
getSiz(u,0);
go(i,u){
int v=e[i].to;
if(vis[v]){
continue;
}
getRt(v,u,siz[v]);
solve(rt);
}
}
void Yorushika(){
read(n,m);
rep(i,1,n-1){
int u,v,w;read(u,v,w);
add(u,v,w),add(v,u,w);
deg[u]++,deg[v]++;
}
rep(u,1,n){
c[u]=inf;
go(i,u){
c[u]=min(c[u],e[i].cw);
}
}
getRt(1,0,n);
solve(rt);
ans+=(1ll*n*(n-1)/2-cnt)*m;
for(int i=2;i<=tot;i+=2){
int u=e[i].to,v=e[i-1].to;
if(e[i].cw>m){
if(deg[u]+deg[v]==n){
if(deg[u]==n-1){
ans+=min(e[i].cw,c[u]+m)-m;
}else if(deg[v]==n-1){
ans+=min(e[i].cw,c[v]+m)-m;
}else{
ans+=min({e[i].cw,m+m+m,c[v]+m,c[u]+m})-m;
}
}else{
ans+=min({e[i].cw,m+m,c[u]+m,c[v]+m})-m;
}
}
}
printf("%lld\n",ans);
}
signed main(){
int t=1;
//read(t);
while(t--){
Yorushika();
}
}