blob: 303b3e45d36b2d791c4e7a35a45241419da8ed4a [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package v3rpc
16
17import (
18 "context"
19 "io"
20 "sync"
21 "time"
22
23 "github.com/coreos/etcd/auth"
24 "github.com/coreos/etcd/etcdserver"
25 "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
26 pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
27 "github.com/coreos/etcd/mvcc"
28 "github.com/coreos/etcd/mvcc/mvccpb"
29)
30
31type watchServer struct {
32 clusterID int64
33 memberID int64
34
35 maxRequestBytes int
36
37 raftTimer etcdserver.RaftTimer
38 watchable mvcc.WatchableKV
39
40 ag AuthGetter
41}
42
43func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
44 return &watchServer{
45 clusterID: int64(s.Cluster().ID()),
46 memberID: int64(s.ID()),
47 maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
48 raftTimer: s,
49 watchable: s.Watchable(),
50 ag: s,
51 }
52}
53
54var (
55 // External test can read this with GetProgressReportInterval()
56 // and change this to a small value to finish fast with
57 // SetProgressReportInterval().
58 progressReportInterval = 10 * time.Minute
59 progressReportIntervalMu sync.RWMutex
60)
61
62func GetProgressReportInterval() time.Duration {
63 progressReportIntervalMu.RLock()
64 defer progressReportIntervalMu.RUnlock()
65 return progressReportInterval
66}
67
68func SetProgressReportInterval(newTimeout time.Duration) {
69 progressReportIntervalMu.Lock()
70 defer progressReportIntervalMu.Unlock()
71 progressReportInterval = newTimeout
72}
73
74const (
75 // We send ctrl response inside the read loop. We do not want
76 // send to block read, but we still want ctrl response we sent to
77 // be serialized. Thus we use a buffered chan to solve the problem.
78 // A small buffer should be OK for most cases, since we expect the
79 // ctrl requests are infrequent.
80 ctrlStreamBufLen = 16
81)
82
83// serverWatchStream is an etcd server side stream. It receives requests
84// from client side gRPC stream. It receives watch events from mvcc.WatchStream,
85// and creates responses that forwarded to gRPC stream.
86// It also forwards control message like watch created and canceled.
87type serverWatchStream struct {
88 clusterID int64
89 memberID int64
90
91 maxRequestBytes int
92
93 raftTimer etcdserver.RaftTimer
94
95 watchable mvcc.WatchableKV
96
97 gRPCStream pb.Watch_WatchServer
98 watchStream mvcc.WatchStream
99 ctrlStream chan *pb.WatchResponse
100
101 // mu protects progress, prevKV
102 mu sync.RWMutex
103 // progress tracks the watchID that stream might need to send
104 // progress to.
105 // TODO: combine progress and prevKV into a single struct?
106 progress map[mvcc.WatchID]bool
107 prevKV map[mvcc.WatchID]bool
108 // records fragmented watch IDs
109 fragment map[mvcc.WatchID]bool
110
111 // closec indicates the stream is closed.
112 closec chan struct{}
113
114 // wg waits for the send loop to complete
115 wg sync.WaitGroup
116
117 ag AuthGetter
118}
119
120func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
121 sws := serverWatchStream{
122 clusterID: ws.clusterID,
123 memberID: ws.memberID,
124
125 maxRequestBytes: ws.maxRequestBytes,
126
127 raftTimer: ws.raftTimer,
128
129 watchable: ws.watchable,
130
131 gRPCStream: stream,
132 watchStream: ws.watchable.NewWatchStream(),
133 // chan for sending control response like watcher created and canceled.
134 ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
135 progress: make(map[mvcc.WatchID]bool),
136 prevKV: make(map[mvcc.WatchID]bool),
137 fragment: make(map[mvcc.WatchID]bool),
138 closec: make(chan struct{}),
139
140 ag: ws.ag,
141 }
142
143 sws.wg.Add(1)
144 go func() {
145 sws.sendLoop()
146 sws.wg.Done()
147 }()
148
149 errc := make(chan error, 1)
150 // Ideally recvLoop would also use sws.wg to signal its completion
151 // but when stream.Context().Done() is closed, the stream's recv
152 // may continue to block since it uses a different context, leading to
153 // deadlock when calling sws.close().
154 go func() {
155 if rerr := sws.recvLoop(); rerr != nil {
156 if isClientCtxErr(stream.Context().Err(), rerr) {
157 plog.Debugf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
158 } else {
159 plog.Warningf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
160 }
161 errc <- rerr
162 }
163 }()
164 select {
165 case err = <-errc:
166 close(sws.ctrlStream)
167 case <-stream.Context().Done():
168 err = stream.Context().Err()
169 // the only server-side cancellation is noleader for now.
170 if err == context.Canceled {
171 err = rpctypes.ErrGRPCNoLeader
172 }
173 }
174 sws.close()
175 return err
176}
177
178func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
179 authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
180 if err != nil {
181 return false
182 }
183 if authInfo == nil {
184 // if auth is enabled, IsRangePermitted() can cause an error
185 authInfo = &auth.AuthInfo{}
186 }
187
188 return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
189}
190
191func (sws *serverWatchStream) recvLoop() error {
192 for {
193 req, err := sws.gRPCStream.Recv()
194 if err == io.EOF {
195 return nil
196 }
197 if err != nil {
198 return err
199 }
200
201 switch uv := req.RequestUnion.(type) {
202 case *pb.WatchRequest_CreateRequest:
203 if uv.CreateRequest == nil {
204 break
205 }
206
207 creq := uv.CreateRequest
208 if len(creq.Key) == 0 {
209 // \x00 is the smallest key
210 creq.Key = []byte{0}
211 }
212 if len(creq.RangeEnd) == 0 {
213 // force nil since watchstream.Watch distinguishes
214 // between nil and []byte{} for single key / >=
215 creq.RangeEnd = nil
216 }
217 if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
218 // support >= key queries
219 creq.RangeEnd = []byte{}
220 }
221
222 if !sws.isWatchPermitted(creq) {
223 wr := &pb.WatchResponse{
224 Header: sws.newResponseHeader(sws.watchStream.Rev()),
225 WatchId: -1,
226 Canceled: true,
227 Created: true,
228 CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(),
229 }
230
231 select {
232 case sws.ctrlStream <- wr:
233 continue
234 case <-sws.closec:
235 return nil
236 }
237 }
238
239 filters := FiltersFromRequest(creq)
240
241 wsrev := sws.watchStream.Rev()
242 rev := creq.StartRevision
243 if rev == 0 {
244 rev = wsrev + 1
245 }
246 id := sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev, filters...)
247 if id != -1 {
248 sws.mu.Lock()
249 if creq.ProgressNotify {
250 sws.progress[id] = true
251 }
252 if creq.PrevKv {
253 sws.prevKV[id] = true
254 }
255 if creq.Fragment {
256 sws.fragment[id] = true
257 }
258 sws.mu.Unlock()
259 }
260 wr := &pb.WatchResponse{
261 Header: sws.newResponseHeader(wsrev),
262 WatchId: int64(id),
263 Created: true,
264 Canceled: id == -1,
265 }
266 select {
267 case sws.ctrlStream <- wr:
268 case <-sws.closec:
269 return nil
270 }
271 case *pb.WatchRequest_CancelRequest:
272 if uv.CancelRequest != nil {
273 id := uv.CancelRequest.WatchId
274 err := sws.watchStream.Cancel(mvcc.WatchID(id))
275 if err == nil {
276 sws.ctrlStream <- &pb.WatchResponse{
277 Header: sws.newResponseHeader(sws.watchStream.Rev()),
278 WatchId: id,
279 Canceled: true,
280 }
281 sws.mu.Lock()
282 delete(sws.progress, mvcc.WatchID(id))
283 delete(sws.prevKV, mvcc.WatchID(id))
284 delete(sws.fragment, mvcc.WatchID(id))
285 sws.mu.Unlock()
286 }
287 }
288 case *pb.WatchRequest_ProgressRequest:
289 if uv.ProgressRequest != nil {
290 sws.ctrlStream <- &pb.WatchResponse{
291 Header: sws.newResponseHeader(sws.watchStream.Rev()),
292 WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels
293 }
294 }
295 default:
296 // we probably should not shutdown the entire stream when
297 // receive an valid command.
298 // so just do nothing instead.
299 continue
300 }
301 }
302}
303
304func (sws *serverWatchStream) sendLoop() {
305 // watch ids that are currently active
306 ids := make(map[mvcc.WatchID]struct{})
307 // watch responses pending on a watch id creation message
308 pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
309
310 interval := GetProgressReportInterval()
311 progressTicker := time.NewTicker(interval)
312
313 defer func() {
314 progressTicker.Stop()
315 // drain the chan to clean up pending events
316 for ws := range sws.watchStream.Chan() {
317 mvcc.ReportEventReceived(len(ws.Events))
318 }
319 for _, wrs := range pending {
320 for _, ws := range wrs {
321 mvcc.ReportEventReceived(len(ws.Events))
322 }
323 }
324 }()
325
326 for {
327 select {
328 case wresp, ok := <-sws.watchStream.Chan():
329 if !ok {
330 return
331 }
332
333 // TODO: evs is []mvccpb.Event type
334 // either return []*mvccpb.Event from the mvcc package
335 // or define protocol buffer with []mvccpb.Event.
336 evs := wresp.Events
337 events := make([]*mvccpb.Event, len(evs))
338 sws.mu.RLock()
339 needPrevKV := sws.prevKV[wresp.WatchID]
340 sws.mu.RUnlock()
341 for i := range evs {
342 events[i] = &evs[i]
343
344 if needPrevKV {
345 opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
346 r, err := sws.watchable.Range(evs[i].Kv.Key, nil, opt)
347 if err == nil && len(r.KVs) != 0 {
348 events[i].PrevKv = &(r.KVs[0])
349 }
350 }
351 }
352
353 canceled := wresp.CompactRevision != 0
354 wr := &pb.WatchResponse{
355 Header: sws.newResponseHeader(wresp.Revision),
356 WatchId: int64(wresp.WatchID),
357 Events: events,
358 CompactRevision: wresp.CompactRevision,
359 Canceled: canceled,
360 }
361
362 if _, hasId := ids[wresp.WatchID]; !hasId {
363 // buffer if id not yet announced
364 wrs := append(pending[wresp.WatchID], wr)
365 pending[wresp.WatchID] = wrs
366 continue
367 }
368
369 mvcc.ReportEventReceived(len(evs))
370
371 sws.mu.RLock()
372 fragmented, ok := sws.fragment[wresp.WatchID]
373 sws.mu.RUnlock()
374
375 var serr error
376 if !fragmented && !ok {
377 serr = sws.gRPCStream.Send(wr)
378 } else {
379 serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
380 }
381
382 if serr != nil {
383 if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
384 plog.Debugf("failed to send watch response to gRPC stream (%q)", serr.Error())
385 } else {
386 plog.Warningf("failed to send watch response to gRPC stream (%q)", serr.Error())
387 }
388 return
389 }
390
391 sws.mu.Lock()
392 if len(evs) > 0 && sws.progress[wresp.WatchID] {
393 // elide next progress update if sent a key update
394 sws.progress[wresp.WatchID] = false
395 }
396 sws.mu.Unlock()
397
398 case c, ok := <-sws.ctrlStream:
399 if !ok {
400 return
401 }
402
403 if err := sws.gRPCStream.Send(c); err != nil {
404 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
405 plog.Debugf("failed to send watch control response to gRPC stream (%q)", err.Error())
406 } else {
407 plog.Warningf("failed to send watch control response to gRPC stream (%q)", err.Error())
408 }
409 return
410 }
411
412 // track id creation
413 wid := mvcc.WatchID(c.WatchId)
414 if c.Canceled {
415 delete(ids, wid)
416 continue
417 }
418 if c.Created {
419 // flush buffered events
420 ids[wid] = struct{}{}
421 for _, v := range pending[wid] {
422 mvcc.ReportEventReceived(len(v.Events))
423 if err := sws.gRPCStream.Send(v); err != nil {
424 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
425 plog.Debugf("failed to send pending watch response to gRPC stream (%q)", err.Error())
426 } else {
427 plog.Warningf("failed to send pending watch response to gRPC stream (%q)", err.Error())
428 }
429 return
430 }
431 }
432 delete(pending, wid)
433 }
434 case <-progressTicker.C:
435 sws.mu.Lock()
436 for id, ok := range sws.progress {
437 if ok {
438 sws.watchStream.RequestProgress(id)
439 }
440 sws.progress[id] = true
441 }
442 sws.mu.Unlock()
443 case <-sws.closec:
444 return
445 }
446 }
447}
448
449func sendFragments(
450 wr *pb.WatchResponse,
451 maxRequestBytes int,
452 sendFunc func(*pb.WatchResponse) error) error {
453 // no need to fragment if total request size is smaller
454 // than max request limit or response contains only one event
455 if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
456 return sendFunc(wr)
457 }
458
459 ow := *wr
460 ow.Events = make([]*mvccpb.Event, 0)
461 ow.Fragment = true
462
463 var idx int
464 for {
465 cur := ow
466 for _, ev := range wr.Events[idx:] {
467 cur.Events = append(cur.Events, ev)
468 if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
469 cur.Events = cur.Events[:len(cur.Events)-1]
470 break
471 }
472 idx++
473 }
474 if idx == len(wr.Events) {
475 // last response has no more fragment
476 cur.Fragment = false
477 }
478 if err := sendFunc(&cur); err != nil {
479 return err
480 }
481 if !cur.Fragment {
482 break
483 }
484 }
485 return nil
486}
487
488func (sws *serverWatchStream) close() {
489 sws.watchStream.Close()
490 close(sws.closec)
491 sws.wg.Wait()
492}
493
494func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
495 return &pb.ResponseHeader{
496 ClusterId: uint64(sws.clusterID),
497 MemberId: uint64(sws.memberID),
498 Revision: rev,
499 RaftTerm: sws.raftTimer.Term(),
500 }
501}
502
503func filterNoDelete(e mvccpb.Event) bool {
504 return e.Type == mvccpb.DELETE
505}
506
507func filterNoPut(e mvccpb.Event) bool {
508 return e.Type == mvccpb.PUT
509}
510
511func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
512 filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
513 for _, ft := range creq.Filters {
514 switch ft {
515 case pb.WatchCreateRequest_NOPUT:
516 filters = append(filters, filterNoPut)
517 case pb.WatchCreateRequest_NODELETE:
518 filters = append(filters, filterNoDelete)
519 default:
520 }
521 }
522 return filters
523}