目录

3526:范围异或查询与子数组反转(★★)

力扣第 3526 题

题目

给定一个长度为 n 的整数数组 nums 和一个长度为 q 的二维整数数组 queries,其中的每个查询是以下三种类型之一:

  1. 更新queries[i] = [1, index, value]
    赋值 nums[index] = value

  2. 范围异或查询queries[i] = [2, left, right]
    计算 子数组 中所有元素的按位异或 nums[left...right],并记录结果。

  3. 反转 子数组queries[i] = [3, left, right]
    原地反转 nums[left...right] 子数组。

按照遇到的顺序返回所有范围异或查询的结果数组。

示例 1:

输入:nums = [1,2,3,4,5], queries = [[2,1,3],[1,2,10],[3,0,4],[2,0,4]]

输出:[5,8]

解释:

  • 查询 1:[2, 1, 3] – 计算 [2, 3, 4] 子数组的异或和,结果为 5。

  • 查询 2:[1, 2, 10] – 将 nums[2] 更新为 10,数组更新为 [1, 2, 10, 4, 5]

  • 查询 3:[3, 0, 4] – 反转整个数组,得到 [5, 4, 10, 2, 1]

  • 查询 4:[2, 0, 4] – 计算 [5, 4, 10, 2, 1] 子数组的异或和,结果为 8。

示例 2:

输入:nums = [7,8,9], queries = [[1,0,3],[2,0,2],[3,1,2]]

输出:[2]

解释:

  • 查询 1:[1, 0, 3] – 将 nums[0] 更新为 3,数组更新为 [3, 8, 9]

  • 查询 2:[2, 0, 2] – 计算 [3, 8, 9] 子数组的异或和,结果为 2。

  • 查询 3:[3, 1, 2] – 反转子数组 [8, 9],得到 [9, 8]

提示:

  • 1 <= nums.length <= 105
  • 0 <= nums[i] <= 109
  • 1 <= queries.length <= 105
  • queries[i].length == 3​
  • queries[i][0] ∈ {1, 2, 3}​
  • 如果 queries[i][0] == 1:
    • 0 <= index < nums.length​
    • 0 <= value <= 109
  • 如果 queries[i][0] == 2queries[i][0] == 3
    • 0 <= left <= right < nums.length​

分析

  • treap 树模板

解答

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class Node:
    __slots__ = ('l','r','prio','sz','val','xo','rev')
    def __init__(self, v):
        self.l = self.r = None
        self.prio = random.randrange(1 << 30)
        self.sz = 1
        self.val = v
        self.xo = v
        self.rev = 0

def pull(t):                    # 结构改变后更新信息
    t.sz = 1
    t.xo = t.val
    if t.l:
        t.sz += t.l.sz
        t.xo ^= t.l.xo
    if t.r:
        t.sz += t.r.sz
        t.xo ^= t.r.xo

def push(t):                    # 下传懒标记
    if t and t.rev:
        t.l, t.r = t.r, t.l
        if t.l: t.l.rev ^= 1
        if t.r: t.r.rev ^= 1
        t.rev = 0

def split(t,k):                 # 按前k个节点/剩余节点拆分
    if not t:
        return None,None
    push(t)
    lsz = t.l.sz if t.l else 0
    if lsz >= k:
        L,R = split(t.l,k)
        t.l = R
        pull(t)
        return L,t
    else:
        L,R = split(t.r,k-lsz-1)
        t.r = L
        pull(t)
        return t,R

def merge(u,v):                 # 合并,u中最大值小于v中最小值
    if not u or not v:
        return u or v
    if u.prio<v.prio:
        push(u)
        u.r = merge(u.r, v)
        pull(u)
        return u
    else:
        push(v)
        v.l = merge(u,v.l)
        pull(v)
        return v

class Solution:
    def getResults(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        root = None
        for v in nums:
            node = Node(v)
            root = merge(root, node)
        res = []
        for tp,x,y in queries:
            if tp == 1:
                L,R = split(root,x)
                u,R = split(R,1)
                u.val = y
                pull(u)
                root = merge(L,merge(u,R))
            elif tp == 2:
                L,R = split(root,x)
                u,R = split(R,y-x+1)
                res.append(u.xo)
                root = merge(L,merge(u,R))
            else:
                L,R = split(root,x)
                u,R = split(R,y-x+1)
                u.rev ^= 1
                root = merge(L,merge(u,R))
        return res

7538 ms