blob: c3e3d85bdeaf05a8b7942a4bb7df6fa4eb862542 [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001// Copyright 2016 The CMux Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied. See the License for the specific language governing
13// permissions and limitations under the License.
14
15package cmux
16
17import (
18 "bytes"
19 "io"
20)
21
22// patriciaTree is a simple patricia tree that handles []byte instead of string
23// and cannot be changed after instantiation.
24type patriciaTree struct {
25 root *ptNode
26 maxDepth int // max depth of the tree.
27}
28
29func newPatriciaTree(bs ...[]byte) *patriciaTree {
30 max := 0
31 for _, b := range bs {
32 if max < len(b) {
33 max = len(b)
34 }
35 }
36 return &patriciaTree{
37 root: newNode(bs),
38 maxDepth: max + 1,
39 }
40}
41
42func newPatriciaTreeString(strs ...string) *patriciaTree {
43 b := make([][]byte, len(strs))
44 for i, s := range strs {
45 b[i] = []byte(s)
46 }
47 return newPatriciaTree(b...)
48}
49
50func (t *patriciaTree) matchPrefix(r io.Reader) bool {
51 buf := make([]byte, t.maxDepth)
52 n, _ := io.ReadFull(r, buf)
53 return t.root.match(buf[:n], true)
54}
55
56func (t *patriciaTree) match(r io.Reader) bool {
57 buf := make([]byte, t.maxDepth)
58 n, _ := io.ReadFull(r, buf)
59 return t.root.match(buf[:n], false)
60}
61
62type ptNode struct {
63 prefix []byte
64 next map[byte]*ptNode
65 terminal bool
66}
67
68func newNode(strs [][]byte) *ptNode {
69 if len(strs) == 0 {
70 return &ptNode{
71 prefix: []byte{},
72 terminal: true,
73 }
74 }
75
76 if len(strs) == 1 {
77 return &ptNode{
78 prefix: strs[0],
79 terminal: true,
80 }
81 }
82
83 p, strs := splitPrefix(strs)
84 n := &ptNode{
85 prefix: p,
86 }
87
88 nexts := make(map[byte][][]byte)
89 for _, s := range strs {
90 if len(s) == 0 {
91 n.terminal = true
92 continue
93 }
94 nexts[s[0]] = append(nexts[s[0]], s[1:])
95 }
96
97 n.next = make(map[byte]*ptNode)
98 for first, rests := range nexts {
99 n.next[first] = newNode(rests)
100 }
101
102 return n
103}
104
105func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) {
106 if len(bss) == 0 || len(bss[0]) == 0 {
107 return prefix, bss
108 }
109
110 if len(bss) == 1 {
111 return bss[0], [][]byte{{}}
112 }
113
114 for i := 0; ; i++ {
115 var cur byte
116 eq := true
117 for j, b := range bss {
118 if len(b) <= i {
119 eq = false
120 break
121 }
122
123 if j == 0 {
124 cur = b[i]
125 continue
126 }
127
128 if cur != b[i] {
129 eq = false
130 break
131 }
132 }
133
134 if !eq {
135 break
136 }
137
138 prefix = append(prefix, cur)
139 }
140
141 rest = make([][]byte, 0, len(bss))
142 for _, b := range bss {
143 rest = append(rest, b[len(prefix):])
144 }
145
146 return prefix, rest
147}
148
149func (n *ptNode) match(b []byte, prefix bool) bool {
150 l := len(n.prefix)
151 if l > 0 {
152 if l > len(b) {
153 l = len(b)
154 }
155 if !bytes.Equal(b[:l], n.prefix) {
156 return false
157 }
158 }
159
160 if n.terminal && (prefix || len(n.prefix) == len(b)) {
161 return true
162 }
163
164 if l >= len(b) {
165 return false
166 }
167
168 nextN, ok := n.next[b[l]]
169 if !ok {
170 return false
171 }
172
173 if l == len(b) {
174 b = b[l:l]
175 } else {
176 b = b[l+1:]
177 }
178 return nextN.match(b, prefix)
179}