1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
n = len(edges)+1
m = n.bit_length()
g = [[] for _ in range(n)]
for u,v in edges:
g[u].append(v)
g[v].append(u)
D = [0]*n
f = [[-1]*m for _ in range(n)]
sk = [(0,-1)]
while sk:
u,fa = sk.pop()
for i in range(m-1):
p = f[u][i]
if p!=-1:
f[u][i+1] = f[p][i]
for v in g[u]:
if v!=fa:
sk.append((v,u))
f[v][0] = u
D[v] = D[u]+1
def lca(x,y):
if D[x]>D[y]:
x,y = y,x
k = D[y]-D[x]
for i in range(k.bit_length()):
if k>>i&1:
y = f[y][i]
if x!=y:
for i in range(m-1,-1,-1):
px,py = f[x][i],f[y][i]
if px!=py:
x,y = px,py
x = f[x][0]
return x
|