加载中...

树状数组和线段树


前言

今天做每日一题的时候,看到是一道关于区间的困难题,想了一阵感觉有点像需要所谓的线段树来解,之前感觉线段树老麻烦了一直没有学明白,今天就尽可能整理清楚,看看大佬们的题解试着消化理解一下

首先针对区间的一系列问题,可以依据实际情况区分为以下几种:

1、数组不变,求区间和:「前缀和」、「树状数组」、「线段树」

2、多次修改某个区间,输出最终结果:「差分」

3、多次修改某个数,求区间和:「树状数组」、「线段树」

4、多次修改某个区间,求区间和:「线段树」、「树状数组」(看修改区间范围大小)

5、多次将某个区间变成同一个数,求区间和:「线段树」、「树状数组」(看修改区间范围大小)

因为线段树的代码一般会比较繁琐,所以除了第4、5种问题不得不用线段树时,才考虑线段树。根据上面总结的情况,对于区间和的问题,可以按照这样的顺序选择解法:

1、简单求区间和,用「前缀和」

2、多次将某个区间变成同一个数,用「线段树」

3、其他情况,用「树状数组」

前缀和很常见也很容易理解,重点是理解树状数组和线段树,下面先了解了解树状数组

树状数组

树状数组和线段树又类似的功能,如果简单地类比描述,可以认为树状数组是线段树的子集,即树状数组有的功能线段树一定有,反之则不一定。

首先我们假设一个简单的场景需求:要能够方便更新一个数组中某个元素的值并方便快速地返回数组某个区间的和值,这种情况就可以使用树状数组。我们直接给出树状数组的全部代码,先从整体感受一下树状数组的写法:

class NumArray {
        // 累加和
        int[] sums;
        // 更新后数组
        int[] nums;

        public NumArray(int[] nums) {
            // sum从1开始有效, 因为计算lowbit时,如果使用下标0会进入死循环
            this.sums = new int[nums.length + 1];
            this.nums = nums;
            // 初始化累加和数组
            for (int i = 0; i < nums.length; i++) {
                add(i, nums[i]);
            }
        }

        private void add(int index, int val) {
            // 下标+1
            int indexSum = index + 1;
            while (indexSum < sums.length) {
                sums[indexSum] = sums[indexSum] + val;
                indexSum += lowBit(indexSum);
            }
        }

        private int lowBit(int x) {
            return x & (-x);
        }

        public void update(int index, int val) {
            int indexSum = index + 1;
            while (indexSum < sums.length) {
                // 减去之前nums[index]的值, 加上新的值
                sums[indexSum] = sums[indexSum] - nums[index] + val;
                indexSum += lowBit(indexSum);
            }
            nums[index] = val;
        }

        public int sumRange(int left, int right) {
            return query(right + 1) - query(left);
        }

        public int query(int x) {
            int ans = 0;
            while (x != 0) {
                ans += sums[x];
                x -= lowBit(x);
            }
            return ans;
        }
    }

累加和数组

首先看一下累加和sums,图形化理解如下图所示:

其实从这张图片可以清晰看到,累加和数组就是有种树的性质,然后其实 前缀和也可以很容易从累加和得到,例如:
$$
prefixSum[6] = sum[4] + sum[6]
$$
初始化累加和的时候,对应于上图其实是整体流程从下往上的,一次遍历数组nums,判断nums[i]影响了哪些sum,比如,nums[0]影响了sum[1]、sum[2]、sum[4]、sum[8],而nums[2]影响了sum[3]、sum[4]、sum[8],代码段:

while (indexSum < sums.length) {
    sums[indexSum] = sums[indexSum] + val;
    indexSum += lowBit(indexSum);
}

就是依次更新当前nums[i]影响的sum[j]的值,这个lowBit()函数

private int lowBit(int x) {
    return x & (-x);
}

通过位运算的方式,找到x的最低位的1并只保留这个1,比如5,二进制101,-5二进制011(忽略符号位),则两者进行&运算之后,结果是001,从而lowBit(5) = 1;再比如,10的二进制1010,-10的二进制0110(忽略符号位),二者取&结果为0010,即lowBit(10) = 4

回到上面的代码,通过找到当前indexSum的只保留最低位1的那个数,并与其相加,从而得到**上层中受到nums[i]影响的sum[j]**,图解如下:

更新操作

如果此时我们要更新nums数组中index位置的元素,那么就可以如同上述的流程一样,依次从indexSum = index + 1开始,依次判断nums[index]影响了哪些sum[j],用lowBit()找到下一个上层sum[j]的位置,这部分代码就很清晰易懂了。

public void update(int index, int val) {
    int indexSum = index + 1;
    while (indexSum < sums.length) {
        // 减去之前nums[index]的值, 加上新的值
        sums[indexSum] = sums[indexSum] - nums[index] + val;
        indexSum += lowBit(indexSum);
    }
    nums[index] = val;
}

查询区间和

有了上面的基础,其实这部分就更好理解了,平常最常用的区间和是通过前缀和相减获得,这里可以首先将累加和sum转换成前缀和,也就是对应这部分代码

public int query(int x) {
    int ans = 0;
    while (x != 0) {
        ans += sums[x];
        x -= lowBit(x);
    }
    return ans;
}

与更新操作不同,这部分通过每次减去lowBit(x)得到前一部分和值的位置,直到减为零,图解如下:

小结

其实树状数组的关键逻辑非常简单,明白lowBit()函数以及通过相加、相减找下一层的元素的流程,就可以理解整体的功能实现了!

线段树

写了半天,总算是到这个折磨了我很久的线段树了。前面已经介绍过,线段树满足的需求除了求区间和,还可满足修改某一个区间的值,时间复杂度都是O(logn),线段树的基本结构如下:

这里对应的nums数组为[1, 2, 3, 4, 5],线段树的叶子节点就对应于nums数组的每一个元素,父节点表示其两边子节点的的元素值之和,因此每一个节点的值就对应底下的区间的和值

对于线段树的每一个节点,我们可以用一个Node类表示:

class Node {
    Node left, right;
    int val;
}

查询操作

线段树中查找某一个区间的,就是在线段树中从根节点向树的两边递归找到包含于查询区间的小区间,比如上图的线段树中查找[2,4]范围的区间和,那整个过程会如下图所示

沿着上述的路径,找到两个黄色的节点,此时两个黄色节点的区间范围之和正好等于要查询的区间范围,因此此时查询的结果为3 + 9 = 12

代码模板:

public int query(Node node, int start, int end, int l, int r) {
    // [start, end]为当前查找到的节点的区间范围,[l,r]为需要查询的区间范围,[l,r]保持不变
    //此时当前找到的节点区间不包含在要查询找的区间之中
    if(l > end || r < start) return 0;
    // 当[l,r]包含[start, end]时,直接返回当前节点的值
    else if (l <= start && end <= r) return node.val;
    // 把当前区间 [start, end] 均分得到左右孩子的区间范围
    int mid = (start + end) >> 1, ans = 0;
    // [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
    if (l <= mid) ans += query(node.left, start, mid, l, r);
    // [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
    if (r > mid) ans += query(node.right, mid + 1, end, l, r);
    // ans 把左右子树的结果都累加起来了,与树的后续遍历同理
    return ans;
}

更新操作

更新操作和查询操作其实与查询是非常类似的,毕竟要更新元素的值首先肯定得先查找到对应得位置,但是,这里有个小优化:

如上图,朴素得想法是,我们每次更新都要更新到叶子节点,比如更新[2,4]区间,那一般肯定是索引为2、3、4得这三个节点都要更新,但是,查询的时候我们查到黄色的节点就可以了,更新了[3,4]区间那么理论上[3,3]、[4,4]区间都会更新,因此没必要查找到诸如红色的叶子节点,因此,这样的线段树可以使用一个懒标记标明该节点的所有子节点都应该有更新,懒标记的值就是每个叶子节点需要更新的值。

改进后的Node节点:

class Node {
    Node left, right;
    int val;
    // 懒标记
    int add;
}

在需要遍历孩子节点的时候,就将该「懒标记」下推给子节点,下推懒标记的的代码模板如下,代码中leftNum表示左子树的节点个数,rightNum表示右子树的节点个数

private void pushDown(Node node, int leftNum, int rightNum) {
    //如果节点不存在左右孩子节点,那么我们就创建左右孩子节点
    if (node.left == null) node.left = new Node();
    if (node.right == null) node.right = new Node();
    // 如果 add 为 0,表示没有标记,直接返回
    if (node.add == 0) return ;
    // 当前节点的左右子节点加上对应的标记总和值
    node.left.val += node.add * leftNum;
    node.right.val += node.add * rightNum;
    // 把标记下推给孩子节点
    node.left.add = node.add;
    node.right.add = node.add;
    // 取消当前节点标记
    node.add = 0;
}

有了以上的基础,那么,更新操作的代码模板如下:

public void update(Node node, int start, int end, int l, int r, int val) {
    if (l <= start && end <= r) {
        // 区间节点加上更新值
        node.val += (end - start + 1) * val;
        // 添加懒标记
        node.add = val;
        return ;
    }
    int mid = (start + end) >> 1;
    // 下推标记
    pushDown(node, mid - start + 1, end - mid);
    
    if (l <= mid) update(node.left, start, mid, l, r, val);
    if (r > mid) update(node.right, mid + 1, end, l, r, val);
    // 向上更新
    pushUp(node);
}

这里有一个向上更新pushUp函数,这个函数里面并不是固定的,而是根据实际问题的具体需求而定,常见的有:

  • 数字之和
  • 最大值
  • 最大公因数 等

例如数字之和,则对应的代码为:

private void pushUp(Node node) {
    node.val = node.left.val + node.right.val;
}

其实这个就相当于在当前节点的递归的最后一步:更新完子节点后回到当前节点,更新当前节点的值

总体代码

根据上述的阐释,线段树的一种完整的模板代码如下所示:

public class SegmentTreeDynamic {
    
    class Node {
        Node left, right;
        int val, add;
    }
    
    private Node root = new Node();
    
    public void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            node.val += (end - start + 1) * val;
            node.add = val;
            return ;
        }
        int mid = (start + end) >> 1;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }
    
    public int query(Node node, int start, int end, int l, int r) {
        if(l > end || r < start) return 0;
        if (l <= start && end <= r) return node.val;
        int mid = (start + end) >> 1, ans = 0;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) ans += query(node.left, start, mid, l, r);
        if (r > mid) ans += query(node.right, mid + 1, end, l, r);
        return ans;
    }
    
    private void pushUp(Node node) {
        //对应于数值之和的实际问题
        node.val = node.left.val + node.right.val;
    }
    
    private void pushDown(Node node, int leftNum, int rightNum) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return ;
        node.left.val += node.add * leftNum;
        node.right.val += node.add * rightNum;
        node.left.add = node.add;
        node.right.add = node.add;
        node.add = 0;
    }
}

小结

经过一番整理,线段树的结构以及使用变得清晰许多,线段树依赖懒标记实现了更新数据的方便快捷性,由于线段树本身也是一颗实实在在的树,叶子节点对应nums数组的元素,因此,单纯只更新一个节点的值的时候,必须找到叶子节点去,更新这一条路径的所有节点,因此这也就是为什么此时的线段树和树状数组基本是一样的原因了,此时写树状数组当时是代码更简洁一些的


文章作者: DestiNation
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 DestiNation !
  目录