目录

3549:两个多项式相乘(★★)

力扣第 3549 题

题目

给定两个整数数组 poly1poly2,其中每个数组中下标 i 的元素表示多项式中 xi 的系数。

A(x)B(x) 分别是 poly1poly2 表示的多项式。

返回一个长度为 (poly1.length + poly2.length - 1) 的整数数组 result 表示乘积多项式 R(x) = A(x) * B(x) 的系数,其中 result[i] 表示 R(x)xi 的系数。

示例 1:

输入:poly1 = [3,2,5], poly2 = [1,4]

输出:[3,14,13,20]

解释:

  • A(x) = 3 + 2x + 5x2B(x) = 1 + 4x
  • R(x) = (3 + 2x + 5x2) * (1 + 4x)
  • R(x) = 3 * 1 + (3 * 4 + 2 * 1)x + (2 * 4 + 5 * 1)x2 + (5 * 4)x3
  • R(x) = 3 + 14x + 13x2 + 20x3
  • 因此,result = [3, 14, 13, 20]

示例 2:

输入:poly1 = [1,0,-2], poly2 = [-1]

输出:[-1,0,2]

解释:

  • A(x) = 1 + 0x - 2x2B(x) = -1
  • R(x) = (1 + 0x - 2x2) * (-1)
  • R(x) = -1 + 0x + 2x2
  • 因此,result = [-1, 0, 2]

示例 3:

输入:poly1 = [1,5,-3], poly2 = [-4,2,0]

输出:[-4,-18,22,-6,0]

解释:

  • A(x) = 1 + 5x - 3x2B(x) = -4 + 2x + 0x2
  • R(x) = (1 + 5x - 3x2) * (-4 + 2x + 0x2)
  • R(x) = 1 * -4 + (1 * 2 + 5 * -4)x + (5 * 2 + -3 * -4)x2 + (-3 * 2)x3 + 0x4
  • R(x) = -4 -18x + 22x2 -6x3 + 0x4
  • 因此,result = [-4, -18, 22, -6, 0]

提示:

  • 1 <= poly1.length, poly2.length <= 5 * 104
  • -103 <= poly1[i], poly2[i] <= 103
  • poly1poly2 至少包含一个非零系数。

分析

  • fft 模板

解答

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math
I = complex(0,1)

def fft(A,N,sgn=1):
    L = N.bit_length()-1
    rev = [0]*N
    for i in range(N):
        j = rev[i] = (rev[i >> 1] >> 1) | (i&1)*(N>>1)
        if i<j:
            A[i],A[j] = A[j],A[i]
    for i in range(L):
        a = 1<<i
        step = math.e**(math.pi/a*sgn*I)    
        for j in range(0,N,a*2):
            w = 1
            for k in range(j,j+a):
                x,y = A[k],A[k+a]*w
                A[k],A[k+a] = x+y,x-y
                w *= step
    if sgn==-1:
        for i,x in enumerate(A):
            A[i] = round(x.real/N)

class Solution:
    def multiply(self, poly1: List[int], poly2: List[int]) -> List[int]:
        A,B = poly1,poly2
        n = len(A)+len(B)-1
        N = 1<<n.bit_length()
        A += [0]*(N-len(A))
        B += [0]*(N-len(B))
        fft(A,N)
        fft(B,N)
        C = [a*b for a,b in zip(A,B)]
        fft(C,N,-1)
        return C[:n]

3480 ms