blob: b8431f9a04d1e82610dcf4e2acaf33b1c7133d0b [file] [log] [blame]
David K. Bainbridge528b3182017-01-23 08:51:59 -08001// Copyright 2016 Canonical Ltd.
2// Licensed under the LGPLv3, see LICENCE file for details.
3
4package utils
5
6import (
7 "io"
8 "sort"
9
10 "github.com/juju/errors"
11)
12
13// SizeReaderAt combines io.ReaderAt with a Size method.
14type SizeReaderAt interface {
15 // Size returns the size of the data readable
16 // from the reader.
17 Size() int64
18 io.ReaderAt
19}
20
21// NewMultiReaderAt is like io.MultiReader but produces a ReaderAt
22// (and Size), instead of just a reader.
23//
24// Note: this implementation was taken from a talk given
25// by Brad Fitzpatrick as OSCON 2013.
26//
27// http://talks.golang.org/2013/oscon-dl.slide#49
28// https://github.com/golang/talks/blob/master/2013/oscon-dl/server-compose.go
29func NewMultiReaderAt(parts ...SizeReaderAt) SizeReaderAt {
30 m := &multiReaderAt{
31 parts: make([]offsetAndSource, 0, len(parts)),
32 }
33 var off int64
34 for _, p := range parts {
35 m.parts = append(m.parts, offsetAndSource{off, p})
36 off += p.Size()
37 }
38 m.size = off
39 return m
40}
41
42type offsetAndSource struct {
43 off int64
44 SizeReaderAt
45}
46
47type multiReaderAt struct {
48 parts []offsetAndSource
49 size int64
50}
51
52func (m *multiReaderAt) Size() int64 {
53 return m.size
54}
55
56func (m *multiReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
57 wantN := len(p)
58
59 // Skip past the requested offset.
60 skipParts := sort.Search(len(m.parts), func(i int) bool {
61 // This function returns whether parts[i] will
62 // contribute any bytes to our output.
63 part := m.parts[i]
64 return part.off+part.Size() > off
65 })
66 parts := m.parts[skipParts:]
67
68 // How far to skip in the first part.
69 needSkip := off
70 if len(parts) > 0 {
71 needSkip -= parts[0].off
72 }
73
74 for len(parts) > 0 && len(p) > 0 {
75 readP := p
76 partSize := parts[0].Size()
77 if int64(len(readP)) > partSize-needSkip {
78 readP = readP[:partSize-needSkip]
79 }
80 pn, err0 := parts[0].ReadAt(readP, needSkip)
81 if err0 != nil {
82 return n, err0
83 }
84 n += pn
85 p = p[pn:]
86 if int64(pn)+needSkip == partSize {
87 parts = parts[1:]
88 }
89 needSkip = 0
90 }
91
92 if n != wantN {
93 err = io.ErrUnexpectedEOF
94 }
95 return
96}
97
98// NewMultiReaderSeeker returns an io.ReadSeeker that combines
99// all the given readers into a single one. It assumes that
100// all the seekers are initially positioned at the start.
101func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker {
102 sreaders := make([]SizeReaderAt, len(readers))
103 for i, r := range readers {
104 r1, err := newSizeReaderAt(r)
105 if err != nil {
106 panic(err)
107 }
108 sreaders[i] = r1
109 }
110 return &readSeeker{
111 r: NewMultiReaderAt(sreaders...),
112 }
113}
114
115// newSizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt.
116// Note that it doesn't strictly adhere to the ReaderAt
117// contract because it's not safe to call ReadAt concurrently.
118// This doesn't matter because io.ReadSeeker doesn't
119// need to be thread-safe and this is only used in that
120// context.
121func newSizeReaderAt(r io.ReadSeeker) (SizeReaderAt, error) {
122 size, err := r.Seek(0, 2)
123 if err != nil {
124 return nil, err
125 }
126 return &sizeReaderAt{
127 r: r,
128 size: size,
129 off: size,
130 }, nil
131}
132
133// sizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt.
134type sizeReaderAt struct {
135 r io.ReadSeeker
136 size int64
137 off int64
138}
139
140// ReadAt implemnts SizeReaderAt.ReadAt.
141func (r *sizeReaderAt) ReadAt(buf []byte, off int64) (n int, err error) {
142 if off != r.off {
143 _, err = r.r.Seek(off, 0)
144 if err != nil {
145 return 0, err
146 }
147 r.off = off
148 }
149 n, err = io.ReadFull(r.r, buf)
150 r.off += int64(n)
151 return n, err
152}
153
154// Size implemnts SizeReaderAt.Size.
155func (r *sizeReaderAt) Size() int64 {
156 return r.size
157}
158
159// readSeeker adapts a SizeReaderAt to an io.ReadSeeker.
160type readSeeker struct {
161 r SizeReaderAt
162 off int64
163}
164
165// Seek implements io.Seeker.Seek.
166func (r *readSeeker) Seek(off int64, whence int) (int64, error) {
167 switch whence {
168 case 0:
169 case 1:
170 off += r.off
171 case 2:
172 off = r.r.Size() + off
173 }
174 if off < 0 {
175 return 0, errors.New("negative position")
176 }
177 r.off = off
178 return off, nil
179}
180
181// Read implements io.Reader.Read.
182func (r *readSeeker) Read(buf []byte) (int, error) {
183 n, err := r.r.ReadAt(buf, r.off)
184 r.off += int64(n)
185 if err == io.ErrUnexpectedEOF {
186 err = io.EOF
187 }
188 return n, err
189}