David K. Bainbridge | 528b318 | 2017-01-23 08:51:59 -0800 | [diff] [blame] | 1 | // Copyright 2016 Canonical Ltd. |
| 2 | // Licensed under the LGPLv3, see LICENCE file for details. |
| 3 | |
| 4 | package utils |
| 5 | |
| 6 | import ( |
| 7 | "io" |
| 8 | "sort" |
| 9 | |
| 10 | "github.com/juju/errors" |
| 11 | ) |
| 12 | |
| 13 | // SizeReaderAt combines io.ReaderAt with a Size method. |
| 14 | type SizeReaderAt interface { |
| 15 | // Size returns the size of the data readable |
| 16 | // from the reader. |
| 17 | Size() int64 |
| 18 | io.ReaderAt |
| 19 | } |
| 20 | |
| 21 | // NewMultiReaderAt is like io.MultiReader but produces a ReaderAt |
| 22 | // (and Size), instead of just a reader. |
| 23 | // |
| 24 | // Note: this implementation was taken from a talk given |
| 25 | // by Brad Fitzpatrick as OSCON 2013. |
| 26 | // |
| 27 | // http://talks.golang.org/2013/oscon-dl.slide#49 |
| 28 | // https://github.com/golang/talks/blob/master/2013/oscon-dl/server-compose.go |
| 29 | func NewMultiReaderAt(parts ...SizeReaderAt) SizeReaderAt { |
| 30 | m := &multiReaderAt{ |
| 31 | parts: make([]offsetAndSource, 0, len(parts)), |
| 32 | } |
| 33 | var off int64 |
| 34 | for _, p := range parts { |
| 35 | m.parts = append(m.parts, offsetAndSource{off, p}) |
| 36 | off += p.Size() |
| 37 | } |
| 38 | m.size = off |
| 39 | return m |
| 40 | } |
| 41 | |
| 42 | type offsetAndSource struct { |
| 43 | off int64 |
| 44 | SizeReaderAt |
| 45 | } |
| 46 | |
| 47 | type multiReaderAt struct { |
| 48 | parts []offsetAndSource |
| 49 | size int64 |
| 50 | } |
| 51 | |
| 52 | func (m *multiReaderAt) Size() int64 { |
| 53 | return m.size |
| 54 | } |
| 55 | |
| 56 | func (m *multiReaderAt) ReadAt(p []byte, off int64) (n int, err error) { |
| 57 | wantN := len(p) |
| 58 | |
| 59 | // Skip past the requested offset. |
| 60 | skipParts := sort.Search(len(m.parts), func(i int) bool { |
| 61 | // This function returns whether parts[i] will |
| 62 | // contribute any bytes to our output. |
| 63 | part := m.parts[i] |
| 64 | return part.off+part.Size() > off |
| 65 | }) |
| 66 | parts := m.parts[skipParts:] |
| 67 | |
| 68 | // How far to skip in the first part. |
| 69 | needSkip := off |
| 70 | if len(parts) > 0 { |
| 71 | needSkip -= parts[0].off |
| 72 | } |
| 73 | |
| 74 | for len(parts) > 0 && len(p) > 0 { |
| 75 | readP := p |
| 76 | partSize := parts[0].Size() |
| 77 | if int64(len(readP)) > partSize-needSkip { |
| 78 | readP = readP[:partSize-needSkip] |
| 79 | } |
| 80 | pn, err0 := parts[0].ReadAt(readP, needSkip) |
| 81 | if err0 != nil { |
| 82 | return n, err0 |
| 83 | } |
| 84 | n += pn |
| 85 | p = p[pn:] |
| 86 | if int64(pn)+needSkip == partSize { |
| 87 | parts = parts[1:] |
| 88 | } |
| 89 | needSkip = 0 |
| 90 | } |
| 91 | |
| 92 | if n != wantN { |
| 93 | err = io.ErrUnexpectedEOF |
| 94 | } |
| 95 | return |
| 96 | } |
| 97 | |
| 98 | // NewMultiReaderSeeker returns an io.ReadSeeker that combines |
| 99 | // all the given readers into a single one. It assumes that |
| 100 | // all the seekers are initially positioned at the start. |
| 101 | func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker { |
| 102 | sreaders := make([]SizeReaderAt, len(readers)) |
| 103 | for i, r := range readers { |
| 104 | r1, err := newSizeReaderAt(r) |
| 105 | if err != nil { |
| 106 | panic(err) |
| 107 | } |
| 108 | sreaders[i] = r1 |
| 109 | } |
| 110 | return &readSeeker{ |
| 111 | r: NewMultiReaderAt(sreaders...), |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | // newSizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt. |
| 116 | // Note that it doesn't strictly adhere to the ReaderAt |
| 117 | // contract because it's not safe to call ReadAt concurrently. |
| 118 | // This doesn't matter because io.ReadSeeker doesn't |
| 119 | // need to be thread-safe and this is only used in that |
| 120 | // context. |
| 121 | func newSizeReaderAt(r io.ReadSeeker) (SizeReaderAt, error) { |
| 122 | size, err := r.Seek(0, 2) |
| 123 | if err != nil { |
| 124 | return nil, err |
| 125 | } |
| 126 | return &sizeReaderAt{ |
| 127 | r: r, |
| 128 | size: size, |
| 129 | off: size, |
| 130 | }, nil |
| 131 | } |
| 132 | |
| 133 | // sizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt. |
| 134 | type sizeReaderAt struct { |
| 135 | r io.ReadSeeker |
| 136 | size int64 |
| 137 | off int64 |
| 138 | } |
| 139 | |
| 140 | // ReadAt implemnts SizeReaderAt.ReadAt. |
| 141 | func (r *sizeReaderAt) ReadAt(buf []byte, off int64) (n int, err error) { |
| 142 | if off != r.off { |
| 143 | _, err = r.r.Seek(off, 0) |
| 144 | if err != nil { |
| 145 | return 0, err |
| 146 | } |
| 147 | r.off = off |
| 148 | } |
| 149 | n, err = io.ReadFull(r.r, buf) |
| 150 | r.off += int64(n) |
| 151 | return n, err |
| 152 | } |
| 153 | |
| 154 | // Size implemnts SizeReaderAt.Size. |
| 155 | func (r *sizeReaderAt) Size() int64 { |
| 156 | return r.size |
| 157 | } |
| 158 | |
| 159 | // readSeeker adapts a SizeReaderAt to an io.ReadSeeker. |
| 160 | type readSeeker struct { |
| 161 | r SizeReaderAt |
| 162 | off int64 |
| 163 | } |
| 164 | |
| 165 | // Seek implements io.Seeker.Seek. |
| 166 | func (r *readSeeker) Seek(off int64, whence int) (int64, error) { |
| 167 | switch whence { |
| 168 | case 0: |
| 169 | case 1: |
| 170 | off += r.off |
| 171 | case 2: |
| 172 | off = r.r.Size() + off |
| 173 | } |
| 174 | if off < 0 { |
| 175 | return 0, errors.New("negative position") |
| 176 | } |
| 177 | r.off = off |
| 178 | return off, nil |
| 179 | } |
| 180 | |
| 181 | // Read implements io.Reader.Read. |
| 182 | func (r *readSeeker) Read(buf []byte) (int, error) { |
| 183 | n, err := r.r.ReadAt(buf, r.off) |
| 184 | r.off += int64(n) |
| 185 | if err == io.ErrUnexpectedEOF { |
| 186 | err = io.EOF |
| 187 | } |
| 188 | return n, err |
| 189 | } |