blob: 10bdd46f155c2cd1b4d9ce522cce66c0ec44e432 [file] [log] [blame]
Takahiro Suzuki241c10e2020-12-17 20:17:57 +09001/*
2Open Source Initiative OSI - The MIT License (MIT):Licensing
3
4The MIT License (MIT)
5Copyright (c) 2013 Ralph Caraveo (deckarep@gmail.com)
6
7Permission is hereby granted, free of charge, to any person obtaining a copy of
8this software and associated documentation files (the "Software"), to deal in
9the Software without restriction, including without limitation the rights to
10use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
11of the Software, and to permit persons to whom the Software is furnished to do
12so, subject to the following conditions:
13
14The above copyright notice and this permission notice shall be included in all
15copies or substantial portions of the Software.
16
17THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23SOFTWARE.
24*/
25
26package mapset
27
28import (
29 "bytes"
30 "encoding/json"
31 "fmt"
32 "reflect"
33 "strings"
34)
35
36type threadUnsafeSet map[interface{}]struct{}
37
38// An OrderedPair represents a 2-tuple of values.
39type OrderedPair struct {
40 First interface{}
41 Second interface{}
42}
43
44func newThreadUnsafeSet() threadUnsafeSet {
45 return make(threadUnsafeSet)
46}
47
48// Equal says whether two 2-tuples contain the same values in the same order.
49func (pair *OrderedPair) Equal(other OrderedPair) bool {
50 if pair.First == other.First &&
51 pair.Second == other.Second {
52 return true
53 }
54
55 return false
56}
57
58func (set *threadUnsafeSet) Add(i interface{}) bool {
59 _, found := (*set)[i]
60 if found {
61 return false //False if it existed already
62 }
63
64 (*set)[i] = struct{}{}
65 return true
66}
67
68func (set *threadUnsafeSet) Contains(i ...interface{}) bool {
69 for _, val := range i {
70 if _, ok := (*set)[val]; !ok {
71 return false
72 }
73 }
74 return true
75}
76
77func (set *threadUnsafeSet) IsSubset(other Set) bool {
78 _ = other.(*threadUnsafeSet)
79 for elem := range *set {
80 if !other.Contains(elem) {
81 return false
82 }
83 }
84 return true
85}
86
87func (set *threadUnsafeSet) IsProperSubset(other Set) bool {
88 return set.IsSubset(other) && !set.Equal(other)
89}
90
91func (set *threadUnsafeSet) IsSuperset(other Set) bool {
92 return other.IsSubset(set)
93}
94
95func (set *threadUnsafeSet) IsProperSuperset(other Set) bool {
96 return set.IsSuperset(other) && !set.Equal(other)
97}
98
99func (set *threadUnsafeSet) Union(other Set) Set {
100 o := other.(*threadUnsafeSet)
101
102 unionedSet := newThreadUnsafeSet()
103
104 for elem := range *set {
105 unionedSet.Add(elem)
106 }
107 for elem := range *o {
108 unionedSet.Add(elem)
109 }
110 return &unionedSet
111}
112
113func (set *threadUnsafeSet) Intersect(other Set) Set {
114 o := other.(*threadUnsafeSet)
115
116 intersection := newThreadUnsafeSet()
117 // loop over smaller set
118 if set.Cardinality() < other.Cardinality() {
119 for elem := range *set {
120 if other.Contains(elem) {
121 intersection.Add(elem)
122 }
123 }
124 } else {
125 for elem := range *o {
126 if set.Contains(elem) {
127 intersection.Add(elem)
128 }
129 }
130 }
131 return &intersection
132}
133
134func (set *threadUnsafeSet) Difference(other Set) Set {
135 _ = other.(*threadUnsafeSet)
136
137 difference := newThreadUnsafeSet()
138 for elem := range *set {
139 if !other.Contains(elem) {
140 difference.Add(elem)
141 }
142 }
143 return &difference
144}
145
146func (set *threadUnsafeSet) SymmetricDifference(other Set) Set {
147 _ = other.(*threadUnsafeSet)
148
149 aDiff := set.Difference(other)
150 bDiff := other.Difference(set)
151 return aDiff.Union(bDiff)
152}
153
154func (set *threadUnsafeSet) Clear() {
155 *set = newThreadUnsafeSet()
156}
157
158func (set *threadUnsafeSet) Remove(i interface{}) {
159 delete(*set, i)
160}
161
162func (set *threadUnsafeSet) Cardinality() int {
163 return len(*set)
164}
165
166func (set *threadUnsafeSet) Each(cb func(interface{}) bool) {
167 for elem := range *set {
168 if cb(elem) {
169 break
170 }
171 }
172}
173
174func (set *threadUnsafeSet) Iter() <-chan interface{} {
175 ch := make(chan interface{})
176 go func() {
177 for elem := range *set {
178 ch <- elem
179 }
180 close(ch)
181 }()
182
183 return ch
184}
185
186func (set *threadUnsafeSet) Iterator() *Iterator {
187 iterator, ch, stopCh := newIterator()
188
189 go func() {
190 L:
191 for elem := range *set {
192 select {
193 case <-stopCh:
194 break L
195 case ch <- elem:
196 }
197 }
198 close(ch)
199 }()
200
201 return iterator
202}
203
204func (set *threadUnsafeSet) Equal(other Set) bool {
205 _ = other.(*threadUnsafeSet)
206
207 if set.Cardinality() != other.Cardinality() {
208 return false
209 }
210 for elem := range *set {
211 if !other.Contains(elem) {
212 return false
213 }
214 }
215 return true
216}
217
218func (set *threadUnsafeSet) Clone() Set {
219 clonedSet := newThreadUnsafeSet()
220 for elem := range *set {
221 clonedSet.Add(elem)
222 }
223 return &clonedSet
224}
225
226func (set *threadUnsafeSet) String() string {
227 items := make([]string, 0, len(*set))
228
229 for elem := range *set {
230 items = append(items, fmt.Sprintf("%v", elem))
231 }
232 return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
233}
234
235// String outputs a 2-tuple in the form "(A, B)".
236func (pair OrderedPair) String() string {
237 return fmt.Sprintf("(%v, %v)", pair.First, pair.Second)
238}
239
240func (set *threadUnsafeSet) Pop() interface{} {
241 for item := range *set {
242 delete(*set, item)
243 return item
244 }
245 return nil
246}
247
248func (set *threadUnsafeSet) PowerSet() Set {
249 powSet := NewThreadUnsafeSet()
250 nullset := newThreadUnsafeSet()
251 powSet.Add(&nullset)
252
253 for es := range *set {
254 u := newThreadUnsafeSet()
255 j := powSet.Iter()
256 for er := range j {
257 p := newThreadUnsafeSet()
258 if reflect.TypeOf(er).Name() == "" {
259 k := er.(*threadUnsafeSet)
260 for ek := range *(k) {
261 p.Add(ek)
262 }
263 } else {
264 p.Add(er)
265 }
266 p.Add(es)
267 u.Add(&p)
268 }
269
270 powSet = powSet.Union(&u)
271 }
272
273 return powSet
274}
275
276func (set *threadUnsafeSet) CartesianProduct(other Set) Set {
277 o := other.(*threadUnsafeSet)
278 cartProduct := NewThreadUnsafeSet()
279
280 for i := range *set {
281 for j := range *o {
282 elem := OrderedPair{First: i, Second: j}
283 cartProduct.Add(elem)
284 }
285 }
286
287 return cartProduct
288}
289
290func (set *threadUnsafeSet) ToSlice() []interface{} {
291 keys := make([]interface{}, 0, set.Cardinality())
292 for elem := range *set {
293 keys = append(keys, elem)
294 }
295
296 return keys
297}
298
299// MarshalJSON creates a JSON array from the set, it marshals all elements
300func (set *threadUnsafeSet) MarshalJSON() ([]byte, error) {
301 items := make([]string, 0, set.Cardinality())
302
303 for elem := range *set {
304 b, err := json.Marshal(elem)
305 if err != nil {
306 return nil, err
307 }
308
309 items = append(items, string(b))
310 }
311
312 return []byte(fmt.Sprintf("[%s]", strings.Join(items, ","))), nil
313}
314
315// UnmarshalJSON recreates a set from a JSON array, it only decodes
316// primitive types. Numbers are decoded as json.Number.
317func (set *threadUnsafeSet) UnmarshalJSON(b []byte) error {
318 var i []interface{}
319
320 d := json.NewDecoder(bytes.NewReader(b))
321 d.UseNumber()
322 err := d.Decode(&i)
323 if err != nil {
324 return err
325 }
326
327 for _, v := range i {
328 switch t := v.(type) {
329 case []interface{}, map[string]interface{}:
330 continue
331 default:
332 set.Add(t)
333 }
334 }
335
336 return nil
337}