blob: c2b7b30dd9177e669196bc8e336535654b62fb16 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001package utilities
2
3import (
4 "sort"
5)
6
7// DoubleArray is a Double Array implementation of trie on sequences of strings.
8type DoubleArray struct {
9 // Encoding keeps an encoding from string to int
10 Encoding map[string]int
11 // Base is the base array of Double Array
12 Base []int
13 // Check is the check array of Double Array
14 Check []int
15}
16
17// NewDoubleArray builds a DoubleArray from a set of sequences of strings.
18func NewDoubleArray(seqs [][]string) *DoubleArray {
19 da := &DoubleArray{Encoding: make(map[string]int)}
20 if len(seqs) == 0 {
21 return da
22 }
23
24 encoded := registerTokens(da, seqs)
25 sort.Sort(byLex(encoded))
26
27 root := node{row: -1, col: -1, left: 0, right: len(encoded)}
28 addSeqs(da, encoded, 0, root)
29
30 for i := len(da.Base); i > 0; i-- {
31 if da.Check[i-1] != 0 {
32 da.Base = da.Base[:i]
33 da.Check = da.Check[:i]
34 break
35 }
36 }
37 return da
38}
39
40func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
41 var result [][]int
42 for _, seq := range seqs {
43 var encoded []int
44 for _, token := range seq {
45 if _, ok := da.Encoding[token]; !ok {
46 da.Encoding[token] = len(da.Encoding)
47 }
48 encoded = append(encoded, da.Encoding[token])
49 }
50 result = append(result, encoded)
51 }
52 for i := range result {
53 result[i] = append(result[i], len(da.Encoding))
54 }
55 return result
56}
57
58type node struct {
59 row, col int
60 left, right int
61}
62
63func (n node) value(seqs [][]int) int {
64 return seqs[n.row][n.col]
65}
66
67func (n node) children(seqs [][]int) []*node {
68 var result []*node
69 lastVal := int(-1)
70 last := new(node)
71 for i := n.left; i < n.right; i++ {
72 if lastVal == seqs[i][n.col+1] {
73 continue
74 }
75 last.right = i
76 last = &node{
77 row: i,
78 col: n.col + 1,
79 left: i,
80 }
81 result = append(result, last)
82 }
83 last.right = n.right
84 return result
85}
86
87func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
88 ensureSize(da, pos)
89
90 children := n.children(seqs)
91 var i int
92 for i = 1; ; i++ {
93 ok := func() bool {
94 for _, child := range children {
95 code := child.value(seqs)
96 j := i + code
97 ensureSize(da, j)
98 if da.Check[j] != 0 {
99 return false
100 }
101 }
102 return true
103 }()
104 if ok {
105 break
106 }
107 }
108 da.Base[pos] = i
109 for _, child := range children {
110 code := child.value(seqs)
111 j := i + code
112 da.Check[j] = pos + 1
113 }
114 terminator := len(da.Encoding)
115 for _, child := range children {
116 code := child.value(seqs)
117 if code == terminator {
118 continue
119 }
120 j := i + code
121 addSeqs(da, seqs, j, *child)
122 }
123}
124
125func ensureSize(da *DoubleArray, i int) {
126 for i >= len(da.Base) {
127 da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
128 da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
129 }
130}
131
132type byLex [][]int
133
134func (l byLex) Len() int { return len(l) }
135func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
136func (l byLex) Less(i, j int) bool {
137 si := l[i]
138 sj := l[j]
139 var k int
140 for k = 0; k < len(si) && k < len(sj); k++ {
141 if si[k] < sj[k] {
142 return true
143 }
144 if si[k] > sj[k] {
145 return false
146 }
147 }
148 if k < len(sj) {
149 return true
150 }
151 return false
152}
153
154// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
155func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
156 if len(da.Base) == 0 {
157 return false
158 }
159
160 var i int
161 for _, t := range seq {
162 code, ok := da.Encoding[t]
163 if !ok {
164 break
165 }
166 j := da.Base[i] + code
167 if len(da.Check) <= j || da.Check[j] != i+1 {
168 break
169 }
170 i = j
171 }
172 j := da.Base[i] + len(da.Encoding)
173 if len(da.Check) <= j || da.Check[j] != i+1 {
174 return false
175 }
176 return true
177}