blob: 0d75cb109a09afa7c73341ca65e7ee9ba9e19da1 [file] [log] [blame]
Prince Pereirac1c21d62021-04-22 08:38:15 +00001/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22 "context"
23 "errors"
24 "fmt"
25 "io"
26 "math"
27 "net"
28 "net/http"
29 "reflect"
30 "runtime"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "time"
35
36 "golang.org/x/net/trace"
37
38 "google.golang.org/grpc/codes"
39 "google.golang.org/grpc/credentials"
40 "google.golang.org/grpc/encoding"
41 "google.golang.org/grpc/encoding/proto"
42 "google.golang.org/grpc/grpclog"
43 "google.golang.org/grpc/internal/binarylog"
44 "google.golang.org/grpc/internal/channelz"
45 "google.golang.org/grpc/internal/grpcsync"
46 "google.golang.org/grpc/internal/transport"
47 "google.golang.org/grpc/keepalive"
48 "google.golang.org/grpc/metadata"
49 "google.golang.org/grpc/peer"
50 "google.golang.org/grpc/stats"
51 "google.golang.org/grpc/status"
52 "google.golang.org/grpc/tap"
53)
54
55const (
56 defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
57 defaultServerMaxSendMessageSize = math.MaxInt32
58)
59
60var statusOK = status.New(codes.OK, "")
61
62type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
63
64// MethodDesc represents an RPC service's method specification.
65type MethodDesc struct {
66 MethodName string
67 Handler methodHandler
68}
69
70// ServiceDesc represents an RPC service's specification.
71type ServiceDesc struct {
72 ServiceName string
73 // The pointer to the service interface. Used to check whether the user
74 // provided implementation satisfies the interface requirements.
75 HandlerType interface{}
76 Methods []MethodDesc
77 Streams []StreamDesc
78 Metadata interface{}
79}
80
81// service consists of the information of the server serving this service and
82// the methods in this service.
83type service struct {
84 server interface{} // the server for service methods
85 md map[string]*MethodDesc
86 sd map[string]*StreamDesc
87 mdata interface{}
88}
89
90// Server is a gRPC server to serve RPC requests.
91type Server struct {
92 opts serverOptions
93
94 mu sync.Mutex // guards following
95 lis map[net.Listener]bool
96 conns map[transport.ServerTransport]bool
97 serve bool
98 drain bool
99 cv *sync.Cond // signaled when connections close for GracefulStop
100 m map[string]*service // service name -> service info
101 events trace.EventLog
102
103 quit *grpcsync.Event
104 done *grpcsync.Event
105 channelzRemoveOnce sync.Once
106 serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
107
108 channelzID int64 // channelz unique identification number
109 czData *channelzData
110}
111
112type serverOptions struct {
113 creds credentials.TransportCredentials
114 codec baseCodec
115 cp Compressor
116 dc Decompressor
117 unaryInt UnaryServerInterceptor
118 streamInt StreamServerInterceptor
119 inTapHandle tap.ServerInHandle
120 statsHandler stats.Handler
121 maxConcurrentStreams uint32
122 maxReceiveMessageSize int
123 maxSendMessageSize int
124 unknownStreamDesc *StreamDesc
125 keepaliveParams keepalive.ServerParameters
126 keepalivePolicy keepalive.EnforcementPolicy
127 initialWindowSize int32
128 initialConnWindowSize int32
129 writeBufferSize int
130 readBufferSize int
131 connectionTimeout time.Duration
132 maxHeaderListSize *uint32
133 headerTableSize *uint32
134}
135
136var defaultServerOptions = serverOptions{
137 maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
138 maxSendMessageSize: defaultServerMaxSendMessageSize,
139 connectionTimeout: 120 * time.Second,
140 writeBufferSize: defaultWriteBufSize,
141 readBufferSize: defaultReadBufSize,
142}
143
144// A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
145type ServerOption interface {
146 apply(*serverOptions)
147}
148
149// EmptyServerOption does not alter the server configuration. It can be embedded
150// in another structure to build custom server options.
151//
152// This API is EXPERIMENTAL.
153type EmptyServerOption struct{}
154
155func (EmptyServerOption) apply(*serverOptions) {}
156
157// funcServerOption wraps a function that modifies serverOptions into an
158// implementation of the ServerOption interface.
159type funcServerOption struct {
160 f func(*serverOptions)
161}
162
163func (fdo *funcServerOption) apply(do *serverOptions) {
164 fdo.f(do)
165}
166
167func newFuncServerOption(f func(*serverOptions)) *funcServerOption {
168 return &funcServerOption{
169 f: f,
170 }
171}
172
173// WriteBufferSize determines how much data can be batched before doing a write on the wire.
174// The corresponding memory allocation for this buffer will be twice the size to keep syscalls low.
175// The default value for this buffer is 32KB.
176// Zero will disable the write buffer such that each write will be on underlying connection.
177// Note: A Send call may not directly translate to a write.
178func WriteBufferSize(s int) ServerOption {
179 return newFuncServerOption(func(o *serverOptions) {
180 o.writeBufferSize = s
181 })
182}
183
184// ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
185// for one read syscall.
186// The default value for this buffer is 32KB.
187// Zero will disable read buffer for a connection so data framer can access the underlying
188// conn directly.
189func ReadBufferSize(s int) ServerOption {
190 return newFuncServerOption(func(o *serverOptions) {
191 o.readBufferSize = s
192 })
193}
194
195// InitialWindowSize returns a ServerOption that sets window size for stream.
196// The lower bound for window size is 64K and any value smaller than that will be ignored.
197func InitialWindowSize(s int32) ServerOption {
198 return newFuncServerOption(func(o *serverOptions) {
199 o.initialWindowSize = s
200 })
201}
202
203// InitialConnWindowSize returns a ServerOption that sets window size for a connection.
204// The lower bound for window size is 64K and any value smaller than that will be ignored.
205func InitialConnWindowSize(s int32) ServerOption {
206 return newFuncServerOption(func(o *serverOptions) {
207 o.initialConnWindowSize = s
208 })
209}
210
211// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
212func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
213 if kp.Time > 0 && kp.Time < time.Second {
214 grpclog.Warning("Adjusting keepalive ping interval to minimum period of 1s")
215 kp.Time = time.Second
216 }
217
218 return newFuncServerOption(func(o *serverOptions) {
219 o.keepaliveParams = kp
220 })
221}
222
223// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
224func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
225 return newFuncServerOption(func(o *serverOptions) {
226 o.keepalivePolicy = kep
227 })
228}
229
230// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
231//
232// This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
233func CustomCodec(codec Codec) ServerOption {
234 return newFuncServerOption(func(o *serverOptions) {
235 o.codec = codec
236 })
237}
238
239// RPCCompressor returns a ServerOption that sets a compressor for outbound
240// messages. For backward compatibility, all outbound messages will be sent
241// using this compressor, regardless of incoming message compression. By
242// default, server messages will be sent using the same compressor with which
243// request messages were sent.
244//
245// Deprecated: use encoding.RegisterCompressor instead.
246func RPCCompressor(cp Compressor) ServerOption {
247 return newFuncServerOption(func(o *serverOptions) {
248 o.cp = cp
249 })
250}
251
252// RPCDecompressor returns a ServerOption that sets a decompressor for inbound
253// messages. It has higher priority than decompressors registered via
254// encoding.RegisterCompressor.
255//
256// Deprecated: use encoding.RegisterCompressor instead.
257func RPCDecompressor(dc Decompressor) ServerOption {
258 return newFuncServerOption(func(o *serverOptions) {
259 o.dc = dc
260 })
261}
262
263// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
264// If this is not set, gRPC uses the default limit.
265//
266// Deprecated: use MaxRecvMsgSize instead.
267func MaxMsgSize(m int) ServerOption {
268 return MaxRecvMsgSize(m)
269}
270
271// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
272// If this is not set, gRPC uses the default 4MB.
273func MaxRecvMsgSize(m int) ServerOption {
274 return newFuncServerOption(func(o *serverOptions) {
275 o.maxReceiveMessageSize = m
276 })
277}
278
279// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
280// If this is not set, gRPC uses the default `math.MaxInt32`.
281func MaxSendMsgSize(m int) ServerOption {
282 return newFuncServerOption(func(o *serverOptions) {
283 o.maxSendMessageSize = m
284 })
285}
286
287// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
288// of concurrent streams to each ServerTransport.
289func MaxConcurrentStreams(n uint32) ServerOption {
290 return newFuncServerOption(func(o *serverOptions) {
291 o.maxConcurrentStreams = n
292 })
293}
294
295// Creds returns a ServerOption that sets credentials for server connections.
296func Creds(c credentials.TransportCredentials) ServerOption {
297 return newFuncServerOption(func(o *serverOptions) {
298 o.creds = c
299 })
300}
301
302// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
303// server. Only one unary interceptor can be installed. The construction of multiple
304// interceptors (e.g., chaining) can be implemented at the caller.
305func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
306 return newFuncServerOption(func(o *serverOptions) {
307 if o.unaryInt != nil {
308 panic("The unary server interceptor was already set and may not be reset.")
309 }
310 o.unaryInt = i
311 })
312}
313
314// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
315// server. Only one stream interceptor can be installed.
316func StreamInterceptor(i StreamServerInterceptor) ServerOption {
317 return newFuncServerOption(func(o *serverOptions) {
318 if o.streamInt != nil {
319 panic("The stream server interceptor was already set and may not be reset.")
320 }
321 o.streamInt = i
322 })
323}
324
325// InTapHandle returns a ServerOption that sets the tap handle for all the server
326// transport to be created. Only one can be installed.
327func InTapHandle(h tap.ServerInHandle) ServerOption {
328 return newFuncServerOption(func(o *serverOptions) {
329 if o.inTapHandle != nil {
330 panic("The tap handle was already set and may not be reset.")
331 }
332 o.inTapHandle = h
333 })
334}
335
336// StatsHandler returns a ServerOption that sets the stats handler for the server.
337func StatsHandler(h stats.Handler) ServerOption {
338 return newFuncServerOption(func(o *serverOptions) {
339 o.statsHandler = h
340 })
341}
342
343// UnknownServiceHandler returns a ServerOption that allows for adding a custom
344// unknown service handler. The provided method is a bidi-streaming RPC service
345// handler that will be invoked instead of returning the "unimplemented" gRPC
346// error whenever a request is received for an unregistered service or method.
347// The handling function and stream interceptor (if set) have full access to
348// the ServerStream, including its Context.
349func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
350 return newFuncServerOption(func(o *serverOptions) {
351 o.unknownStreamDesc = &StreamDesc{
352 StreamName: "unknown_service_handler",
353 Handler: streamHandler,
354 // We need to assume that the users of the streamHandler will want to use both.
355 ClientStreams: true,
356 ServerStreams: true,
357 }
358 })
359}
360
361// ConnectionTimeout returns a ServerOption that sets the timeout for
362// connection establishment (up to and including HTTP/2 handshaking) for all
363// new connections. If this is not set, the default is 120 seconds. A zero or
364// negative value will result in an immediate timeout.
365//
366// This API is EXPERIMENTAL.
367func ConnectionTimeout(d time.Duration) ServerOption {
368 return newFuncServerOption(func(o *serverOptions) {
369 o.connectionTimeout = d
370 })
371}
372
373// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
374// of header list that the server is prepared to accept.
375func MaxHeaderListSize(s uint32) ServerOption {
376 return newFuncServerOption(func(o *serverOptions) {
377 o.maxHeaderListSize = &s
378 })
379}
380
381// HeaderTableSize returns a ServerOption that sets the size of dynamic
382// header table for stream.
383//
384// This API is EXPERIMENTAL.
385func HeaderTableSize(s uint32) ServerOption {
386 return newFuncServerOption(func(o *serverOptions) {
387 o.headerTableSize = &s
388 })
389}
390
391// NewServer creates a gRPC server which has no service registered and has not
392// started to accept requests yet.
393func NewServer(opt ...ServerOption) *Server {
394 opts := defaultServerOptions
395 for _, o := range opt {
396 o.apply(&opts)
397 }
398 s := &Server{
399 lis: make(map[net.Listener]bool),
400 opts: opts,
401 conns: make(map[transport.ServerTransport]bool),
402 m: make(map[string]*service),
403 quit: grpcsync.NewEvent(),
404 done: grpcsync.NewEvent(),
405 czData: new(channelzData),
406 }
407 s.cv = sync.NewCond(&s.mu)
408 if EnableTracing {
409 _, file, line, _ := runtime.Caller(1)
410 s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
411 }
412
413 if channelz.IsOn() {
414 s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
415 }
416 return s
417}
418
419// printf records an event in s's event log, unless s has been stopped.
420// REQUIRES s.mu is held.
421func (s *Server) printf(format string, a ...interface{}) {
422 if s.events != nil {
423 s.events.Printf(format, a...)
424 }
425}
426
427// errorf records an error in s's event log, unless s has been stopped.
428// REQUIRES s.mu is held.
429func (s *Server) errorf(format string, a ...interface{}) {
430 if s.events != nil {
431 s.events.Errorf(format, a...)
432 }
433}
434
435// RegisterService registers a service and its implementation to the gRPC
436// server. It is called from the IDL generated code. This must be called before
437// invoking Serve.
438func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
439 ht := reflect.TypeOf(sd.HandlerType).Elem()
440 st := reflect.TypeOf(ss)
441 if !st.Implements(ht) {
442 grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
443 }
444 s.register(sd, ss)
445}
446
447func (s *Server) register(sd *ServiceDesc, ss interface{}) {
448 s.mu.Lock()
449 defer s.mu.Unlock()
450 s.printf("RegisterService(%q)", sd.ServiceName)
451 if s.serve {
452 grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
453 }
454 if _, ok := s.m[sd.ServiceName]; ok {
455 grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
456 }
457 srv := &service{
458 server: ss,
459 md: make(map[string]*MethodDesc),
460 sd: make(map[string]*StreamDesc),
461 mdata: sd.Metadata,
462 }
463 for i := range sd.Methods {
464 d := &sd.Methods[i]
465 srv.md[d.MethodName] = d
466 }
467 for i := range sd.Streams {
468 d := &sd.Streams[i]
469 srv.sd[d.StreamName] = d
470 }
471 s.m[sd.ServiceName] = srv
472}
473
474// MethodInfo contains the information of an RPC including its method name and type.
475type MethodInfo struct {
476 // Name is the method name only, without the service name or package name.
477 Name string
478 // IsClientStream indicates whether the RPC is a client streaming RPC.
479 IsClientStream bool
480 // IsServerStream indicates whether the RPC is a server streaming RPC.
481 IsServerStream bool
482}
483
484// ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service.
485type ServiceInfo struct {
486 Methods []MethodInfo
487 // Metadata is the metadata specified in ServiceDesc when registering service.
488 Metadata interface{}
489}
490
491// GetServiceInfo returns a map from service names to ServiceInfo.
492// Service names include the package names, in the form of <package>.<service>.
493func (s *Server) GetServiceInfo() map[string]ServiceInfo {
494 ret := make(map[string]ServiceInfo)
495 for n, srv := range s.m {
496 methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
497 for m := range srv.md {
498 methods = append(methods, MethodInfo{
499 Name: m,
500 IsClientStream: false,
501 IsServerStream: false,
502 })
503 }
504 for m, d := range srv.sd {
505 methods = append(methods, MethodInfo{
506 Name: m,
507 IsClientStream: d.ClientStreams,
508 IsServerStream: d.ServerStreams,
509 })
510 }
511
512 ret[n] = ServiceInfo{
513 Methods: methods,
514 Metadata: srv.mdata,
515 }
516 }
517 return ret
518}
519
520// ErrServerStopped indicates that the operation is now illegal because of
521// the server being stopped.
522var ErrServerStopped = errors.New("grpc: the server has been stopped")
523
524func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
525 if s.opts.creds == nil {
526 return rawConn, nil, nil
527 }
528 return s.opts.creds.ServerHandshake(rawConn)
529}
530
531type listenSocket struct {
532 net.Listener
533 channelzID int64
534}
535
536func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric {
537 return &channelz.SocketInternalMetric{
538 SocketOptions: channelz.GetSocketOption(l.Listener),
539 LocalAddr: l.Listener.Addr(),
540 }
541}
542
543func (l *listenSocket) Close() error {
544 err := l.Listener.Close()
545 if channelz.IsOn() {
546 channelz.RemoveEntry(l.channelzID)
547 }
548 return err
549}
550
551// Serve accepts incoming connections on the listener lis, creating a new
552// ServerTransport and service goroutine for each. The service goroutines
553// read gRPC requests and then call the registered handlers to reply to them.
554// Serve returns when lis.Accept fails with fatal errors. lis will be closed when
555// this method returns.
556// Serve will return a non-nil error unless Stop or GracefulStop is called.
557func (s *Server) Serve(lis net.Listener) error {
558 s.mu.Lock()
559 s.printf("serving")
560 s.serve = true
561 if s.lis == nil {
562 // Serve called after Stop or GracefulStop.
563 s.mu.Unlock()
564 lis.Close()
565 return ErrServerStopped
566 }
567
568 s.serveWG.Add(1)
569 defer func() {
570 s.serveWG.Done()
571 if s.quit.HasFired() {
572 // Stop or GracefulStop called; block until done and return nil.
573 <-s.done.Done()
574 }
575 }()
576
577 ls := &listenSocket{Listener: lis}
578 s.lis[ls] = true
579
580 if channelz.IsOn() {
581 ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String())
582 }
583 s.mu.Unlock()
584
585 defer func() {
586 s.mu.Lock()
587 if s.lis != nil && s.lis[ls] {
588 ls.Close()
589 delete(s.lis, ls)
590 }
591 s.mu.Unlock()
592 }()
593
594 var tempDelay time.Duration // how long to sleep on accept failure
595
596 for {
597 rawConn, err := lis.Accept()
598 if err != nil {
599 if ne, ok := err.(interface {
600 Temporary() bool
601 }); ok && ne.Temporary() {
602 if tempDelay == 0 {
603 tempDelay = 5 * time.Millisecond
604 } else {
605 tempDelay *= 2
606 }
607 if max := 1 * time.Second; tempDelay > max {
608 tempDelay = max
609 }
610 s.mu.Lock()
611 s.printf("Accept error: %v; retrying in %v", err, tempDelay)
612 s.mu.Unlock()
613 timer := time.NewTimer(tempDelay)
614 select {
615 case <-timer.C:
616 case <-s.quit.Done():
617 timer.Stop()
618 return nil
619 }
620 continue
621 }
622 s.mu.Lock()
623 s.printf("done serving; Accept = %v", err)
624 s.mu.Unlock()
625
626 if s.quit.HasFired() {
627 return nil
628 }
629 return err
630 }
631 tempDelay = 0
632 // Start a new goroutine to deal with rawConn so we don't stall this Accept
633 // loop goroutine.
634 //
635 // Make sure we account for the goroutine so GracefulStop doesn't nil out
636 // s.conns before this conn can be added.
637 s.serveWG.Add(1)
638 go func() {
639 s.handleRawConn(rawConn)
640 s.serveWG.Done()
641 }()
642 }
643}
644
645// handleRawConn forks a goroutine to handle a just-accepted connection that
646// has not had any I/O performed on it yet.
647func (s *Server) handleRawConn(rawConn net.Conn) {
648 if s.quit.HasFired() {
649 rawConn.Close()
650 return
651 }
652 rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
653 conn, authInfo, err := s.useTransportAuthenticator(rawConn)
654 if err != nil {
655 // ErrConnDispatched means that the connection was dispatched away from
656 // gRPC; those connections should be left open.
657 if err != credentials.ErrConnDispatched {
658 s.mu.Lock()
659 s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
660 s.mu.Unlock()
661 grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
662 rawConn.Close()
663 }
664 rawConn.SetDeadline(time.Time{})
665 return
666 }
667
668 // Finish handshaking (HTTP2)
669 st := s.newHTTP2Transport(conn, authInfo)
670 if st == nil {
671 return
672 }
673
674 rawConn.SetDeadline(time.Time{})
675 if !s.addConn(st) {
676 return
677 }
678 go func() {
679 s.serveStreams(st)
680 s.removeConn(st)
681 }()
682}
683
684// newHTTP2Transport sets up a http/2 transport (using the
685// gRPC http2 server transport in transport/http2_server.go).
686func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport {
687 config := &transport.ServerConfig{
688 MaxStreams: s.opts.maxConcurrentStreams,
689 AuthInfo: authInfo,
690 InTapHandle: s.opts.inTapHandle,
691 StatsHandler: s.opts.statsHandler,
692 KeepaliveParams: s.opts.keepaliveParams,
693 KeepalivePolicy: s.opts.keepalivePolicy,
694 InitialWindowSize: s.opts.initialWindowSize,
695 InitialConnWindowSize: s.opts.initialConnWindowSize,
696 WriteBufferSize: s.opts.writeBufferSize,
697 ReadBufferSize: s.opts.readBufferSize,
698 ChannelzParentID: s.channelzID,
699 MaxHeaderListSize: s.opts.maxHeaderListSize,
700 HeaderTableSize: s.opts.headerTableSize,
701 }
702 st, err := transport.NewServerTransport("http2", c, config)
703 if err != nil {
704 s.mu.Lock()
705 s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
706 s.mu.Unlock()
707 c.Close()
708 grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
709 return nil
710 }
711
712 return st
713}
714
715func (s *Server) serveStreams(st transport.ServerTransport) {
716 defer st.Close()
717 var wg sync.WaitGroup
718 st.HandleStreams(func(stream *transport.Stream) {
719 wg.Add(1)
720 go func() {
721 defer wg.Done()
722 s.handleStream(st, stream, s.traceInfo(st, stream))
723 }()
724 }, func(ctx context.Context, method string) context.Context {
725 if !EnableTracing {
726 return ctx
727 }
728 tr := trace.New("grpc.Recv."+methodFamily(method), method)
729 return trace.NewContext(ctx, tr)
730 })
731 wg.Wait()
732}
733
734var _ http.Handler = (*Server)(nil)
735
736// ServeHTTP implements the Go standard library's http.Handler
737// interface by responding to the gRPC request r, by looking up
738// the requested gRPC method in the gRPC server s.
739//
740// The provided HTTP request must have arrived on an HTTP/2
741// connection. When using the Go standard library's server,
742// practically this means that the Request must also have arrived
743// over TLS.
744//
745// To share one port (such as 443 for https) between gRPC and an
746// existing http.Handler, use a root http.Handler such as:
747//
748// if r.ProtoMajor == 2 && strings.HasPrefix(
749// r.Header.Get("Content-Type"), "application/grpc") {
750// grpcServer.ServeHTTP(w, r)
751// } else {
752// yourMux.ServeHTTP(w, r)
753// }
754//
755// Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally
756// separate from grpc-go's HTTP/2 server. Performance and features may vary
757// between the two paths. ServeHTTP does not support some gRPC features
758// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
759// and subject to change.
760func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
761 st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
762 if err != nil {
763 http.Error(w, err.Error(), http.StatusInternalServerError)
764 return
765 }
766 if !s.addConn(st) {
767 return
768 }
769 defer s.removeConn(st)
770 s.serveStreams(st)
771}
772
773// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
774// If tracing is not enabled, it returns nil.
775func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
776 if !EnableTracing {
777 return nil
778 }
779 tr, ok := trace.FromContext(stream.Context())
780 if !ok {
781 return nil
782 }
783
784 trInfo = &traceInfo{
785 tr: tr,
786 firstLine: firstLine{
787 client: false,
788 remoteAddr: st.RemoteAddr(),
789 },
790 }
791 if dl, ok := stream.Context().Deadline(); ok {
792 trInfo.firstLine.deadline = time.Until(dl)
793 }
794 return trInfo
795}
796
797func (s *Server) addConn(st transport.ServerTransport) bool {
798 s.mu.Lock()
799 defer s.mu.Unlock()
800 if s.conns == nil {
801 st.Close()
802 return false
803 }
804 if s.drain {
805 // Transport added after we drained our existing conns: drain it
806 // immediately.
807 st.Drain()
808 }
809 s.conns[st] = true
810 return true
811}
812
813func (s *Server) removeConn(st transport.ServerTransport) {
814 s.mu.Lock()
815 defer s.mu.Unlock()
816 if s.conns != nil {
817 delete(s.conns, st)
818 s.cv.Broadcast()
819 }
820}
821
822func (s *Server) channelzMetric() *channelz.ServerInternalMetric {
823 return &channelz.ServerInternalMetric{
824 CallsStarted: atomic.LoadInt64(&s.czData.callsStarted),
825 CallsSucceeded: atomic.LoadInt64(&s.czData.callsSucceeded),
826 CallsFailed: atomic.LoadInt64(&s.czData.callsFailed),
827 LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&s.czData.lastCallStartedTime)),
828 }
829}
830
831func (s *Server) incrCallsStarted() {
832 atomic.AddInt64(&s.czData.callsStarted, 1)
833 atomic.StoreInt64(&s.czData.lastCallStartedTime, time.Now().UnixNano())
834}
835
836func (s *Server) incrCallsSucceeded() {
837 atomic.AddInt64(&s.czData.callsSucceeded, 1)
838}
839
840func (s *Server) incrCallsFailed() {
841 atomic.AddInt64(&s.czData.callsFailed, 1)
842}
843
844func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
845 data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
846 if err != nil {
847 grpclog.Errorln("grpc: server failed to encode response: ", err)
848 return err
849 }
850 compData, err := compress(data, cp, comp)
851 if err != nil {
852 grpclog.Errorln("grpc: server failed to compress response: ", err)
853 return err
854 }
855 hdr, payload := msgHeader(data, compData)
856 // TODO(dfawley): should we be checking len(data) instead?
857 if len(payload) > s.opts.maxSendMessageSize {
858 return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
859 }
860 err = t.Write(stream, hdr, payload, opts)
861 if err == nil && s.opts.statsHandler != nil {
862 s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
863 }
864 return err
865}
866
867func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
868 sh := s.opts.statsHandler
869 if sh != nil || trInfo != nil || channelz.IsOn() {
870 if channelz.IsOn() {
871 s.incrCallsStarted()
872 }
873 var statsBegin *stats.Begin
874 if sh != nil {
875 beginTime := time.Now()
876 statsBegin = &stats.Begin{
877 BeginTime: beginTime,
878 }
879 sh.HandleRPC(stream.Context(), statsBegin)
880 }
881 if trInfo != nil {
882 trInfo.tr.LazyLog(&trInfo.firstLine, false)
883 }
884 // The deferred error handling for tracing, stats handler and channelz are
885 // combined into one function to reduce stack usage -- a defer takes ~56-64
886 // bytes on the stack, so overflowing the stack will require a stack
887 // re-allocation, which is expensive.
888 //
889 // To maintain behavior similar to separate deferred statements, statements
890 // should be executed in the reverse order. That is, tracing first, stats
891 // handler second, and channelz last. Note that panics *within* defers will
892 // lead to different behavior, but that's an acceptable compromise; that
893 // would be undefined behavior territory anyway.
894 defer func() {
895 if trInfo != nil {
896 if err != nil && err != io.EOF {
897 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
898 trInfo.tr.SetError()
899 }
900 trInfo.tr.Finish()
901 }
902
903 if sh != nil {
904 end := &stats.End{
905 BeginTime: statsBegin.BeginTime,
906 EndTime: time.Now(),
907 }
908 if err != nil && err != io.EOF {
909 end.Error = toRPCErr(err)
910 }
911 sh.HandleRPC(stream.Context(), end)
912 }
913
914 if channelz.IsOn() {
915 if err != nil && err != io.EOF {
916 s.incrCallsFailed()
917 } else {
918 s.incrCallsSucceeded()
919 }
920 }
921 }()
922 }
923
924 binlog := binarylog.GetMethodLogger(stream.Method())
925 if binlog != nil {
926 ctx := stream.Context()
927 md, _ := metadata.FromIncomingContext(ctx)
928 logEntry := &binarylog.ClientHeader{
929 Header: md,
930 MethodName: stream.Method(),
931 PeerAddr: nil,
932 }
933 if deadline, ok := ctx.Deadline(); ok {
934 logEntry.Timeout = time.Until(deadline)
935 if logEntry.Timeout < 0 {
936 logEntry.Timeout = 0
937 }
938 }
939 if a := md[":authority"]; len(a) > 0 {
940 logEntry.Authority = a[0]
941 }
942 if peer, ok := peer.FromContext(ctx); ok {
943 logEntry.PeerAddr = peer.Addr
944 }
945 binlog.Log(logEntry)
946 }
947
948 // comp and cp are used for compression. decomp and dc are used for
949 // decompression. If comp and decomp are both set, they are the same;
950 // however they are kept separate to ensure that at most one of the
951 // compressor/decompressor variable pairs are set for use later.
952 var comp, decomp encoding.Compressor
953 var cp Compressor
954 var dc Decompressor
955
956 // If dc is set and matches the stream's compression, use it. Otherwise, try
957 // to find a matching registered compressor for decomp.
958 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
959 dc = s.opts.dc
960 } else if rc != "" && rc != encoding.Identity {
961 decomp = encoding.GetCompressor(rc)
962 if decomp == nil {
963 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
964 t.WriteStatus(stream, st)
965 return st.Err()
966 }
967 }
968
969 // If cp is set, use it. Otherwise, attempt to compress the response using
970 // the incoming message compression method.
971 //
972 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
973 if s.opts.cp != nil {
974 cp = s.opts.cp
975 stream.SetSendCompress(cp.Type())
976 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
977 // Legacy compressor not specified; attempt to respond with same encoding.
978 comp = encoding.GetCompressor(rc)
979 if comp != nil {
980 stream.SetSendCompress(rc)
981 }
982 }
983
984 var payInfo *payloadInfo
985 if sh != nil || binlog != nil {
986 payInfo = &payloadInfo{}
987 }
988 d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
989 if err != nil {
990 if st, ok := status.FromError(err); ok {
991 if e := t.WriteStatus(stream, st); e != nil {
992 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
993 }
994 }
995 return err
996 }
997 if channelz.IsOn() {
998 t.IncrMsgRecv()
999 }
1000 df := func(v interface{}) error {
1001 if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
1002 return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
1003 }
1004 if sh != nil {
1005 sh.HandleRPC(stream.Context(), &stats.InPayload{
1006 RecvTime: time.Now(),
1007 Payload: v,
1008 WireLength: payInfo.wireLength,
1009 Data: d,
1010 Length: len(d),
1011 })
1012 }
1013 if binlog != nil {
1014 binlog.Log(&binarylog.ClientMessage{
1015 Message: d,
1016 })
1017 }
1018 if trInfo != nil {
1019 trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
1020 }
1021 return nil
1022 }
1023 ctx := NewContextWithServerTransportStream(stream.Context(), stream)
1024 reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
1025 if appErr != nil {
1026 appStatus, ok := status.FromError(appErr)
1027 if !ok {
1028 // Convert appErr if it is not a grpc status error.
1029 appErr = status.Error(codes.Unknown, appErr.Error())
1030 appStatus, _ = status.FromError(appErr)
1031 }
1032 if trInfo != nil {
1033 trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1034 trInfo.tr.SetError()
1035 }
1036 if e := t.WriteStatus(stream, appStatus); e != nil {
1037 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
1038 }
1039 if binlog != nil {
1040 if h, _ := stream.Header(); h.Len() > 0 {
1041 // Only log serverHeader if there was header. Otherwise it can
1042 // be trailer only.
1043 binlog.Log(&binarylog.ServerHeader{
1044 Header: h,
1045 })
1046 }
1047 binlog.Log(&binarylog.ServerTrailer{
1048 Trailer: stream.Trailer(),
1049 Err: appErr,
1050 })
1051 }
1052 return appErr
1053 }
1054 if trInfo != nil {
1055 trInfo.tr.LazyLog(stringer("OK"), false)
1056 }
1057 opts := &transport.Options{Last: true}
1058
1059 if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
1060 if err == io.EOF {
1061 // The entire stream is done (for unary RPC only).
1062 return err
1063 }
1064 if s, ok := status.FromError(err); ok {
1065 if e := t.WriteStatus(stream, s); e != nil {
1066 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
1067 }
1068 } else {
1069 switch st := err.(type) {
1070 case transport.ConnectionError:
1071 // Nothing to do here.
1072 default:
1073 panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
1074 }
1075 }
1076 if binlog != nil {
1077 h, _ := stream.Header()
1078 binlog.Log(&binarylog.ServerHeader{
1079 Header: h,
1080 })
1081 binlog.Log(&binarylog.ServerTrailer{
1082 Trailer: stream.Trailer(),
1083 Err: appErr,
1084 })
1085 }
1086 return err
1087 }
1088 if binlog != nil {
1089 h, _ := stream.Header()
1090 binlog.Log(&binarylog.ServerHeader{
1091 Header: h,
1092 })
1093 binlog.Log(&binarylog.ServerMessage{
1094 Message: reply,
1095 })
1096 }
1097 if channelz.IsOn() {
1098 t.IncrMsgSent()
1099 }
1100 if trInfo != nil {
1101 trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
1102 }
1103 // TODO: Should we be logging if writing status failed here, like above?
1104 // Should the logging be in WriteStatus? Should we ignore the WriteStatus
1105 // error or allow the stats handler to see it?
1106 err = t.WriteStatus(stream, statusOK)
1107 if binlog != nil {
1108 binlog.Log(&binarylog.ServerTrailer{
1109 Trailer: stream.Trailer(),
1110 Err: appErr,
1111 })
1112 }
1113 return err
1114}
1115
1116func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
1117 if channelz.IsOn() {
1118 s.incrCallsStarted()
1119 }
1120 sh := s.opts.statsHandler
1121 var statsBegin *stats.Begin
1122 if sh != nil {
1123 beginTime := time.Now()
1124 statsBegin = &stats.Begin{
1125 BeginTime: beginTime,
1126 }
1127 sh.HandleRPC(stream.Context(), statsBegin)
1128 }
1129 ctx := NewContextWithServerTransportStream(stream.Context(), stream)
1130 ss := &serverStream{
1131 ctx: ctx,
1132 t: t,
1133 s: stream,
1134 p: &parser{r: stream},
1135 codec: s.getCodec(stream.ContentSubtype()),
1136 maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
1137 maxSendMessageSize: s.opts.maxSendMessageSize,
1138 trInfo: trInfo,
1139 statsHandler: sh,
1140 }
1141
1142 if sh != nil || trInfo != nil || channelz.IsOn() {
1143 // See comment in processUnaryRPC on defers.
1144 defer func() {
1145 if trInfo != nil {
1146 ss.mu.Lock()
1147 if err != nil && err != io.EOF {
1148 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1149 ss.trInfo.tr.SetError()
1150 }
1151 ss.trInfo.tr.Finish()
1152 ss.trInfo.tr = nil
1153 ss.mu.Unlock()
1154 }
1155
1156 if sh != nil {
1157 end := &stats.End{
1158 BeginTime: statsBegin.BeginTime,
1159 EndTime: time.Now(),
1160 }
1161 if err != nil && err != io.EOF {
1162 end.Error = toRPCErr(err)
1163 }
1164 sh.HandleRPC(stream.Context(), end)
1165 }
1166
1167 if channelz.IsOn() {
1168 if err != nil && err != io.EOF {
1169 s.incrCallsFailed()
1170 } else {
1171 s.incrCallsSucceeded()
1172 }
1173 }
1174 }()
1175 }
1176
1177 ss.binlog = binarylog.GetMethodLogger(stream.Method())
1178 if ss.binlog != nil {
1179 md, _ := metadata.FromIncomingContext(ctx)
1180 logEntry := &binarylog.ClientHeader{
1181 Header: md,
1182 MethodName: stream.Method(),
1183 PeerAddr: nil,
1184 }
1185 if deadline, ok := ctx.Deadline(); ok {
1186 logEntry.Timeout = time.Until(deadline)
1187 if logEntry.Timeout < 0 {
1188 logEntry.Timeout = 0
1189 }
1190 }
1191 if a := md[":authority"]; len(a) > 0 {
1192 logEntry.Authority = a[0]
1193 }
1194 if peer, ok := peer.FromContext(ss.Context()); ok {
1195 logEntry.PeerAddr = peer.Addr
1196 }
1197 ss.binlog.Log(logEntry)
1198 }
1199
1200 // If dc is set and matches the stream's compression, use it. Otherwise, try
1201 // to find a matching registered compressor for decomp.
1202 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
1203 ss.dc = s.opts.dc
1204 } else if rc != "" && rc != encoding.Identity {
1205 ss.decomp = encoding.GetCompressor(rc)
1206 if ss.decomp == nil {
1207 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
1208 t.WriteStatus(ss.s, st)
1209 return st.Err()
1210 }
1211 }
1212
1213 // If cp is set, use it. Otherwise, attempt to compress the response using
1214 // the incoming message compression method.
1215 //
1216 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
1217 if s.opts.cp != nil {
1218 ss.cp = s.opts.cp
1219 stream.SetSendCompress(s.opts.cp.Type())
1220 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
1221 // Legacy compressor not specified; attempt to respond with same encoding.
1222 ss.comp = encoding.GetCompressor(rc)
1223 if ss.comp != nil {
1224 stream.SetSendCompress(rc)
1225 }
1226 }
1227
1228 if trInfo != nil {
1229 trInfo.tr.LazyLog(&trInfo.firstLine, false)
1230 }
1231 var appErr error
1232 var server interface{}
1233 if srv != nil {
1234 server = srv.server
1235 }
1236 if s.opts.streamInt == nil {
1237 appErr = sd.Handler(server, ss)
1238 } else {
1239 info := &StreamServerInfo{
1240 FullMethod: stream.Method(),
1241 IsClientStream: sd.ClientStreams,
1242 IsServerStream: sd.ServerStreams,
1243 }
1244 appErr = s.opts.streamInt(server, ss, info, sd.Handler)
1245 }
1246 if appErr != nil {
1247 appStatus, ok := status.FromError(appErr)
1248 if !ok {
1249 appStatus = status.New(codes.Unknown, appErr.Error())
1250 appErr = appStatus.Err()
1251 }
1252 if trInfo != nil {
1253 ss.mu.Lock()
1254 ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1255 ss.trInfo.tr.SetError()
1256 ss.mu.Unlock()
1257 }
1258 t.WriteStatus(ss.s, appStatus)
1259 if ss.binlog != nil {
1260 ss.binlog.Log(&binarylog.ServerTrailer{
1261 Trailer: ss.s.Trailer(),
1262 Err: appErr,
1263 })
1264 }
1265 // TODO: Should we log an error from WriteStatus here and below?
1266 return appErr
1267 }
1268 if trInfo != nil {
1269 ss.mu.Lock()
1270 ss.trInfo.tr.LazyLog(stringer("OK"), false)
1271 ss.mu.Unlock()
1272 }
1273 err = t.WriteStatus(ss.s, statusOK)
1274 if ss.binlog != nil {
1275 ss.binlog.Log(&binarylog.ServerTrailer{
1276 Trailer: ss.s.Trailer(),
1277 Err: appErr,
1278 })
1279 }
1280 return err
1281}
1282
1283func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
1284 sm := stream.Method()
1285 if sm != "" && sm[0] == '/' {
1286 sm = sm[1:]
1287 }
1288 pos := strings.LastIndex(sm, "/")
1289 if pos == -1 {
1290 if trInfo != nil {
1291 trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
1292 trInfo.tr.SetError()
1293 }
1294 errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
1295 if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
1296 if trInfo != nil {
1297 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1298 trInfo.tr.SetError()
1299 }
1300 grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
1301 }
1302 if trInfo != nil {
1303 trInfo.tr.Finish()
1304 }
1305 return
1306 }
1307 service := sm[:pos]
1308 method := sm[pos+1:]
1309
1310 srv, knownService := s.m[service]
1311 if knownService {
1312 if md, ok := srv.md[method]; ok {
1313 s.processUnaryRPC(t, stream, srv, md, trInfo)
1314 return
1315 }
1316 if sd, ok := srv.sd[method]; ok {
1317 s.processStreamingRPC(t, stream, srv, sd, trInfo)
1318 return
1319 }
1320 }
1321 // Unknown service, or known server unknown method.
1322 if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
1323 s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
1324 return
1325 }
1326 var errDesc string
1327 if !knownService {
1328 errDesc = fmt.Sprintf("unknown service %v", service)
1329 } else {
1330 errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
1331 }
1332 if trInfo != nil {
1333 trInfo.tr.LazyPrintf("%s", errDesc)
1334 trInfo.tr.SetError()
1335 }
1336 if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
1337 if trInfo != nil {
1338 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1339 trInfo.tr.SetError()
1340 }
1341 grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
1342 }
1343 if trInfo != nil {
1344 trInfo.tr.Finish()
1345 }
1346}
1347
1348// The key to save ServerTransportStream in the context.
1349type streamKey struct{}
1350
1351// NewContextWithServerTransportStream creates a new context from ctx and
1352// attaches stream to it.
1353//
1354// This API is EXPERIMENTAL.
1355func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
1356 return context.WithValue(ctx, streamKey{}, stream)
1357}
1358
1359// ServerTransportStream is a minimal interface that a transport stream must
1360// implement. This can be used to mock an actual transport stream for tests of
1361// handler code that use, for example, grpc.SetHeader (which requires some
1362// stream to be in context).
1363//
1364// See also NewContextWithServerTransportStream.
1365//
1366// This API is EXPERIMENTAL.
1367type ServerTransportStream interface {
1368 Method() string
1369 SetHeader(md metadata.MD) error
1370 SendHeader(md metadata.MD) error
1371 SetTrailer(md metadata.MD) error
1372}
1373
1374// ServerTransportStreamFromContext returns the ServerTransportStream saved in
1375// ctx. Returns nil if the given context has no stream associated with it
1376// (which implies it is not an RPC invocation context).
1377//
1378// This API is EXPERIMENTAL.
1379func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
1380 s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
1381 return s
1382}
1383
1384// Stop stops the gRPC server. It immediately closes all open
1385// connections and listeners.
1386// It cancels all active RPCs on the server side and the corresponding
1387// pending RPCs on the client side will get notified by connection
1388// errors.
1389func (s *Server) Stop() {
1390 s.quit.Fire()
1391
1392 defer func() {
1393 s.serveWG.Wait()
1394 s.done.Fire()
1395 }()
1396
1397 s.channelzRemoveOnce.Do(func() {
1398 if channelz.IsOn() {
1399 channelz.RemoveEntry(s.channelzID)
1400 }
1401 })
1402
1403 s.mu.Lock()
1404 listeners := s.lis
1405 s.lis = nil
1406 st := s.conns
1407 s.conns = nil
1408 // interrupt GracefulStop if Stop and GracefulStop are called concurrently.
1409 s.cv.Broadcast()
1410 s.mu.Unlock()
1411
1412 for lis := range listeners {
1413 lis.Close()
1414 }
1415 for c := range st {
1416 c.Close()
1417 }
1418
1419 s.mu.Lock()
1420 if s.events != nil {
1421 s.events.Finish()
1422 s.events = nil
1423 }
1424 s.mu.Unlock()
1425}
1426
1427// GracefulStop stops the gRPC server gracefully. It stops the server from
1428// accepting new connections and RPCs and blocks until all the pending RPCs are
1429// finished.
1430func (s *Server) GracefulStop() {
1431 s.quit.Fire()
1432 defer s.done.Fire()
1433
1434 s.channelzRemoveOnce.Do(func() {
1435 if channelz.IsOn() {
1436 channelz.RemoveEntry(s.channelzID)
1437 }
1438 })
1439 s.mu.Lock()
1440 if s.conns == nil {
1441 s.mu.Unlock()
1442 return
1443 }
1444
1445 for lis := range s.lis {
1446 lis.Close()
1447 }
1448 s.lis = nil
1449 if !s.drain {
1450 for st := range s.conns {
1451 st.Drain()
1452 }
1453 s.drain = true
1454 }
1455
1456 // Wait for serving threads to be ready to exit. Only then can we be sure no
1457 // new conns will be created.
1458 s.mu.Unlock()
1459 s.serveWG.Wait()
1460 s.mu.Lock()
1461
1462 for len(s.conns) != 0 {
1463 s.cv.Wait()
1464 }
1465 s.conns = nil
1466 if s.events != nil {
1467 s.events.Finish()
1468 s.events = nil
1469 }
1470 s.mu.Unlock()
1471}
1472
1473// contentSubtype must be lowercase
1474// cannot return nil
1475func (s *Server) getCodec(contentSubtype string) baseCodec {
1476 if s.opts.codec != nil {
1477 return s.opts.codec
1478 }
1479 if contentSubtype == "" {
1480 return encoding.GetCodec(proto.Name)
1481 }
1482 codec := encoding.GetCodec(contentSubtype)
1483 if codec == nil {
1484 return encoding.GetCodec(proto.Name)
1485 }
1486 return codec
1487}
1488
1489// SetHeader sets the header metadata.
1490// When called multiple times, all the provided metadata will be merged.
1491// All the metadata will be sent out when one of the following happens:
1492// - grpc.SendHeader() is called;
1493// - The first response is sent out;
1494// - An RPC status is sent out (error or success).
1495func SetHeader(ctx context.Context, md metadata.MD) error {
1496 if md.Len() == 0 {
1497 return nil
1498 }
1499 stream := ServerTransportStreamFromContext(ctx)
1500 if stream == nil {
1501 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1502 }
1503 return stream.SetHeader(md)
1504}
1505
1506// SendHeader sends header metadata. It may be called at most once.
1507// The provided md and headers set by SetHeader() will be sent.
1508func SendHeader(ctx context.Context, md metadata.MD) error {
1509 stream := ServerTransportStreamFromContext(ctx)
1510 if stream == nil {
1511 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1512 }
1513 if err := stream.SendHeader(md); err != nil {
1514 return toRPCErr(err)
1515 }
1516 return nil
1517}
1518
1519// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
1520// When called more than once, all the provided metadata will be merged.
1521func SetTrailer(ctx context.Context, md metadata.MD) error {
1522 if md.Len() == 0 {
1523 return nil
1524 }
1525 stream := ServerTransportStreamFromContext(ctx)
1526 if stream == nil {
1527 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1528 }
1529 return stream.SetTrailer(md)
1530}
1531
1532// Method returns the method string for the server context. The returned
1533// string is in the format of "/service/method".
1534func Method(ctx context.Context) (string, bool) {
1535 s := ServerTransportStreamFromContext(ctx)
1536 if s == nil {
1537 return "", false
1538 }
1539 return s.Method(), true
1540}
1541
1542type channelzServer struct {
1543 s *Server
1544}
1545
1546func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
1547 return c.s.channelzMetric()
1548}