1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
n = len(edges)+1
m = n.bit_length()
g = [[] for _ in range(n)]
for u,v,w in edges:
g[u].append((v,w))
g[v].append((u,w))
D = [0]*n
W = [0]*n
f = [[-1]*m for _ in range(n)]
sk = [0]
while sk:
u = sk.pop()
for i in range(m-1):
p = f[u][i]
if p!=-1:
f[u][i+1] = f[p][i]
for v,w in g[u]:
if v!=f[u][0]:
sk.append(v)
f[v][0] = u
D[v] = D[u]+1
W[v] = W[u]+w
def lca(x,y):
if D[x]>D[y]:
x,y = y,x
k = D[y]-D[x]
for i in range(k.bit_length()):
if k>>i&1:
y = f[y][i]
if x!=y:
for i in range(m-1,-1,-1):
px,py = f[x][i],f[y][i]
if px!=py:
x,y = px,py
x = f[x][0]
return x
|