Skip to content

线段树

介绍

线段树主要用于维护区间信息的数据结构,它可以在O(logN)的复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

实现代码

javascript
class SegmentTree {
  constructor(inputArray, operation, operationFallback) {
    this.inputArray = inputArray
    // opration 用于区间操作,可以传入不同的函数,如求和,求最大值,求最小值等等。
    this.operation = operation
    this.operationFallback = operationFallback

    this.segmentTree = this.initSegmentTree(inputArray)
    this.buildSegmentTree()

  }

  // 初始化线段树数组
  initSegmentTree(inputArray) {
    let segmentTreeArrayLength
    const inputArrayLength = inputArray.length
    if (isPowerOfTwo(inputArrayLength)) {
      segmentTreeArrayLength = 2 * inputArrayLength - 1
    } else {
      const power = Math.floor(Math.log(inputArrayLength)) + 1
      segmentTreeArrayLength = 2 * 2 ** power - 1
    }
    return new Array(segmentTreeArrayLength).fill(null)
  }

  buildSegmentTree() {
    // 初始化条件
    const leftIndex = 0
    const rightIndex = this.inputArray.length - 1
    // 线段树数组的位置,从0开始
    const position = 0
    this.buildSegmentTreeRecursive(leftIndex, rightIndex, position)
  }

  buildSegmentTreeRecursive(leftIndex, rightIndex, position) {
    // 边界条件
    if (leftIndex === rightIndex) {
      this.segmentTree[position] = this.inputArray[leftIndex]
      return
    }

    // 递归过程
    // 参数:leftIndex, rightIndex, position
    // 结果:左边的值 + 右边的值
    const middleIndex = Math.floor((leftIndex + rightIndex) / 2)
    const leftChildPosition = this.getLeftChildIndex(position)
    const rightChildPosition = this.getRightChildIndex(position)

    this.buildSegmentTreeRecursive(leftIndex, middleIndex, leftChildPosition)
    this.buildSegmentTreeRecursive(middleIndex + 1, rightIndex, rightChildPosition)

    const res = this.operation(
      this.segmentTree[leftChildPosition],
      this.segmentTree[rightChildPosition],
    )

    // build 完后,左右节点的值就有了
    this.segmentTree[position] = this.operation(
      this.segmentTree[leftChildPosition],
      this.segmentTree[rightChildPosition],
    )
  }

  // 查询区间值
  queryRange(queryLeftIndex, queryRightIndex) {
    // 初始条件
    const leftIndex = 0
    const rightIndex = this.inputArray.length - 1
    const position = 0
    return this.queryRangeRecursive(
      queryLeftIndex,
      queryRightIndex,
      leftIndex,
      rightIndex,
      position
    )
  }

  // 递归过程和建树时相似,只不过需要比较查询区间
  queryRangeRecursive(queryLeftIndex, queryRightIndex, leftIndex, rightIndex, position) {
    // 边界条件1
    if (queryLeftIndex <= leftIndex && queryRightIndex >= rightIndex) {
      // 说明这个子区间包含在要查询的区间内部,返回子区间的值
      return this.segmentTree[position]
    }

    // 边界条件2
    if (queryLeftIndex > rightIndex || queryRightIndex < leftIndex) {
      return this.operationFallback()
    }

    const middleIndex = Math.floor((leftIndex + rightIndex) / 2)
    const leftChildPosition = this.getLeftChildIndex(position)
    const rightChildPosition = this.getRightChildIndex(position)
    const leftResult = this.queryRangeRecursive(
      queryLeftIndex,
      queryRightIndex,
      leftIndex,
      middleIndex,
      leftChildPosition
    )
    const rightResult = this.queryRangeRecursive(
      queryLeftIndex,
      queryRightIndex,
      middleIndex + 1,
      rightIndex,
      rightChildPosition
    )
    return this.operation(leftResult, rightResult)
  }

  // 左右节点和父节点的关系类似于 堆
  getLeftChildIndex(position) {
    return 2 * position + 1
  }
  getRightChildIndex(position) {
    return 2 * position + 2
  }
}

function isPowerOfTwo(number) {
  if (number < 1) {
    return false;
  }
  let dividedNumber = number;
  while (dividedNumber !== 1) {
    if (dividedNumber % 2 !== 0) {
      return false;
    }
    dividedNumber /= 2;
  }
  return true;
}

运行结果

javascript
const seg = new SegmentTree([1, 2, 3, 4, 5, 6], (a, b) => a + b, () => 0)
console.log('seg.segmentTree ==> ', seg.segmentTree);
// [ 21, 6, 15, 3, 3, 9, 6, 1, 2, <2 empty items>, 4, 5 ]
console.log('seg ==> ', seg.queryRange(0, 3)); // 10
console.log('seg ==> ', seg.queryRange(1, 2)); // 5
console.log('seg ==> ', seg.queryRange(4, 6)); // 11