blob: 563cbfc694a4a43eac718de0507a8dd69ea19f5f [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23 "errors"
24 "fmt"
25 "io"
26 "sync"
27 "sync/atomic"
28 "time"
29)
30
31// ErrAbandonRequest is a special error server handler implementations can
32// return to indicate that the request has been abandoned.
33//
34// TSimpleServer will check for this error, and close the client connection
35// instead of writing the response/error back to the client.
36//
37// It shall only be used when the server handler implementation know that the
38// client already abandoned the request (by checking that the passed in context
39// is already canceled, for example).
40var ErrAbandonRequest = errors.New("request abandoned")
41
42// ServerConnectivityCheckInterval defines the ticker interval used by
43// connectivity check in thrift compiled TProcessorFunc implementations.
44//
45// It's defined as a variable instead of constant, so that thrift server
46// implementations can change its value to control the behavior.
47//
48// If it's changed to <=0, the feature will be disabled.
49var ServerConnectivityCheckInterval = time.Millisecond * 5
50
51/*
52 * This is not a typical TSimpleServer as it is not blocked after accept a socket.
53 * It is more like a TThreadedServer that can handle different connections in different goroutines.
54 * This will work if golang user implements a conn-pool like thing in client side.
55 */
56type TSimpleServer struct {
57 closed int32
58 wg sync.WaitGroup
59 mu sync.Mutex
60
61 processorFactory TProcessorFactory
62 serverTransport TServerTransport
63 inputTransportFactory TTransportFactory
64 outputTransportFactory TTransportFactory
65 inputProtocolFactory TProtocolFactory
66 outputProtocolFactory TProtocolFactory
67
68 // Headers to auto forward in THeaderProtocol
69 forwardHeaders []string
70
71 logger Logger
72}
73
74func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
75 return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport)
76}
77
78func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
79 return NewTSimpleServerFactory4(NewTProcessorFactory(processor),
80 serverTransport,
81 transportFactory,
82 protocolFactory,
83 )
84}
85
86func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
87 return NewTSimpleServerFactory6(NewTProcessorFactory(processor),
88 serverTransport,
89 inputTransportFactory,
90 outputTransportFactory,
91 inputProtocolFactory,
92 outputProtocolFactory,
93 )
94}
95
96func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer {
97 return NewTSimpleServerFactory6(processorFactory,
98 serverTransport,
99 NewTTransportFactory(),
100 NewTTransportFactory(),
101 NewTBinaryProtocolFactoryDefault(),
102 NewTBinaryProtocolFactoryDefault(),
103 )
104}
105
106func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
107 return NewTSimpleServerFactory6(processorFactory,
108 serverTransport,
109 transportFactory,
110 transportFactory,
111 protocolFactory,
112 protocolFactory,
113 )
114}
115
116func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
117 return &TSimpleServer{
118 processorFactory: processorFactory,
119 serverTransport: serverTransport,
120 inputTransportFactory: inputTransportFactory,
121 outputTransportFactory: outputTransportFactory,
122 inputProtocolFactory: inputProtocolFactory,
123 outputProtocolFactory: outputProtocolFactory,
124 }
125}
126
127func (p *TSimpleServer) ProcessorFactory() TProcessorFactory {
128 return p.processorFactory
129}
130
131func (p *TSimpleServer) ServerTransport() TServerTransport {
132 return p.serverTransport
133}
134
135func (p *TSimpleServer) InputTransportFactory() TTransportFactory {
136 return p.inputTransportFactory
137}
138
139func (p *TSimpleServer) OutputTransportFactory() TTransportFactory {
140 return p.outputTransportFactory
141}
142
143func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory {
144 return p.inputProtocolFactory
145}
146
147func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory {
148 return p.outputProtocolFactory
149}
150
151func (p *TSimpleServer) Listen() error {
152 return p.serverTransport.Listen()
153}
154
155// SetForwardHeaders sets the list of header keys that will be auto forwarded
156// while using THeaderProtocol.
157//
158// "forward" means that when the server is also a client to other upstream
159// thrift servers, the context object user gets in the processor functions will
160// have both read and write headers set, with write headers being forwarded.
161// Users can always override the write headers by calling SetWriteHeaderList
162// before calling thrift client functions.
163func (p *TSimpleServer) SetForwardHeaders(headers []string) {
164 size := len(headers)
165 if size == 0 {
166 p.forwardHeaders = nil
167 return
168 }
169
170 keys := make([]string, size)
171 copy(keys, headers)
172 p.forwardHeaders = keys
173}
174
175// SetLogger sets the logger used by this TSimpleServer.
176//
177// If no logger was set before Serve is called, a default logger using standard
178// log library will be used.
179func (p *TSimpleServer) SetLogger(logger Logger) {
180 p.logger = logger
181}
182
183func (p *TSimpleServer) innerAccept() (int32, error) {
184 client, err := p.serverTransport.Accept()
185 p.mu.Lock()
186 defer p.mu.Unlock()
187 closed := atomic.LoadInt32(&p.closed)
188 if closed != 0 {
189 return closed, nil
190 }
191 if err != nil {
192 return 0, err
193 }
194 if client != nil {
195 p.wg.Add(1)
196 go func() {
197 defer p.wg.Done()
198 if err := p.processRequests(client); err != nil {
199 p.logger(fmt.Sprintf("error processing request: %v", err))
200 }
201 }()
202 }
203 return 0, nil
204}
205
206func (p *TSimpleServer) AcceptLoop() error {
207 for {
208 closed, err := p.innerAccept()
209 if err != nil {
210 return err
211 }
212 if closed != 0 {
213 return nil
214 }
215 }
216}
217
218func (p *TSimpleServer) Serve() error {
219 p.logger = fallbackLogger(p.logger)
220
221 err := p.Listen()
222 if err != nil {
223 return err
224 }
225 p.AcceptLoop()
226 return nil
227}
228
229func (p *TSimpleServer) Stop() error {
230 p.mu.Lock()
231 defer p.mu.Unlock()
232 if atomic.LoadInt32(&p.closed) != 0 {
233 return nil
234 }
235 atomic.StoreInt32(&p.closed, 1)
236 p.serverTransport.Interrupt()
237 p.wg.Wait()
238 return nil
239}
240
241// If err is actually EOF, return nil, otherwise return err as-is.
242func treatEOFErrorsAsNil(err error) error {
243 if err == nil {
244 return nil
245 }
246 if errors.Is(err, io.EOF) {
247 return nil
248 }
249 var te TTransportException
250 if errors.As(err, &te) && te.TypeId() == END_OF_FILE {
251 return nil
252 }
253 return err
254}
255
256func (p *TSimpleServer) processRequests(client TTransport) (err error) {
257 defer func() {
258 err = treatEOFErrorsAsNil(err)
259 }()
260
261 processor := p.processorFactory.GetProcessor(client)
262 inputTransport, err := p.inputTransportFactory.GetTransport(client)
263 if err != nil {
264 return err
265 }
266 inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
267 var outputTransport TTransport
268 var outputProtocol TProtocol
269
270 // for THeaderProtocol, we must use the same protocol instance for
271 // input and output so that the response is in the same dialect that
272 // the server detected the request was in.
273 headerProtocol, ok := inputProtocol.(*THeaderProtocol)
274 if ok {
275 outputProtocol = inputProtocol
276 } else {
277 oTrans, err := p.outputTransportFactory.GetTransport(client)
278 if err != nil {
279 return err
280 }
281 outputTransport = oTrans
282 outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)
283 }
284
285 if inputTransport != nil {
286 defer inputTransport.Close()
287 }
288 if outputTransport != nil {
289 defer outputTransport.Close()
290 }
291 for {
292 if atomic.LoadInt32(&p.closed) != 0 {
293 return nil
294 }
295
296 ctx := SetResponseHelper(
297 defaultCtx,
298 TResponseHelper{
299 THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol),
300 },
301 )
302 if headerProtocol != nil {
303 // We need to call ReadFrame here, otherwise we won't
304 // get any headers on the AddReadTHeaderToContext call.
305 //
306 // ReadFrame is safe to be called multiple times so it
307 // won't break when it's called again later when we
308 // actually start to read the message.
309 if err := headerProtocol.ReadFrame(ctx); err != nil {
310 return err
311 }
312 ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders())
313 ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
314 }
315
316 ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
317 if errors.Is(err, ErrAbandonRequest) {
318 return client.Close()
319 }
320 if errors.As(err, new(TTransportException)) && err != nil {
321 return err
322 }
323 var tae TApplicationException
324 if errors.As(err, &tae) && tae.TypeId() == UNKNOWN_METHOD {
325 continue
326 }
327 if !ok {
328 break
329 }
330 }
331 return nil
332}