最近公共祖先,LCA

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>

using namespace std;

const int maxn = 4e4+50;
const int logtwo = 30;

struct edge{
    int to ,w ;
    edge(){}
    edge(int _w,int _to){ w = _w; to = _to; }
};

vector<edge>G[maxn];
int grand[maxn][logtwo],depth[maxn],gw[maxn][logtwo];
int n,m,N;

void dfs(int x) {
    for (int i = 1;i <= N ; i ++) {
        grand[x][i] = grand[grand[x][i-1]][i-1];
        gw[x][i] = gw[grand[x][i-1]][i-1] + gw[x][i-1];
    }
    int len = G[x].size();
    for (int i = 0;i < len ; i ++ ) {
        edge e = G[x][i];
        if ( grand[x][0] != e.to ) {
            depth[e.to] = depth[x] + 1 ;
            grand[e.to][0] = x; gw[e.to][0] = e.w;
            dfs(e.to);
        }
    }
}

void init() {
    N = floor( log(n + 0.0) / log(2.0) );
    memset(depth,0,sizeof(depth));
    memset(grand,0,sizeof(grand));
    memset(gw,0,sizeof(gw));
    dfs(1);
}

int lca(int a,int b) {
    if ( depth[a] > depth[b] ) swap(a,b);
    int ans = 0;
    for (int i = N;i >= 0;i--) {
        if ( depth[a] < depth[b] && depth[grand[b][i]] >= depth[a] )
            ans += gw[b][i] , b = grand[b][i];
    } 
    for (int j = N; j >= 0 ; j -- ) {
        if ( grand[a][j] != grand[b][j]  ) {
            ans += gw[a][j]; ans += gw[b][j];
            a = grand[a][j]; b = grand[b][j];
        }
    }
    if(a != b) ans += gw[a][0], ans += gw[b][0];
    return ans;
}

int main(){
    int t; cin>>t;
    while ( t -- ) {
        scanf("%d %d",&n,&m);
        for (int i = 0;i <= n;i++) G[i].clear();
        int u,v,w;
        for ( int i = 0; i < n-1 ; i ++ ) {
            scanf("%d %d %d",&u,&v,&w);
            G[u].push_back(edge(w,v));
            G[v].push_back(edge(w,u));
        }
        init();  int a ,b ;
        for (int i = 1;i <= m; i ++ ) {
            scanf("%d %d",&a,&b);
            printf("%d\n",lca(a,b));
        }
    }   
    return 0;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 4e4 + 50 ;
const int logtwo = 30 ;

struct edge {
    int nxt ,w ,to ;
}ss[maxn << 2];

int head[maxn] ,depth[maxn] ,father[maxn][logtwo] ,fw[maxn] ;
int n ,m ,cnt ,u ,v ,w ,s1 ,s2 ;

void init() {
    cnt = 0;
    memset(depth ,0 ,sizeof(depth)) ;
    memset(fw ,0 ,sizeof(fw)) ;
    memset(father ,0 ,sizeof(father)) ;
    memset(head ,-1 ,sizeof(head)) ;
}

void _add(int u,int v,int w) {
    ss[++cnt].w = w;
    ss[cnt].to = v;
    ss[cnt].nxt = head[u];
    head[u] = cnt ;
}

void add_edge(int u,int v,int w) {
    _add(u ,v ,w) ;
    _add(v ,u ,w ) ;
}

void dfs(int x) {
    for (int i = 1 ; (1 << i) <= depth[x] ; i ++ ) {
        father[x][i] = father[father[x][i-1]][i-1];
    }
    for (int i = head[x] ; ~i ; i = ss[i].nxt ) {
        int v = ss[i].to ;
        if ( v != father[x][0] ) {
            father[v][0] = x;
            depth[v] = depth[x] + 1;
            fw[v] = fw[x] + ss[i].w ;
            dfs(v ) ;
        }
    }
}

int lca(int a,int b) {
    if ( depth[a] > depth[b] ) swap(a ,b ) ;
    int dis = depth[b] - depth[a] ;
    for (int i = 0 ; (1<<i) <= dis ; i ++ ) {
        if ( dis & ( 1 << i) ) {
            b = father[b][i];
        }
    }
    if ( a == b ) return a;
    for (int i = 29 ; i >= 0 ; i -- ) {
        if ( father[a][i] != father[b][i] ) {
            a = father[a][i];
            b = father[b][i];
        }
    }
    return father[a][0];
}

int main() {
    int t; cin >> t ;
    while( t -- ) {
        init() ;
        scanf("%d %d",&n,&m);
        for (int i = 1 ; i < n  ; i ++ ) {
            scanf("%d %d %d",&u ,&v ,&w ) ;
            add_edge(u ,v ,w );
        }
        dfs(1) ;
        for (int i = 1 ; i <= m ; i ++ ) {
            scanf("%d %d",&s1 ,&s2 ) ;
            printf("%d\n",fw[s1] + fw[s2] - 2 * fw[lca(s1 ,s2 )] );
        } 
    }
    return 0;
}