线段树
介绍
线段树主要用于维护区间信息的数据结构,它可以在O(logN)
的复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。
实现代码
javascript
class SegmentTree {
constructor(inputArray, operation, operationFallback) {
this.inputArray = inputArray
// opration 用于区间操作,可以传入不同的函数,如求和,求最大值,求最小值等等。
this.operation = operation
this.operationFallback = operationFallback
this.segmentTree = this.initSegmentTree(inputArray)
this.buildSegmentTree()
}
// 初始化线段树数组
initSegmentTree(inputArray) {
let segmentTreeArrayLength
const inputArrayLength = inputArray.length
if (isPowerOfTwo(inputArrayLength)) {
segmentTreeArrayLength = 2 * inputArrayLength - 1
} else {
const power = Math.floor(Math.log(inputArrayLength)) + 1
segmentTreeArrayLength = 2 * 2 ** power - 1
}
return new Array(segmentTreeArrayLength).fill(null)
}
buildSegmentTree() {
// 初始化条件
const leftIndex = 0
const rightIndex = this.inputArray.length - 1
// 线段树数组的位置,从0开始
const position = 0
this.buildSegmentTreeRecursive(leftIndex, rightIndex, position)
}
buildSegmentTreeRecursive(leftIndex, rightIndex, position) {
// 边界条件
if (leftIndex === rightIndex) {
this.segmentTree[position] = this.inputArray[leftIndex]
return
}
// 递归过程
// 参数:leftIndex, rightIndex, position
// 结果:左边的值 + 右边的值
const middleIndex = Math.floor((leftIndex + rightIndex) / 2)
const leftChildPosition = this.getLeftChildIndex(position)
const rightChildPosition = this.getRightChildIndex(position)
this.buildSegmentTreeRecursive(leftIndex, middleIndex, leftChildPosition)
this.buildSegmentTreeRecursive(middleIndex + 1, rightIndex, rightChildPosition)
const res = this.operation(
this.segmentTree[leftChildPosition],
this.segmentTree[rightChildPosition],
)
// build 完后,左右节点的值就有了
this.segmentTree[position] = this.operation(
this.segmentTree[leftChildPosition],
this.segmentTree[rightChildPosition],
)
}
// 查询区间值
queryRange(queryLeftIndex, queryRightIndex) {
// 初始条件
const leftIndex = 0
const rightIndex = this.inputArray.length - 1
const position = 0
return this.queryRangeRecursive(
queryLeftIndex,
queryRightIndex,
leftIndex,
rightIndex,
position
)
}
// 递归过程和建树时相似,只不过需要比较查询区间
queryRangeRecursive(queryLeftIndex, queryRightIndex, leftIndex, rightIndex, position) {
// 边界条件1
if (queryLeftIndex <= leftIndex && queryRightIndex >= rightIndex) {
// 说明这个子区间包含在要查询的区间内部,返回子区间的值
return this.segmentTree[position]
}
// 边界条件2
if (queryLeftIndex > rightIndex || queryRightIndex < leftIndex) {
return this.operationFallback()
}
const middleIndex = Math.floor((leftIndex + rightIndex) / 2)
const leftChildPosition = this.getLeftChildIndex(position)
const rightChildPosition = this.getRightChildIndex(position)
const leftResult = this.queryRangeRecursive(
queryLeftIndex,
queryRightIndex,
leftIndex,
middleIndex,
leftChildPosition
)
const rightResult = this.queryRangeRecursive(
queryLeftIndex,
queryRightIndex,
middleIndex + 1,
rightIndex,
rightChildPosition
)
return this.operation(leftResult, rightResult)
}
// 左右节点和父节点的关系类似于 堆
getLeftChildIndex(position) {
return 2 * position + 1
}
getRightChildIndex(position) {
return 2 * position + 2
}
}
function isPowerOfTwo(number) {
if (number < 1) {
return false;
}
let dividedNumber = number;
while (dividedNumber !== 1) {
if (dividedNumber % 2 !== 0) {
return false;
}
dividedNumber /= 2;
}
return true;
}
运行结果
javascript
const seg = new SegmentTree([1, 2, 3, 4, 5, 6], (a, b) => a + b, () => 0)
console.log('seg.segmentTree ==> ', seg.segmentTree);
// [ 21, 6, 15, 3, 3, 9, 6, 1, 2, <2 empty items>, 4, 5 ]
console.log('seg ==> ', seg.queryRange(0, 3)); // 10
console.log('seg ==> ', seg.queryRange(1, 2)); // 5
console.log('seg ==> ', seg.queryRange(4, 6)); // 11