blob: 771633d1d99975c5cd55188e5fc958a676b9d6f2 [file] [log] [blame]
David K. Bainbridge215e0242017-09-05 23:18:24 -07001// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package transform
6
7import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io/ioutil"
12 "strconv"
13 "strings"
14 "testing"
15 "time"
16 "unicode/utf8"
17
18 "golang.org/x/text/internal/testtext"
19)
20
21type lowerCaseASCII struct{ NopResetter }
22
23func (lowerCaseASCII) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
24 n := len(src)
25 if n > len(dst) {
26 n, err = len(dst), ErrShortDst
27 }
28 for i, c := range src[:n] {
29 if 'A' <= c && c <= 'Z' {
30 c += 'a' - 'A'
31 }
32 dst[i] = c
33 }
34 return n, n, err
35}
36
37// lowerCaseASCIILookahead lowercases the string and reports ErrShortSrc as long
38// as the input is not atEOF.
39type lowerCaseASCIILookahead struct{ NopResetter }
40
41func (lowerCaseASCIILookahead) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
42 n := len(src)
43 if n > len(dst) {
44 n, err = len(dst), ErrShortDst
45 }
46 for i, c := range src[:n] {
47 if 'A' <= c && c <= 'Z' {
48 c += 'a' - 'A'
49 }
50 dst[i] = c
51 }
52 if !atEOF {
53 err = ErrShortSrc
54 }
55 return n, n, err
56}
57
58var errYouMentionedX = errors.New("you mentioned X")
59
60type dontMentionX struct{ NopResetter }
61
62func (dontMentionX) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
63 n := len(src)
64 if n > len(dst) {
65 n, err = len(dst), ErrShortDst
66 }
67 for i, c := range src[:n] {
68 if c == 'X' {
69 return i, i, errYouMentionedX
70 }
71 dst[i] = c
72 }
73 return n, n, err
74}
75
76var errAtEnd = errors.New("error after all text")
77
78type errorAtEnd struct{ NopResetter }
79
80func (errorAtEnd) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
81 n := copy(dst, src)
82 if n < len(src) {
83 return n, n, ErrShortDst
84 }
85 if atEOF {
86 return n, n, errAtEnd
87 }
88 return n, n, nil
89}
90
91type replaceWithConstant struct {
92 replacement string
93 written int
94}
95
96func (t *replaceWithConstant) Reset() {
97 t.written = 0
98}
99
100func (t *replaceWithConstant) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
101 if atEOF {
102 nDst = copy(dst, t.replacement[t.written:])
103 t.written += nDst
104 if t.written < len(t.replacement) {
105 err = ErrShortDst
106 }
107 }
108 return nDst, len(src), err
109}
110
111type addAnXAtTheEnd struct{ NopResetter }
112
113func (addAnXAtTheEnd) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
114 n := copy(dst, src)
115 if n < len(src) {
116 return n, n, ErrShortDst
117 }
118 if !atEOF {
119 return n, n, nil
120 }
121 if len(dst) == n {
122 return n, n, ErrShortDst
123 }
124 dst[n] = 'X'
125 return n + 1, n, nil
126}
127
128// doublerAtEOF is a strange Transformer that transforms "this" to "tthhiiss",
129// but only if atEOF is true.
130type doublerAtEOF struct{ NopResetter }
131
132func (doublerAtEOF) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
133 if !atEOF {
134 return 0, 0, ErrShortSrc
135 }
136 for i, c := range src {
137 if 2*i+2 >= len(dst) {
138 return 2 * i, i, ErrShortDst
139 }
140 dst[2*i+0] = c
141 dst[2*i+1] = c
142 }
143 return 2 * len(src), len(src), nil
144}
145
146// rleDecode and rleEncode implement a toy run-length encoding: "aabbbbbbbbbb"
147// is encoded as "2a10b". The decoding is assumed to not contain any numbers.
148
149type rleDecode struct{ NopResetter }
150
151func (rleDecode) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
152loop:
153 for len(src) > 0 {
154 n := 0
155 for i, c := range src {
156 if '0' <= c && c <= '9' {
157 n = 10*n + int(c-'0')
158 continue
159 }
160 if i == 0 {
161 return nDst, nSrc, errors.New("rleDecode: bad input")
162 }
163 if n > len(dst) {
164 return nDst, nSrc, ErrShortDst
165 }
166 for j := 0; j < n; j++ {
167 dst[j] = c
168 }
169 dst, src = dst[n:], src[i+1:]
170 nDst, nSrc = nDst+n, nSrc+i+1
171 continue loop
172 }
173 if atEOF {
174 return nDst, nSrc, errors.New("rleDecode: bad input")
175 }
176 return nDst, nSrc, ErrShortSrc
177 }
178 return nDst, nSrc, nil
179}
180
181type rleEncode struct {
182 NopResetter
183
184 // allowStutter means that "xxxxxxxx" can be encoded as "5x3x"
185 // instead of always as "8x".
186 allowStutter bool
187}
188
189func (e rleEncode) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
190 for len(src) > 0 {
191 n, c0 := len(src), src[0]
192 for i, c := range src[1:] {
193 if c != c0 {
194 n = i + 1
195 break
196 }
197 }
198 if n == len(src) && !atEOF && !e.allowStutter {
199 return nDst, nSrc, ErrShortSrc
200 }
201 s := strconv.Itoa(n)
202 if len(s) >= len(dst) {
203 return nDst, nSrc, ErrShortDst
204 }
205 copy(dst, s)
206 dst[len(s)] = c0
207 dst, src = dst[len(s)+1:], src[n:]
208 nDst, nSrc = nDst+len(s)+1, nSrc+n
209 }
210 return nDst, nSrc, nil
211}
212
213// trickler consumes all input bytes, but writes a single byte at a time to dst.
214type trickler []byte
215
216func (t *trickler) Reset() {
217 *t = nil
218}
219
220func (t *trickler) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
221 *t = append(*t, src...)
222 if len(*t) == 0 {
223 return 0, 0, nil
224 }
225 if len(dst) == 0 {
226 return 0, len(src), ErrShortDst
227 }
228 dst[0] = (*t)[0]
229 *t = (*t)[1:]
230 if len(*t) > 0 {
231 err = ErrShortDst
232 }
233 return 1, len(src), err
234}
235
236// delayedTrickler is like trickler, but delays writing output to dst. This is
237// highly unlikely to be relevant in practice, but it seems like a good idea
238// to have some tolerance as long as progress can be detected.
239type delayedTrickler []byte
240
241func (t *delayedTrickler) Reset() {
242 *t = nil
243}
244
245func (t *delayedTrickler) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
246 if len(*t) > 0 && len(dst) > 0 {
247 dst[0] = (*t)[0]
248 *t = (*t)[1:]
249 nDst = 1
250 }
251 *t = append(*t, src...)
252 if len(*t) > 0 {
253 err = ErrShortDst
254 }
255 return nDst, len(src), err
256}
257
258type testCase struct {
259 desc string
260 t Transformer
261 src string
262 dstSize int
263 srcSize int
264 ioSize int
265 wantStr string
266 wantErr error
267 wantIter int // number of iterations taken; 0 means we don't care.
268}
269
270func (t testCase) String() string {
271 return tstr(t.t) + "; " + t.desc
272}
273
274func tstr(t Transformer) string {
275 if stringer, ok := t.(fmt.Stringer); ok {
276 return stringer.String()
277 }
278 s := fmt.Sprintf("%T", t)
279 return s[1+strings.Index(s, "."):]
280}
281
282func (c chain) String() string {
283 buf := &bytes.Buffer{}
284 buf.WriteString("Chain(")
285 for i, l := range c.link[:len(c.link)-1] {
286 if i != 0 {
287 fmt.Fprint(buf, ", ")
288 }
289 buf.WriteString(tstr(l.t))
290 }
291 buf.WriteString(")")
292 return buf.String()
293}
294
295var testCases = []testCase{
296 {
297 desc: "empty",
298 t: lowerCaseASCII{},
299 src: "",
300 dstSize: 100,
301 srcSize: 100,
302 wantStr: "",
303 },
304
305 {
306 desc: "basic",
307 t: lowerCaseASCII{},
308 src: "Hello WORLD.",
309 dstSize: 100,
310 srcSize: 100,
311 wantStr: "hello world.",
312 },
313
314 {
315 desc: "small dst",
316 t: lowerCaseASCII{},
317 src: "Hello WORLD.",
318 dstSize: 3,
319 srcSize: 100,
320 wantStr: "hello world.",
321 },
322
323 {
324 desc: "small src",
325 t: lowerCaseASCII{},
326 src: "Hello WORLD.",
327 dstSize: 100,
328 srcSize: 4,
329 wantStr: "hello world.",
330 },
331
332 {
333 desc: "small buffers",
334 t: lowerCaseASCII{},
335 src: "Hello WORLD.",
336 dstSize: 3,
337 srcSize: 4,
338 wantStr: "hello world.",
339 },
340
341 {
342 desc: "very small buffers",
343 t: lowerCaseASCII{},
344 src: "Hello WORLD.",
345 dstSize: 1,
346 srcSize: 1,
347 wantStr: "hello world.",
348 },
349
350 {
351 desc: "small dst with lookahead",
352 t: lowerCaseASCIILookahead{},
353 src: "Hello WORLD.",
354 dstSize: 3,
355 srcSize: 100,
356 wantStr: "hello world.",
357 },
358
359 {
360 desc: "small src with lookahead",
361 t: lowerCaseASCIILookahead{},
362 src: "Hello WORLD.",
363 dstSize: 100,
364 srcSize: 4,
365 wantStr: "hello world.",
366 },
367
368 {
369 desc: "small buffers with lookahead",
370 t: lowerCaseASCIILookahead{},
371 src: "Hello WORLD.",
372 dstSize: 3,
373 srcSize: 4,
374 wantStr: "hello world.",
375 },
376
377 {
378 desc: "very small buffers with lookahead",
379 t: lowerCaseASCIILookahead{},
380 src: "Hello WORLD.",
381 dstSize: 1,
382 srcSize: 2,
383 wantStr: "hello world.",
384 },
385
386 {
387 desc: "user error",
388 t: dontMentionX{},
389 src: "The First Rule of Transform Club: don't mention Mister X, ever.",
390 dstSize: 100,
391 srcSize: 100,
392 wantStr: "The First Rule of Transform Club: don't mention Mister ",
393 wantErr: errYouMentionedX,
394 },
395
396 {
397 desc: "user error at end",
398 t: errorAtEnd{},
399 src: "All goes well until it doesn't.",
400 dstSize: 100,
401 srcSize: 100,
402 wantStr: "All goes well until it doesn't.",
403 wantErr: errAtEnd,
404 },
405
406 {
407 desc: "user error at end, incremental",
408 t: errorAtEnd{},
409 src: "All goes well until it doesn't.",
410 dstSize: 10,
411 srcSize: 10,
412 wantStr: "All goes well until it doesn't.",
413 wantErr: errAtEnd,
414 },
415
416 {
417 desc: "replace entire non-empty string with one byte",
418 t: &replaceWithConstant{replacement: "X"},
419 src: "none of this will be copied",
420 dstSize: 1,
421 srcSize: 10,
422 wantStr: "X",
423 },
424
425 {
426 desc: "replace entire empty string with one byte",
427 t: &replaceWithConstant{replacement: "X"},
428 src: "",
429 dstSize: 1,
430 srcSize: 10,
431 wantStr: "X",
432 },
433
434 {
435 desc: "replace entire empty string with seven bytes",
436 t: &replaceWithConstant{replacement: "ABCDEFG"},
437 src: "",
438 dstSize: 3,
439 srcSize: 10,
440 wantStr: "ABCDEFG",
441 },
442
443 {
444 desc: "add an X (initialBufSize-1)",
445 t: addAnXAtTheEnd{},
446 src: aaa[:initialBufSize-1],
447 dstSize: 10,
448 srcSize: 10,
449 wantStr: aaa[:initialBufSize-1] + "X",
450 },
451
452 {
453 desc: "add an X (initialBufSize+0)",
454 t: addAnXAtTheEnd{},
455 src: aaa[:initialBufSize+0],
456 dstSize: 10,
457 srcSize: 10,
458 wantStr: aaa[:initialBufSize+0] + "X",
459 },
460
461 {
462 desc: "add an X (initialBufSize+1)",
463 t: addAnXAtTheEnd{},
464 src: aaa[:initialBufSize+1],
465 dstSize: 10,
466 srcSize: 10,
467 wantStr: aaa[:initialBufSize+1] + "X",
468 },
469
470 {
471 desc: "small buffers",
472 t: dontMentionX{},
473 src: "The First Rule of Transform Club: don't mention Mister X, ever.",
474 dstSize: 10,
475 srcSize: 10,
476 wantStr: "The First Rule of Transform Club: don't mention Mister ",
477 wantErr: errYouMentionedX,
478 },
479
480 {
481 desc: "very small buffers",
482 t: dontMentionX{},
483 src: "The First Rule of Transform Club: don't mention Mister X, ever.",
484 dstSize: 1,
485 srcSize: 1,
486 wantStr: "The First Rule of Transform Club: don't mention Mister ",
487 wantErr: errYouMentionedX,
488 },
489
490 {
491 desc: "only transform at EOF",
492 t: doublerAtEOF{},
493 src: "this",
494 dstSize: 100,
495 srcSize: 100,
496 wantStr: "tthhiiss",
497 },
498
499 {
500 desc: "basic",
501 t: rleDecode{},
502 src: "1a2b3c10d11e0f1g",
503 dstSize: 100,
504 srcSize: 100,
505 wantStr: "abbcccddddddddddeeeeeeeeeeeg",
506 },
507
508 {
509 desc: "long",
510 t: rleDecode{},
511 src: "12a23b34c45d56e99z",
512 dstSize: 100,
513 srcSize: 100,
514 wantStr: strings.Repeat("a", 12) +
515 strings.Repeat("b", 23) +
516 strings.Repeat("c", 34) +
517 strings.Repeat("d", 45) +
518 strings.Repeat("e", 56) +
519 strings.Repeat("z", 99),
520 },
521
522 {
523 desc: "tight buffers",
524 t: rleDecode{},
525 src: "1a2b3c10d11e0f1g",
526 dstSize: 11,
527 srcSize: 3,
528 wantStr: "abbcccddddddddddeeeeeeeeeeeg",
529 },
530
531 {
532 desc: "short dst",
533 t: rleDecode{},
534 src: "1a2b3c10d11e0f1g",
535 dstSize: 10,
536 srcSize: 3,
537 wantStr: "abbcccdddddddddd",
538 wantErr: ErrShortDst,
539 },
540
541 {
542 desc: "short src",
543 t: rleDecode{},
544 src: "1a2b3c10d11e0f1g",
545 dstSize: 11,
546 srcSize: 2,
547 ioSize: 2,
548 wantStr: "abbccc",
549 wantErr: ErrShortSrc,
550 },
551
552 {
553 desc: "basic",
554 t: rleEncode{},
555 src: "abbcccddddddddddeeeeeeeeeeeg",
556 dstSize: 100,
557 srcSize: 100,
558 wantStr: "1a2b3c10d11e1g",
559 },
560
561 {
562 desc: "long",
563 t: rleEncode{},
564 src: strings.Repeat("a", 12) +
565 strings.Repeat("b", 23) +
566 strings.Repeat("c", 34) +
567 strings.Repeat("d", 45) +
568 strings.Repeat("e", 56) +
569 strings.Repeat("z", 99),
570 dstSize: 100,
571 srcSize: 100,
572 wantStr: "12a23b34c45d56e99z",
573 },
574
575 {
576 desc: "tight buffers",
577 t: rleEncode{},
578 src: "abbcccddddddddddeeeeeeeeeeeg",
579 dstSize: 3,
580 srcSize: 12,
581 wantStr: "1a2b3c10d11e1g",
582 },
583
584 {
585 desc: "short dst",
586 t: rleEncode{},
587 src: "abbcccddddddddddeeeeeeeeeeeg",
588 dstSize: 2,
589 srcSize: 12,
590 wantStr: "1a2b3c",
591 wantErr: ErrShortDst,
592 },
593
594 {
595 desc: "short src",
596 t: rleEncode{},
597 src: "abbcccddddddddddeeeeeeeeeeeg",
598 dstSize: 3,
599 srcSize: 11,
600 ioSize: 11,
601 wantStr: "1a2b3c10d",
602 wantErr: ErrShortSrc,
603 },
604
605 {
606 desc: "allowStutter = false",
607 t: rleEncode{allowStutter: false},
608 src: "aaaabbbbbbbbccccddddd",
609 dstSize: 10,
610 srcSize: 10,
611 wantStr: "4a8b4c5d",
612 },
613
614 {
615 desc: "allowStutter = true",
616 t: rleEncode{allowStutter: true},
617 src: "aaaabbbbbbbbccccddddd",
618 dstSize: 10,
619 srcSize: 10,
620 ioSize: 10,
621 wantStr: "4a6b2b4c4d1d",
622 },
623
624 {
625 desc: "trickler",
626 t: &trickler{},
627 src: "abcdefghijklm",
628 dstSize: 3,
629 srcSize: 15,
630 wantStr: "abcdefghijklm",
631 },
632
633 {
634 desc: "delayedTrickler",
635 t: &delayedTrickler{},
636 src: "abcdefghijklm",
637 dstSize: 3,
638 srcSize: 15,
639 wantStr: "abcdefghijklm",
640 },
641}
642
643func TestReader(t *testing.T) {
644 for _, tc := range testCases {
645 testtext.Run(t, tc.desc, func(t *testing.T) {
646 r := NewReader(strings.NewReader(tc.src), tc.t)
647 // Differently sized dst and src buffers are not part of the
648 // exported API. We override them manually.
649 r.dst = make([]byte, tc.dstSize)
650 r.src = make([]byte, tc.srcSize)
651 got, err := ioutil.ReadAll(r)
652 str := string(got)
653 if str != tc.wantStr || err != tc.wantErr {
654 t.Errorf("\ngot %q, %v\nwant %q, %v", str, err, tc.wantStr, tc.wantErr)
655 }
656 })
657 }
658}
659
660func TestWriter(t *testing.T) {
661 tests := append(testCases, chainTests()...)
662 for _, tc := range tests {
663 sizes := []int{1, 2, 3, 4, 5, 10, 100, 1000}
664 if tc.ioSize > 0 {
665 sizes = []int{tc.ioSize}
666 }
667 for _, sz := range sizes {
668 testtext.Run(t, fmt.Sprintf("%s/%d", tc.desc, sz), func(t *testing.T) {
669 bb := &bytes.Buffer{}
670 w := NewWriter(bb, tc.t)
671 // Differently sized dst and src buffers are not part of the
672 // exported API. We override them manually.
673 w.dst = make([]byte, tc.dstSize)
674 w.src = make([]byte, tc.srcSize)
675 src := make([]byte, sz)
676 var err error
677 for b := tc.src; len(b) > 0 && err == nil; {
678 n := copy(src, b)
679 b = b[n:]
680 m := 0
681 m, err = w.Write(src[:n])
682 if m != n && err == nil {
683 t.Errorf("did not consume all bytes %d < %d", m, n)
684 }
685 }
686 if err == nil {
687 err = w.Close()
688 }
689 str := bb.String()
690 if str != tc.wantStr || err != tc.wantErr {
691 t.Errorf("\ngot %q, %v\nwant %q, %v", str, err, tc.wantStr, tc.wantErr)
692 }
693 })
694 }
695 }
696}
697
698func TestNop(t *testing.T) {
699 testCases := []struct {
700 str string
701 dstSize int
702 err error
703 }{
704 {"", 0, nil},
705 {"", 10, nil},
706 {"a", 0, ErrShortDst},
707 {"a", 1, nil},
708 {"a", 10, nil},
709 }
710 for i, tc := range testCases {
711 dst := make([]byte, tc.dstSize)
712 nDst, nSrc, err := Nop.Transform(dst, []byte(tc.str), true)
713 want := tc.str
714 if tc.dstSize < len(want) {
715 want = want[:tc.dstSize]
716 }
717 if got := string(dst[:nDst]); got != want || err != tc.err || nSrc != nDst {
718 t.Errorf("%d:\ngot %q, %d, %v\nwant %q, %d, %v", i, got, nSrc, err, want, nDst, tc.err)
719 }
720 }
721}
722
723func TestDiscard(t *testing.T) {
724 testCases := []struct {
725 str string
726 dstSize int
727 }{
728 {"", 0},
729 {"", 10},
730 {"a", 0},
731 {"ab", 10},
732 }
733 for i, tc := range testCases {
734 nDst, nSrc, err := Discard.Transform(make([]byte, tc.dstSize), []byte(tc.str), true)
735 if nDst != 0 || nSrc != len(tc.str) || err != nil {
736 t.Errorf("%d:\ngot %q, %d, %v\nwant 0, %d, nil", i, nDst, nSrc, err, len(tc.str))
737 }
738 }
739}
740
741// mkChain creates a Chain transformer. x must be alternating between transformer
742// and bufSize, like T, (sz, T)*
743func mkChain(x ...interface{}) *chain {
744 t := []Transformer{}
745 for i := 0; i < len(x); i += 2 {
746 t = append(t, x[i].(Transformer))
747 }
748 c := Chain(t...).(*chain)
749 for i, j := 1, 1; i < len(x); i, j = i+2, j+1 {
750 c.link[j].b = make([]byte, x[i].(int))
751 }
752 return c
753}
754
755func chainTests() []testCase {
756 return []testCase{
757 {
758 desc: "nil error",
759 t: mkChain(rleEncode{}, 100, lowerCaseASCII{}),
760 src: "ABB",
761 dstSize: 100,
762 srcSize: 100,
763 wantStr: "1a2b",
764 wantErr: nil,
765 wantIter: 1,
766 },
767
768 {
769 desc: "short dst buffer",
770 t: mkChain(lowerCaseASCII{}, 3, rleDecode{}),
771 src: "1a2b3c10d11e0f1g",
772 dstSize: 10,
773 srcSize: 3,
774 wantStr: "abbcccdddddddddd",
775 wantErr: ErrShortDst,
776 },
777
778 {
779 desc: "short internal dst buffer",
780 t: mkChain(lowerCaseASCII{}, 3, rleDecode{}, 10, Nop),
781 src: "1a2b3c10d11e0f1g",
782 dstSize: 100,
783 srcSize: 3,
784 wantStr: "abbcccdddddddddd",
785 wantErr: errShortInternal,
786 },
787
788 {
789 desc: "short internal dst buffer from input",
790 t: mkChain(rleDecode{}, 10, Nop),
791 src: "1a2b3c10d11e0f1g",
792 dstSize: 100,
793 srcSize: 3,
794 wantStr: "abbcccdddddddddd",
795 wantErr: errShortInternal,
796 },
797
798 {
799 desc: "empty short internal dst buffer",
800 t: mkChain(lowerCaseASCII{}, 3, rleDecode{}, 10, Nop),
801 src: "4a7b11e0f1g",
802 dstSize: 100,
803 srcSize: 3,
804 wantStr: "aaaabbbbbbb",
805 wantErr: errShortInternal,
806 },
807
808 {
809 desc: "empty short internal dst buffer from input",
810 t: mkChain(rleDecode{}, 10, Nop),
811 src: "4a7b11e0f1g",
812 dstSize: 100,
813 srcSize: 3,
814 wantStr: "aaaabbbbbbb",
815 wantErr: errShortInternal,
816 },
817
818 {
819 desc: "short internal src buffer after full dst buffer",
820 t: mkChain(Nop, 5, rleEncode{}, 10, Nop),
821 src: "cccccddddd",
822 dstSize: 100,
823 srcSize: 100,
824 wantStr: "",
825 wantErr: errShortInternal,
826 wantIter: 1,
827 },
828
829 {
830 desc: "short internal src buffer after short dst buffer; test lastFull",
831 t: mkChain(rleDecode{}, 5, rleEncode{}, 4, Nop),
832 src: "2a1b4c6d",
833 dstSize: 100,
834 srcSize: 100,
835 wantStr: "2a1b",
836 wantErr: errShortInternal,
837 },
838
839 {
840 desc: "short internal src buffer after successful complete fill",
841 t: mkChain(Nop, 3, rleDecode{}),
842 src: "123a4b",
843 dstSize: 4,
844 srcSize: 3,
845 wantStr: "",
846 wantErr: errShortInternal,
847 wantIter: 1,
848 },
849
850 {
851 desc: "short internal src buffer after short dst buffer; test lastFull",
852 t: mkChain(rleDecode{}, 5, rleEncode{}),
853 src: "2a1b4c6d",
854 dstSize: 4,
855 srcSize: 100,
856 wantStr: "2a1b",
857 wantErr: errShortInternal,
858 },
859
860 {
861 desc: "short src buffer",
862 t: mkChain(rleEncode{}, 5, Nop),
863 src: "abbcccddddeeeee",
864 dstSize: 4,
865 srcSize: 4,
866 ioSize: 4,
867 wantStr: "1a2b3c",
868 wantErr: ErrShortSrc,
869 },
870
871 {
872 desc: "process all in one go",
873 t: mkChain(rleEncode{}, 5, Nop),
874 src: "abbcccddddeeeeeffffff",
875 dstSize: 100,
876 srcSize: 100,
877 wantStr: "1a2b3c4d5e6f",
878 wantErr: nil,
879 wantIter: 1,
880 },
881
882 {
883 desc: "complete processing downstream after error",
884 t: mkChain(dontMentionX{}, 2, rleDecode{}, 5, Nop),
885 src: "3a4b5eX",
886 dstSize: 100,
887 srcSize: 100,
888 ioSize: 100,
889 wantStr: "aaabbbbeeeee",
890 wantErr: errYouMentionedX,
891 },
892
893 {
894 desc: "return downstream fatal errors first (followed by short dst)",
895 t: mkChain(dontMentionX{}, 8, rleDecode{}, 4, Nop),
896 src: "3a4b5eX",
897 dstSize: 100,
898 srcSize: 100,
899 ioSize: 100,
900 wantStr: "aaabbbb",
901 wantErr: errShortInternal,
902 },
903
904 {
905 desc: "return downstream fatal errors first (followed by short src)",
906 t: mkChain(dontMentionX{}, 5, Nop, 1, rleDecode{}),
907 src: "1a5bX",
908 dstSize: 100,
909 srcSize: 100,
910 ioSize: 100,
911 wantStr: "",
912 wantErr: errShortInternal,
913 },
914
915 {
916 desc: "short internal",
917 t: mkChain(Nop, 11, rleEncode{}, 3, Nop),
918 src: "abbcccddddddddddeeeeeeeeeeeg",
919 dstSize: 3,
920 srcSize: 100,
921 wantStr: "1a2b3c10d",
922 wantErr: errShortInternal,
923 },
924 }
925}
926
927func doTransform(tc testCase) (res string, iter int, err error) {
928 tc.t.Reset()
929 dst := make([]byte, tc.dstSize)
930 out, in := make([]byte, 0, 2*len(tc.src)), []byte(tc.src)
931 for {
932 iter++
933 src, atEOF := in, true
934 if len(src) > tc.srcSize {
935 src, atEOF = src[:tc.srcSize], false
936 }
937 nDst, nSrc, err := tc.t.Transform(dst, src, atEOF)
938 out = append(out, dst[:nDst]...)
939 in = in[nSrc:]
940 switch {
941 case err == nil && len(in) != 0:
942 case err == ErrShortSrc && nSrc > 0:
943 case err == ErrShortDst && (nDst > 0 || nSrc > 0):
944 default:
945 return string(out), iter, err
946 }
947 }
948}
949
950func TestChain(t *testing.T) {
951 if c, ok := Chain().(nop); !ok {
952 t.Errorf("empty chain: %v; want Nop", c)
953 }
954
955 // Test Chain for a single Transformer.
956 for _, tc := range testCases {
957 tc.t = Chain(tc.t)
958 str, _, err := doTransform(tc)
959 if str != tc.wantStr || err != tc.wantErr {
960 t.Errorf("%s:\ngot %q, %v\nwant %q, %v", tc, str, err, tc.wantStr, tc.wantErr)
961 }
962 }
963
964 tests := chainTests()
965 sizes := []int{1, 2, 3, 4, 5, 7, 10, 100, 1000}
966 addTest := func(tc testCase, t *chain) {
967 if t.link[0].t != tc.t && tc.wantErr == ErrShortSrc {
968 tc.wantErr = errShortInternal
969 }
970 if t.link[len(t.link)-2].t != tc.t && tc.wantErr == ErrShortDst {
971 tc.wantErr = errShortInternal
972 }
973 tc.t = t
974 tests = append(tests, tc)
975 }
976 for _, tc := range testCases {
977 for _, sz := range sizes {
978 tt := tc
979 tt.dstSize = sz
980 addTest(tt, mkChain(tc.t, tc.dstSize, Nop))
981 addTest(tt, mkChain(tc.t, tc.dstSize, Nop, 2, Nop))
982 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop))
983 if sz >= tc.dstSize && (tc.wantErr != ErrShortDst || sz == tc.dstSize) {
984 addTest(tt, mkChain(Nop, tc.srcSize, tc.t))
985 addTest(tt, mkChain(Nop, 100, Nop, tc.srcSize, tc.t))
986 }
987 }
988 }
989 for _, tc := range testCases {
990 tt := tc
991 tt.dstSize = 1
992 tt.wantStr = ""
993 addTest(tt, mkChain(tc.t, tc.dstSize, Discard))
994 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Discard))
995 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop, tc.dstSize, Discard))
996 }
997 for _, tc := range testCases {
998 tt := tc
999 tt.dstSize = 100
1000 tt.wantStr = strings.Replace(tc.src, "0f", "", -1)
1001 // Chain encoders and decoders.
1002 if _, ok := tc.t.(rleEncode); ok && tc.wantErr == nil {
1003 addTest(tt, mkChain(tc.t, tc.dstSize, Nop, 1000, rleDecode{}))
1004 addTest(tt, mkChain(tc.t, tc.dstSize, Nop, tc.dstSize, rleDecode{}))
1005 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop, 100, rleDecode{}))
1006 // decoding needs larger destinations
1007 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, rleDecode{}, 100, Nop))
1008 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop, 100, rleDecode{}, 100, Nop))
1009 } else if _, ok := tc.t.(rleDecode); ok && tc.wantErr == nil {
1010 // The internal buffer size may need to be the sum of the maximum segment
1011 // size of the two encoders!
1012 addTest(tt, mkChain(tc.t, 2*tc.dstSize, rleEncode{}))
1013 addTest(tt, mkChain(tc.t, tc.dstSize, Nop, 101, rleEncode{}))
1014 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop, 100, rleEncode{}))
1015 addTest(tt, mkChain(Nop, tc.srcSize, tc.t, tc.dstSize, Nop, 200, rleEncode{}, 100, Nop))
1016 }
1017 }
1018 for _, tc := range tests {
1019 str, iter, err := doTransform(tc)
1020 mi := tc.wantIter != 0 && tc.wantIter != iter
1021 if str != tc.wantStr || err != tc.wantErr || mi {
1022 t.Errorf("%s:\ngot iter:%d, %q, %v\nwant iter:%d, %q, %v", tc, iter, str, err, tc.wantIter, tc.wantStr, tc.wantErr)
1023 }
1024 break
1025 }
1026}
1027
1028func TestRemoveFunc(t *testing.T) {
1029 filter := RemoveFunc(func(r rune) bool {
1030 return strings.IndexRune("ab\u0300\u1234,", r) != -1
1031 })
1032 tests := []testCase{
1033 {
1034 src: ",",
1035 wantStr: "",
1036 },
1037
1038 {
1039 src: "c",
1040 wantStr: "c",
1041 },
1042
1043 {
1044 src: "\u2345",
1045 wantStr: "\u2345",
1046 },
1047
1048 {
1049 src: "tschüß",
1050 wantStr: "tschüß",
1051 },
1052
1053 {
1054 src: ",до,свидания,",
1055 wantStr: "досвидания",
1056 },
1057
1058 {
1059 src: "a\xbd\xb2=\xbc ⌘",
1060 wantStr: "\uFFFD\uFFFD=\uFFFD ⌘",
1061 },
1062
1063 {
1064 // If we didn't replace illegal bytes with RuneError, the result
1065 // would be \u0300 or the code would need to be more complex.
1066 src: "\xcc\u0300\x80",
1067 wantStr: "\uFFFD\uFFFD",
1068 },
1069
1070 {
1071 src: "\xcc\u0300\x80",
1072 dstSize: 3,
1073 wantStr: "\uFFFD\uFFFD",
1074 wantIter: 2,
1075 },
1076
1077 {
1078 // Test a long buffer greater than the internal buffer size
1079 src: "hello\xcc\xcc\xccworld",
1080 srcSize: 13,
1081 wantStr: "hello\uFFFD\uFFFD\uFFFDworld",
1082 wantIter: 1,
1083 },
1084
1085 {
1086 src: "\u2345",
1087 dstSize: 2,
1088 wantStr: "",
1089 wantErr: ErrShortDst,
1090 },
1091
1092 {
1093 src: "\xcc",
1094 dstSize: 2,
1095 wantStr: "",
1096 wantErr: ErrShortDst,
1097 },
1098
1099 {
1100 src: "\u0300",
1101 dstSize: 2,
1102 srcSize: 1,
1103 wantStr: "",
1104 wantErr: ErrShortSrc,
1105 },
1106
1107 {
1108 t: RemoveFunc(func(r rune) bool {
1109 return r == utf8.RuneError
1110 }),
1111 src: "\xcc\u0300\x80",
1112 wantStr: "\u0300",
1113 },
1114 }
1115
1116 for _, tc := range tests {
1117 tc.desc = tc.src
1118 if tc.t == nil {
1119 tc.t = filter
1120 }
1121 if tc.dstSize == 0 {
1122 tc.dstSize = 100
1123 }
1124 if tc.srcSize == 0 {
1125 tc.srcSize = 100
1126 }
1127 str, iter, err := doTransform(tc)
1128 mi := tc.wantIter != 0 && tc.wantIter != iter
1129 if str != tc.wantStr || err != tc.wantErr || mi {
1130 t.Errorf("%+q:\ngot iter:%d, %+q, %v\nwant iter:%d, %+q, %v", tc.src, iter, str, err, tc.wantIter, tc.wantStr, tc.wantErr)
1131 }
1132
1133 tc.src = str
1134 idem, _, _ := doTransform(tc)
1135 if str != idem {
1136 t.Errorf("%+q: found %+q; want %+q", tc.src, idem, str)
1137 }
1138 }
1139}
1140
1141func testString(t *testing.T, f func(Transformer, string) (string, int, error)) {
1142 for _, tt := range append(testCases, chainTests()...) {
1143 if tt.desc == "allowStutter = true" {
1144 // We don't have control over the buffer size, so we eliminate tests
1145 // that depend on a specific buffer size being set.
1146 continue
1147 }
1148 if tt.wantErr == ErrShortDst || tt.wantErr == ErrShortSrc {
1149 // The result string will be different.
1150 continue
1151 }
1152 testtext.Run(t, tt.desc, func(t *testing.T) {
1153 got, n, err := f(tt.t, tt.src)
1154 if tt.wantErr != err {
1155 t.Errorf("error: got %v; want %v", err, tt.wantErr)
1156 }
1157 // Check that err == nil implies that n == len(tt.src). Note that vice
1158 // versa isn't necessarily true.
1159 if err == nil && n != len(tt.src) {
1160 t.Errorf("err == nil: got %d bytes, want %d", n, err)
1161 }
1162 if got != tt.wantStr {
1163 t.Errorf("string: got %q; want %q", got, tt.wantStr)
1164 }
1165 })
1166 }
1167}
1168
1169func TestBytes(t *testing.T) {
1170 testString(t, func(z Transformer, s string) (string, int, error) {
1171 b, n, err := Bytes(z, []byte(s))
1172 return string(b), n, err
1173 })
1174}
1175
1176func TestAppend(t *testing.T) {
1177 // Create a bunch of subtests for different buffer sizes.
1178 testCases := [][]byte{
1179 nil,
1180 make([]byte, 0, 0),
1181 make([]byte, 0, 1),
1182 make([]byte, 1, 1),
1183 make([]byte, 1, 5),
1184 make([]byte, 100, 100),
1185 make([]byte, 100, 200),
1186 }
1187 for _, tc := range testCases {
1188 testString(t, func(z Transformer, s string) (string, int, error) {
1189 b, n, err := Append(z, tc, []byte(s))
1190 return string(b[len(tc):]), n, err
1191 })
1192 }
1193}
1194
1195func TestString(t *testing.T) {
1196 testtext.Run(t, "transform", func(t *testing.T) { testString(t, String) })
1197
1198 // Overrun the internal destination buffer.
1199 for i, s := range []string{
1200 aaa[:1*initialBufSize-1],
1201 aaa[:1*initialBufSize+0],
1202 aaa[:1*initialBufSize+1],
1203 AAA[:1*initialBufSize-1],
1204 AAA[:1*initialBufSize+0],
1205 AAA[:1*initialBufSize+1],
1206 AAA[:2*initialBufSize-1],
1207 AAA[:2*initialBufSize+0],
1208 AAA[:2*initialBufSize+1],
1209 aaa[:1*initialBufSize-2] + "A",
1210 aaa[:1*initialBufSize-1] + "A",
1211 aaa[:1*initialBufSize+0] + "A",
1212 aaa[:1*initialBufSize+1] + "A",
1213 } {
1214 testtext.Run(t, fmt.Sprint("dst buffer test using lower/", i), func(t *testing.T) {
1215 got, _, _ := String(lowerCaseASCII{}, s)
1216 if want := strings.ToLower(s); got != want {
1217 t.Errorf("got %s (%d); want %s (%d)", got, len(got), want, len(want))
1218 }
1219 })
1220 }
1221
1222 // Overrun the internal source buffer.
1223 for i, s := range []string{
1224 aaa[:1*initialBufSize-1],
1225 aaa[:1*initialBufSize+0],
1226 aaa[:1*initialBufSize+1],
1227 aaa[:2*initialBufSize+1],
1228 aaa[:2*initialBufSize+0],
1229 aaa[:2*initialBufSize+1],
1230 } {
1231 testtext.Run(t, fmt.Sprint("src buffer test using rleEncode/", i), func(t *testing.T) {
1232 got, _, _ := String(rleEncode{}, s)
1233 if want := fmt.Sprintf("%da", len(s)); got != want {
1234 t.Errorf("got %s (%d); want %s (%d)", got, len(got), want, len(want))
1235 }
1236 })
1237 }
1238
1239 // Test allocations for non-changing strings.
1240 // Note we still need to allocate a single buffer.
1241 for i, s := range []string{
1242 "",
1243 "123456789",
1244 aaa[:initialBufSize-1],
1245 aaa[:initialBufSize+0],
1246 aaa[:initialBufSize+1],
1247 aaa[:10*initialBufSize],
1248 } {
1249 testtext.Run(t, fmt.Sprint("alloc/", i), func(t *testing.T) {
1250 if n := testtext.AllocsPerRun(5, func() { String(&lowerCaseASCIILookahead{}, s) }); n > 1 {
1251 t.Errorf("#allocs was %f; want 1", n)
1252 }
1253 })
1254 }
1255}
1256
1257// TestBytesAllocation tests that buffer growth stays limited with the trickler
1258// transformer, which behaves oddly but within spec. In case buffer growth is
1259// not correctly handled, the test will either panic with a failed allocation or
1260// thrash. To ensure the tests terminate under the last condition, we time out
1261// after some sufficiently long period of time.
1262func TestBytesAllocation(t *testing.T) {
1263 done := make(chan bool)
1264 go func() {
1265 in := bytes.Repeat([]byte{'a'}, 1000)
1266 tr := trickler(make([]byte, 1))
1267 Bytes(&tr, in)
1268 done <- true
1269 }()
1270 select {
1271 case <-done:
1272 case <-time.After(3 * time.Second):
1273 t.Error("time out, likely due to excessive allocation")
1274 }
1275}
1276
1277// TestStringAllocation tests that buffer growth stays limited with the trickler
1278// transformer, which behaves oddly but within spec. In case buffer growth is
1279// not correctly handled, the test will either panic with a failed allocation or
1280// thrash. To ensure the tests terminate under the last condition, we time out
1281// after some sufficiently long period of time.
1282func TestStringAllocation(t *testing.T) {
1283 done := make(chan bool)
1284 go func() {
1285 tr := trickler(make([]byte, 1))
1286 String(&tr, aaa[:1000])
1287 done <- true
1288 }()
1289 select {
1290 case <-done:
1291 case <-time.After(3 * time.Second):
1292 t.Error("time out, likely due to excessive allocation")
1293 }
1294}
1295
1296func BenchmarkStringLowerEmpty(b *testing.B) {
1297 for i := 0; i < b.N; i++ {
1298 String(&lowerCaseASCIILookahead{}, "")
1299 }
1300}
1301
1302func BenchmarkStringLowerIdentical(b *testing.B) {
1303 for i := 0; i < b.N; i++ {
1304 String(&lowerCaseASCIILookahead{}, aaa[:4096])
1305 }
1306}
1307
1308func BenchmarkStringLowerChanged(b *testing.B) {
1309 for i := 0; i < b.N; i++ {
1310 String(&lowerCaseASCIILookahead{}, AAA[:4096])
1311 }
1312}
1313
1314var (
1315 aaa = strings.Repeat("a", 4096)
1316 AAA = strings.Repeat("A", 4096)
1317)