1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
mod = 998244353
G = 3
def ntt(A,N,sgn=1):
L = N.bit_length()-1
rev = [0]*N
for i in range(N):
j = rev[i] = (rev[i >> 1] >> 1) | (i&1)*(N>>1)
if i<j:
A[i],A[j] = A[j],A[i]
for i in range(L):
a = 1<<i
step = pow(G,(mod-1)//(a*2)*sgn+mod-1,mod)
for j in range(0,N,a*2):
w = 1
for k in range(j,j+a):
x,y = A[k],A[k+a]*w%mod
A[k],A[k+a] = (x+y)%mod,(x-y)%mod
w = w*step%mod
if sgn==-1:
inv = pow(N,-1,mod)
for i,x in enumerate(A):
A[i] = x*inv%mod
|