甲乙小朋友的房子

甲乙小朋友很笨,但甲乙小朋友不会放弃

0%

算法-线段树

线段树入门

本文主要参考自JustDoIT线段树知识点总结

线段树,类似区间树,它在各个节点保存一条线段(数组中的一段子数组),主要用于高效解决连续区间的动态查询问题,由于二叉结构的特性,它基本能保持每个操作的复杂度为O(logn)。

线段树的每个节点表示一个区间,子节点则分别表示父节点的左右半区间,例如父亲的区间是[a,b],那么(c=(a+b)/2)左儿子的区间是[a,c],右儿子的区间是[c+1,b]。线段树形如:

下面我们从一个经典的例子来了解线段树,问题描述如下:从数组arr[0...n-1]中查找某个数组某个区间内的最大值,其中数组大小固定,但是数组中的元素的值可以随时更新。从这题可以看出:区间(a,b)的最大值和区间(b,c)的最大值中,取较大的就是区间(a,c)的最大值。很明显这个操作具有区间的性质。

我们可以用线段树来解决这个区间最大值问题。根据这个问题我们构造如下的二叉线段树。区间的第三维就是区间的最大值。

加入第三维的时候,只需要在构建完左右区间后,根据左右区间的最大值更新当前区间最大值即可。

因为每次将区间长度一分为二,所有构造的节点个数为:

n + 1/2 * n + 1/4 * n + 1/8 * n + ...

= (1 + 1/2 + 1/4 + 1/8 + ...) * n

= 2n

所以构造线段树的时空复杂度都为O(n)。

线段树常见题型

一道题可不可以用线段树来做,基本是看这道题的操作有没有区间的性质。也就是在一个区间上的操作是否可以转化为两个子区间上的操作。

  • 求区间和,积,最小值,gcd等
  • 以当前节点的值作为节点处理。例如给出N个数字,再给一个数,问比这个数大的有多少个。
  • 区间加减同一个值,或者区间同时赋一个值。

链式线段树

我们常见的二叉树都是链式结构。因此我们先完成链式的线段树。

建树

复杂度$O( n ) $

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/**
节点区间定义
(start, end) 代表节点的区间范围
max 是节点在(start, end)区间上的最大值
left, right 是当前节点区间划分后的左右节点区间
**/
public class SegmentTreeNode{ // 节点
int start, end;
int max;
SegmentTreeNode left = null, right = null;
public SegmentTreeNode(int start, int end, int max){
this.start = start;
this.end = end;
this.max = max;
}
}

// 构造树
public SegmentTreeNode build(int[] A){
return buildhelper(0, A.length - 1, A);
}

// 建树耗时O(n),空间O(n)
public SegmentTreeNode buildhelper(int low, int high, int[] A){
if(low > high) return null;
SegmentTreeNode root = new SegmentTreeNode(low, high, A[low]); // 根据节点区间左边界的值为节点赋初值
if(low == high) return root;
int mid = (low + high) / 2; // 划分当前区间的左右区间
root.left = buildhelper(low, mid, A);
root.right = buildhelper(mid + 1, high, A);
root.max = Math.max(root.left.max, root.right.max); // 更新当前节点值
return root;
}

一些变种:

1
2
3
4
5
//如果需要区间最小值
root.min = Math.min(root.left.min, root.right.min);

// 如果需要区间和
root.sum = root.left.sum + root.right.sum;

区间查询

复杂度 \(O(log(n))\)

构造线段树目的是为了更快地查询。例如给定区间,要求区间中的最大值。而线段树的区间查询操作就是将当前区间分解为较小的子区间,然后由子区间的最大值就可以快速得到需要查询区间的最大值。例如

1
query(1,3) = max(query(1,1), query(2,3)) = max(4,3) = 4

查询实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public int query(SegmentTreeNode root, int low, int high){
if(low == root.start && root.end == high){
// 如果查询区间就是当前节点区间,直接输出结果
return root.max;
}
int mid = (root.start + root.end) / 2; // 将当前节点区间分割为左右两个区间
int ans = Integer.MIN_VALUE;
/**
[root.left,...,mid,...,root.right]
[low, high]
**/
if(mid >= low){ // 如果查询区间与左子节点有交集,则查询区间的最大值有可能是左子节点的最大值。
ans = Math.max(ans, query(root.left, low, high));
}
if(mid + 1 <= high){ // 如果查询区间与右子节点有交集,则查询区间的最大值有可能是右子节点的最大值。
ans = Math.max(ans, query(root.right, low, high));
}
return ans;
}

单点更新

复杂度 \(O(log(n))\)

更新序列中的一个节点,那么如何把这种变化体现到线段树中呢?

例如要将第4个点更新为5.就要变动3个区间的值,分别为[3,3], [2,3], [0,3]

改动一个节点,与这个节点对应的叶子结点都要变动。并且,这个节点变动后,这个节点的属性值也有可能会变动,那么就有可能印象到这个节点的父亲节点的属性值(例如可能影响到最大值)。所以需要从叶子节点一路走到根节点。

单点更新实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
public void modify(SegmentTreeNode root, int idx, int val){
if(root.start == root.end && root.start == idx){ // 找到要改动的叶子节点
root.max = val;
return;
}
int mid = (root.start + root.end) / 2; // 将当前节点区间分割为2个区间
if(idx <= mid){ // 如果修改节点在左边
modify(root.left, idx, val);
}else{ // 如果修改节点在右边
modify(root.right, idx, val);
}
root.max = Math.max(root.right.max, root.left.max) ;// 更新
}

数组式线段树

由于输入的是数组,那么树节点个数不会变化,而且线段树是趋于完全二叉树的。因此我们可以考虑用数组形式的线段树。

建树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 节点定义与上面保持一致
// (start, end) 代表节点的区间范围
// max 是节点在(start, end)区间上的最大值
// left, right 是当前节点区间划分后的左右节点区间

public class SegmentTreeNode{
int start, end;
int max;
SegmentTreeNode left = null, right = null;
}
SegmentTreeNode[] SegmentTree; // 替代原来的root

i的左节点:SegmentTree[2*i + 1]
i的右节点:SegmentTree[2*i]

其它操作都差不多,省略

相关题

Range Sum Query - Mutable

给一个数组,求数组的i到j的和。而且数组的值实时更新。

老套路不行了,就超时了。得换新套路——线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
 // 线段树节点定义  
class SegmentTreeNode{
int start, end;
int sum = 0;
SegmentTreeNode left = null, right = null;
SegmentTreeNode(int start, int end){
this.start = start;
this.end = end;
}
}
// 树根
SegmentTreeNode root;

// 构造
public NumArray(int[] nums) {
root = NumArrayHelper(nums, 0, nums.length - 1);
}
private SegmentTreeNode NumArrayHelper(int[] nums, int low, int high){
if(low > high) return null;
SegmentTreeNode node = new SegmentTreeNode(low, high);
if(low == high){
node.sum = nums[low];
}else{
int mid = (low + high) / 2;
node.left = NumArrayHelper(nums, low, mid);
node.right = NumArrayHelper(nums, mid + 1, high);
node.sum = node.right.sum + node.left.sum;
}
return node;
}

// 更新节点
public void update(int i, int val) {
updateHelper(root, i, val);
}
private void updateHelper(SegmentTreeNode node,int idx, int val){
if(node.start == node.end){
if(node.start == idx) node.sum = val;
else System.out.println("error");
return;
}
int mid = (node.start + node.end) / 2;
if(idx <= mid){
updateHelper(node.left, idx, val);
}else{
updateHelper(node.right, idx, val);
}
node.sum = node.left.sum + node.right.sum;
}

// 求sum
public int sumRange(int i, int j) {
return sumRangeHelper(root, i, j);
}
private int sumRangeHelper(SegmentTreeNode node, int low, int high){
if(node == null || high < node.start || low > node.end) return 0;
if(node.start == low && node.end == high) return node.sum;
/**
* [node.start, ...mid..., node.end] node是一个较大的区间
* 1. [low, , high] 查找区间完全在左子区间
* 2. [low....high] 查找区间完全在右子区间
* 3. [low, ..........,high] 查找区间跨越左右子区间 **/
int mid = (node.start + node.end) / 2;
if(high <= mid){
return sumRangeHelper(node.left, low, high);
}else if(low > mid){
return sumRangeHelper(node.right,low, high);
}else{
return sumRangeHelper(node.left, low, mid) + sumRangeHelper(node.right, mid + 1, high);
}
}