blob: ec302e4a7a963420be6bc31eff93c0478baa0639 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package adt
16
17import (
18 "bytes"
19 "math"
20)
21
22// Comparable is an interface for trichotomic comparisons.
23type Comparable interface {
24 // Compare gives the result of a 3-way comparison
25 // a.Compare(b) = 1 => a > b
26 // a.Compare(b) = 0 => a == b
27 // a.Compare(b) = -1 => a < b
28 Compare(c Comparable) int
29}
30
31type rbcolor int
32
33const (
34 black rbcolor = iota
35 red
36)
37
38// Interval implements a Comparable interval [begin, end)
39// TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
40type Interval struct {
41 Begin Comparable
42 End Comparable
43}
44
45// Compare on an interval gives == if the interval overlaps.
46func (ivl *Interval) Compare(c Comparable) int {
47 ivl2 := c.(*Interval)
48 ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
49 ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
50 iveCmpBegin := ivl.End.Compare(ivl2.Begin)
51
52 // ivl is left of ivl2
53 if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
54 return -1
55 }
56
57 // iv is right of iv2
58 if ivbCmpEnd >= 0 {
59 return 1
60 }
61
62 return 0
63}
64
65type intervalNode struct {
66 // iv is the interval-value pair entry.
67 iv IntervalValue
68 // max endpoint of all descendent nodes.
69 max Comparable
70 // left and right are sorted by low endpoint of key interval
71 left, right *intervalNode
72 // parent is the direct ancestor of the node
73 parent *intervalNode
74 c rbcolor
75}
76
77func (x *intervalNode) color() rbcolor {
78 if x == nil {
79 return black
80 }
81 return x.c
82}
83
84func (n *intervalNode) height() int {
85 if n == nil {
86 return 0
87 }
88 ld := n.left.height()
89 rd := n.right.height()
90 if ld < rd {
91 return rd + 1
92 }
93 return ld + 1
94}
95
96func (x *intervalNode) min() *intervalNode {
97 for x.left != nil {
98 x = x.left
99 }
100 return x
101}
102
103// successor is the next in-order node in the tree
104func (x *intervalNode) successor() *intervalNode {
105 if x.right != nil {
106 return x.right.min()
107 }
108 y := x.parent
109 for y != nil && x == y.right {
110 x = y
111 y = y.parent
112 }
113 return y
114}
115
116// updateMax updates the maximum values for a node and its ancestors
117func (x *intervalNode) updateMax() {
118 for x != nil {
119 oldmax := x.max
120 max := x.iv.Ivl.End
121 if x.left != nil && x.left.max.Compare(max) > 0 {
122 max = x.left.max
123 }
124 if x.right != nil && x.right.max.Compare(max) > 0 {
125 max = x.right.max
126 }
127 if oldmax.Compare(max) == 0 {
128 break
129 }
130 x.max = max
131 x = x.parent
132 }
133}
134
135type nodeVisitor func(n *intervalNode) bool
136
137// visit will call a node visitor on each node that overlaps the given interval
138func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
139 if x == nil {
140 return true
141 }
142 v := iv.Compare(&x.iv.Ivl)
143 switch {
144 case v < 0:
145 if !x.left.visit(iv, nv) {
146 return false
147 }
148 case v > 0:
149 maxiv := Interval{x.iv.Ivl.Begin, x.max}
150 if maxiv.Compare(iv) == 0 {
151 if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
152 return false
153 }
154 }
155 default:
156 if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
157 return false
158 }
159 }
160 return true
161}
162
163type IntervalValue struct {
164 Ivl Interval
165 Val interface{}
166}
167
168// IntervalTree represents a (mostly) textbook implementation of the
169// "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
170// and chapter 14.3 interval tree with search supporting "stabbing queries".
171type IntervalTree struct {
172 root *intervalNode
173 count int
174}
175
176// Delete removes the node with the given interval from the tree, returning
177// true if a node is in fact removed.
178func (ivt *IntervalTree) Delete(ivl Interval) bool {
179 z := ivt.find(ivl)
180 if z == nil {
181 return false
182 }
183
184 y := z
185 if z.left != nil && z.right != nil {
186 y = z.successor()
187 }
188
189 x := y.left
190 if x == nil {
191 x = y.right
192 }
193 if x != nil {
194 x.parent = y.parent
195 }
196
197 if y.parent == nil {
198 ivt.root = x
199 } else {
200 if y == y.parent.left {
201 y.parent.left = x
202 } else {
203 y.parent.right = x
204 }
205 y.parent.updateMax()
206 }
207 if y != z {
208 z.iv = y.iv
209 z.updateMax()
210 }
211
212 if y.color() == black && x != nil {
213 ivt.deleteFixup(x)
214 }
215
216 ivt.count--
217 return true
218}
219
220func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
221 for x != ivt.root && x.color() == black && x.parent != nil {
222 if x == x.parent.left {
223 w := x.parent.right
224 if w.color() == red {
225 w.c = black
226 x.parent.c = red
227 ivt.rotateLeft(x.parent)
228 w = x.parent.right
229 }
230 if w == nil {
231 break
232 }
233 if w.left.color() == black && w.right.color() == black {
234 w.c = red
235 x = x.parent
236 } else {
237 if w.right.color() == black {
238 w.left.c = black
239 w.c = red
240 ivt.rotateRight(w)
241 w = x.parent.right
242 }
243 w.c = x.parent.color()
244 x.parent.c = black
245 w.right.c = black
246 ivt.rotateLeft(x.parent)
247 x = ivt.root
248 }
249 } else {
250 // same as above but with left and right exchanged
251 w := x.parent.left
252 if w.color() == red {
253 w.c = black
254 x.parent.c = red
255 ivt.rotateRight(x.parent)
256 w = x.parent.left
257 }
258 if w == nil {
259 break
260 }
261 if w.left.color() == black && w.right.color() == black {
262 w.c = red
263 x = x.parent
264 } else {
265 if w.left.color() == black {
266 w.right.c = black
267 w.c = red
268 ivt.rotateLeft(w)
269 w = x.parent.left
270 }
271 w.c = x.parent.color()
272 x.parent.c = black
273 w.left.c = black
274 ivt.rotateRight(x.parent)
275 x = ivt.root
276 }
277 }
278 }
279 if x != nil {
280 x.c = black
281 }
282}
283
284// Insert adds a node with the given interval into the tree.
285func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
286 var y *intervalNode
287 z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
288 x := ivt.root
289 for x != nil {
290 y = x
291 if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
292 x = x.left
293 } else {
294 x = x.right
295 }
296 }
297
298 z.parent = y
299 if y == nil {
300 ivt.root = z
301 } else {
302 if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
303 y.left = z
304 } else {
305 y.right = z
306 }
307 y.updateMax()
308 }
309 z.c = red
310 ivt.insertFixup(z)
311 ivt.count++
312}
313
314func (ivt *IntervalTree) insertFixup(z *intervalNode) {
315 for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
316 if z.parent == z.parent.parent.left {
317 y := z.parent.parent.right
318 if y.color() == red {
319 y.c = black
320 z.parent.c = black
321 z.parent.parent.c = red
322 z = z.parent.parent
323 } else {
324 if z == z.parent.right {
325 z = z.parent
326 ivt.rotateLeft(z)
327 }
328 z.parent.c = black
329 z.parent.parent.c = red
330 ivt.rotateRight(z.parent.parent)
331 }
332 } else {
333 // same as then with left/right exchanged
334 y := z.parent.parent.left
335 if y.color() == red {
336 y.c = black
337 z.parent.c = black
338 z.parent.parent.c = red
339 z = z.parent.parent
340 } else {
341 if z == z.parent.left {
342 z = z.parent
343 ivt.rotateRight(z)
344 }
345 z.parent.c = black
346 z.parent.parent.c = red
347 ivt.rotateLeft(z.parent.parent)
348 }
349 }
350 }
351 ivt.root.c = black
352}
353
354// rotateLeft moves x so it is left of its right child
355func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
356 y := x.right
357 x.right = y.left
358 if y.left != nil {
359 y.left.parent = x
360 }
361 x.updateMax()
362 ivt.replaceParent(x, y)
363 y.left = x
364 y.updateMax()
365}
366
367// rotateLeft moves x so it is right of its left child
368func (ivt *IntervalTree) rotateRight(x *intervalNode) {
369 if x == nil {
370 return
371 }
372 y := x.left
373 x.left = y.right
374 if y.right != nil {
375 y.right.parent = x
376 }
377 x.updateMax()
378 ivt.replaceParent(x, y)
379 y.right = x
380 y.updateMax()
381}
382
383// replaceParent replaces x's parent with y
384func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
385 y.parent = x.parent
386 if x.parent == nil {
387 ivt.root = y
388 } else {
389 if x == x.parent.left {
390 x.parent.left = y
391 } else {
392 x.parent.right = y
393 }
394 x.parent.updateMax()
395 }
396 x.parent = y
397}
398
399// Len gives the number of elements in the tree
400func (ivt *IntervalTree) Len() int { return ivt.count }
401
402// Height is the number of levels in the tree; one node has height 1.
403func (ivt *IntervalTree) Height() int { return ivt.root.height() }
404
405// MaxHeight is the expected maximum tree height given the number of nodes
406func (ivt *IntervalTree) MaxHeight() int {
407 return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
408}
409
410// IntervalVisitor is used on tree searches; return false to stop searching.
411type IntervalVisitor func(n *IntervalValue) bool
412
413// Visit calls a visitor function on every tree node intersecting the given interval.
414// It will visit each interval [x, y) in ascending order sorted on x.
415func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
416 ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
417}
418
419// find the exact node for a given interval
420func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
421 f := func(n *intervalNode) bool {
422 if n.iv.Ivl != ivl {
423 return true
424 }
425 ret = n
426 return false
427 }
428 ivt.root.visit(&ivl, f)
429 return ret
430}
431
432// Find gets the IntervalValue for the node matching the given interval
433func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
434 n := ivt.find(ivl)
435 if n == nil {
436 return nil
437 }
438 return &n.iv
439}
440
441// Intersects returns true if there is some tree node intersecting the given interval.
442func (ivt *IntervalTree) Intersects(iv Interval) bool {
443 x := ivt.root
444 for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
445 if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
446 x = x.left
447 } else {
448 x = x.right
449 }
450 }
451 return x != nil
452}
453
454// Contains returns true if the interval tree's keys cover the entire given interval.
455func (ivt *IntervalTree) Contains(ivl Interval) bool {
456 var maxEnd, minBegin Comparable
457
458 isContiguous := true
459 ivt.Visit(ivl, func(n *IntervalValue) bool {
460 if minBegin == nil {
461 minBegin = n.Ivl.Begin
462 maxEnd = n.Ivl.End
463 return true
464 }
465 if maxEnd.Compare(n.Ivl.Begin) < 0 {
466 isContiguous = false
467 return false
468 }
469 if n.Ivl.End.Compare(maxEnd) > 0 {
470 maxEnd = n.Ivl.End
471 }
472 return true
473 })
474
475 return isContiguous && minBegin != nil && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0
476}
477
478// Stab returns a slice with all elements in the tree intersecting the interval.
479func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
480 if ivt.count == 0 {
481 return nil
482 }
483 f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
484 ivt.Visit(iv, f)
485 return ivs
486}
487
488// Union merges a given interval tree into the receiver.
489func (ivt *IntervalTree) Union(inIvt IntervalTree, ivl Interval) {
490 f := func(n *IntervalValue) bool {
491 ivt.Insert(n.Ivl, n.Val)
492 return true
493 }
494 inIvt.Visit(ivl, f)
495}
496
497type StringComparable string
498
499func (s StringComparable) Compare(c Comparable) int {
500 sc := c.(StringComparable)
501 if s < sc {
502 return -1
503 }
504 if s > sc {
505 return 1
506 }
507 return 0
508}
509
510func NewStringInterval(begin, end string) Interval {
511 return Interval{StringComparable(begin), StringComparable(end)}
512}
513
514func NewStringPoint(s string) Interval {
515 return Interval{StringComparable(s), StringComparable(s + "\x00")}
516}
517
518// StringAffineComparable treats "" as > all other strings
519type StringAffineComparable string
520
521func (s StringAffineComparable) Compare(c Comparable) int {
522 sc := c.(StringAffineComparable)
523
524 if len(s) == 0 {
525 if len(sc) == 0 {
526 return 0
527 }
528 return 1
529 }
530 if len(sc) == 0 {
531 return -1
532 }
533
534 if s < sc {
535 return -1
536 }
537 if s > sc {
538 return 1
539 }
540 return 0
541}
542
543func NewStringAffineInterval(begin, end string) Interval {
544 return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
545}
546func NewStringAffinePoint(s string) Interval {
547 return NewStringAffineInterval(s, s+"\x00")
548}
549
550func NewInt64Interval(a int64, b int64) Interval {
551 return Interval{Int64Comparable(a), Int64Comparable(b)}
552}
553
554func NewInt64Point(a int64) Interval {
555 return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
556}
557
558type Int64Comparable int64
559
560func (v Int64Comparable) Compare(c Comparable) int {
561 vc := c.(Int64Comparable)
562 cmp := v - vc
563 if cmp < 0 {
564 return -1
565 }
566 if cmp > 0 {
567 return 1
568 }
569 return 0
570}
571
572// BytesAffineComparable treats empty byte arrays as > all other byte arrays
573type BytesAffineComparable []byte
574
575func (b BytesAffineComparable) Compare(c Comparable) int {
576 bc := c.(BytesAffineComparable)
577
578 if len(b) == 0 {
579 if len(bc) == 0 {
580 return 0
581 }
582 return 1
583 }
584 if len(bc) == 0 {
585 return -1
586 }
587
588 return bytes.Compare(b, bc)
589}
590
591func NewBytesAffineInterval(begin, end []byte) Interval {
592 return Interval{BytesAffineComparable(begin), BytesAffineComparable(end)}
593}
594func NewBytesAffinePoint(b []byte) Interval {
595 be := make([]byte, len(b)+1)
596 copy(be, b)
597 be[len(b)] = 0
598 return NewBytesAffineInterval(b, be)
599}