Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package session |
| 8 | |
| 9 | import ( |
| 10 | "sync" |
| 11 | |
| 12 | "github.com/mongodb/mongo-go-driver/x/bsonx" |
| 13 | "github.com/mongodb/mongo-go-driver/x/network/description" |
| 14 | ) |
| 15 | |
| 16 | // Node represents a server session in a linked list |
| 17 | type Node struct { |
| 18 | *Server |
| 19 | next *Node |
| 20 | prev *Node |
| 21 | } |
| 22 | |
| 23 | // Pool is a pool of server sessions that can be reused. |
| 24 | type Pool struct { |
| 25 | descChan <-chan description.Topology |
| 26 | head *Node |
| 27 | tail *Node |
| 28 | timeout uint32 |
| 29 | mutex sync.Mutex // mutex to protect list and sessionTimeout |
| 30 | |
| 31 | checkedOut int // number of sessions checked out of pool |
| 32 | } |
| 33 | |
| 34 | func (p *Pool) createServerSession() (*Server, error) { |
| 35 | s, err := newServerSession() |
| 36 | if err != nil { |
| 37 | return nil, err |
| 38 | } |
| 39 | |
| 40 | p.checkedOut++ |
| 41 | return s, nil |
| 42 | } |
| 43 | |
| 44 | // NewPool creates a new server session pool |
| 45 | func NewPool(descChan <-chan description.Topology) *Pool { |
| 46 | p := &Pool{ |
| 47 | descChan: descChan, |
| 48 | } |
| 49 | |
| 50 | return p |
| 51 | } |
| 52 | |
| 53 | // assumes caller has mutex to protect the pool |
| 54 | func (p *Pool) updateTimeout() { |
| 55 | select { |
| 56 | case newDesc := <-p.descChan: |
| 57 | p.timeout = newDesc.SessionTimeoutMinutes |
| 58 | default: |
| 59 | // no new description waiting |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | // GetSession retrieves an unexpired session from the pool. |
| 64 | func (p *Pool) GetSession() (*Server, error) { |
| 65 | p.mutex.Lock() // prevent changing the linked list while seeing if sessions have expired |
| 66 | defer p.mutex.Unlock() |
| 67 | |
| 68 | // empty pool |
| 69 | if p.head == nil && p.tail == nil { |
| 70 | return p.createServerSession() |
| 71 | } |
| 72 | |
| 73 | p.updateTimeout() |
| 74 | for p.head != nil { |
| 75 | // pull session from head of queue and return if it is valid for at least 1 more minute |
| 76 | if p.head.expired(p.timeout) { |
| 77 | p.head = p.head.next |
| 78 | continue |
| 79 | } |
| 80 | |
| 81 | // found unexpired session |
| 82 | session := p.head.Server |
| 83 | if p.head.next != nil { |
| 84 | p.head.next.prev = nil |
| 85 | } |
| 86 | if p.tail == p.head { |
| 87 | p.tail = nil |
| 88 | p.head = nil |
| 89 | } else { |
| 90 | p.head = p.head.next |
| 91 | } |
| 92 | |
| 93 | p.checkedOut++ |
| 94 | return session, nil |
| 95 | } |
| 96 | |
| 97 | // no valid session found |
| 98 | p.tail = nil // empty list |
| 99 | return p.createServerSession() |
| 100 | } |
| 101 | |
| 102 | // ReturnSession returns a session to the pool if it has not expired. |
| 103 | func (p *Pool) ReturnSession(ss *Server) { |
| 104 | if ss == nil { |
| 105 | return |
| 106 | } |
| 107 | |
| 108 | p.mutex.Lock() |
| 109 | defer p.mutex.Unlock() |
| 110 | |
| 111 | p.checkedOut-- |
| 112 | p.updateTimeout() |
| 113 | // check sessions at end of queue for expired |
| 114 | // stop checking after hitting the first valid session |
| 115 | for p.tail != nil && p.tail.expired(p.timeout) { |
| 116 | if p.tail.prev != nil { |
| 117 | p.tail.prev.next = nil |
| 118 | } |
| 119 | p.tail = p.tail.prev |
| 120 | } |
| 121 | |
| 122 | // session expired |
| 123 | if ss.expired(p.timeout) { |
| 124 | return |
| 125 | } |
| 126 | |
| 127 | newNode := &Node{ |
| 128 | Server: ss, |
| 129 | next: nil, |
| 130 | prev: nil, |
| 131 | } |
| 132 | |
| 133 | // empty list |
| 134 | if p.tail == nil { |
| 135 | p.head = newNode |
| 136 | p.tail = newNode |
| 137 | return |
| 138 | } |
| 139 | |
| 140 | // at least 1 valid session in list |
| 141 | newNode.next = p.head |
| 142 | p.head.prev = newNode |
| 143 | p.head = newNode |
| 144 | } |
| 145 | |
| 146 | // IDSlice returns a slice of session IDs for each session in the pool |
| 147 | func (p *Pool) IDSlice() []bsonx.Doc { |
| 148 | p.mutex.Lock() |
| 149 | defer p.mutex.Unlock() |
| 150 | |
| 151 | ids := []bsonx.Doc{} |
| 152 | for node := p.head; node != nil; node = node.next { |
| 153 | ids = append(ids, node.SessionID) |
| 154 | } |
| 155 | |
| 156 | return ids |
| 157 | } |
| 158 | |
| 159 | // String implements the Stringer interface |
| 160 | func (p *Pool) String() string { |
| 161 | p.mutex.Lock() |
| 162 | defer p.mutex.Unlock() |
| 163 | |
| 164 | s := "" |
| 165 | for head := p.head; head != nil; head = head.next { |
| 166 | s += head.SessionID.String() + "\n" |
| 167 | } |
| 168 | |
| 169 | return s |
| 170 | } |
| 171 | |
| 172 | // CheckedOut returns number of sessions checked out from pool. |
| 173 | func (p *Pool) CheckedOut() int { |
| 174 | return p.checkedOut |
| 175 | } |