blob: ec6c513d3bd66c37c110fca96760583605b790b0 [file] [log] [blame]
Scott Baker8487c5d2019-10-18 12:49:46 -07001package client
2
3import (
4 "fmt"
5 "strings"
6 "sync"
7 "time"
8
9 "gopkg.in/jcmturner/gokrb5.v7/iana/nametype"
10 "gopkg.in/jcmturner/gokrb5.v7/krberror"
11 "gopkg.in/jcmturner/gokrb5.v7/messages"
12 "gopkg.in/jcmturner/gokrb5.v7/types"
13)
14
15// sessions hold TGTs and are keyed on the realm name
16type sessions struct {
17 Entries map[string]*session
18 mux sync.RWMutex
19}
20
21// destroy erases all sessions
22func (s *sessions) destroy() {
23 s.mux.Lock()
24 defer s.mux.Unlock()
25 for k, e := range s.Entries {
26 e.destroy()
27 delete(s.Entries, k)
28 }
29}
30
31// update replaces a session with the one provided or adds it as a new one
32func (s *sessions) update(sess *session) {
33 s.mux.Lock()
34 defer s.mux.Unlock()
35 // if a session already exists for this, cancel its auto renew.
36 if i, ok := s.Entries[sess.realm]; ok {
37 if i != sess {
38 // Session in the sessions cache is not the same as one provided.
39 // Cancel the one in the cache and add this one.
40 i.mux.Lock()
41 defer i.mux.Unlock()
42 i.cancel <- true
43 s.Entries[sess.realm] = sess
44 return
45 }
46 }
47 // No session for this realm was found so just add it
48 s.Entries[sess.realm] = sess
49}
50
51// get returns the session for the realm specified
52func (s *sessions) get(realm string) (*session, bool) {
53 s.mux.RLock()
54 defer s.mux.RUnlock()
55 sess, ok := s.Entries[realm]
56 return sess, ok
57}
58
59// session holds the TGT details for a realm
60type session struct {
61 realm string
62 authTime time.Time
63 endTime time.Time
64 renewTill time.Time
65 tgt messages.Ticket
66 sessionKey types.EncryptionKey
67 sessionKeyExpiration time.Time
68 cancel chan bool
69 mux sync.RWMutex
70}
71
72// AddSession adds a session for a realm with a TGT to the client's session cache.
73// A goroutine is started to automatically renew the TGT before expiry.
74func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
75 if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
76 // Not a TGT
77 return
78 }
79 realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
80 s := &session{
81 realm: realm,
82 authTime: dep.AuthTime,
83 endTime: dep.EndTime,
84 renewTill: dep.RenewTill,
85 tgt: tgt,
86 sessionKey: dep.Key,
87 sessionKeyExpiration: dep.KeyExpiration,
88 }
89 cl.sessions.update(s)
90 cl.enableAutoSessionRenewal(s)
91 cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
92}
93
94// update overwrites the session details with those from the TGT and decrypted encPart
95func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
96 s.mux.Lock()
97 defer s.mux.Unlock()
98 s.authTime = dep.AuthTime
99 s.endTime = dep.EndTime
100 s.renewTill = dep.RenewTill
101 s.tgt = tgt
102 s.sessionKey = dep.Key
103 s.sessionKeyExpiration = dep.KeyExpiration
104}
105
106// destroy will cancel any auto renewal of the session and set the expiration times to the current time
107func (s *session) destroy() {
108 s.mux.Lock()
109 defer s.mux.Unlock()
110 if s.cancel != nil {
111 s.cancel <- true
112 }
113 s.endTime = time.Now().UTC()
114 s.renewTill = s.endTime
115 s.sessionKeyExpiration = s.endTime
116}
117
118// valid informs if the TGT is still within the valid time window
119func (s *session) valid() bool {
120 s.mux.RLock()
121 defer s.mux.RUnlock()
122 t := time.Now().UTC()
123 if t.Before(s.endTime) && s.authTime.Before(t) {
124 return true
125 }
126 return false
127}
128
129// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
130func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
131 s.mux.RLock()
132 defer s.mux.RUnlock()
133 return s.realm, s.tgt, s.sessionKey
134}
135
136// timeDetails is a thread safe way to get the session's validity time values
137func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
138 s.mux.RLock()
139 defer s.mux.RUnlock()
140 return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
141}
142
143// enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
144func (cl *Client) enableAutoSessionRenewal(s *session) {
145 var timer *time.Timer
146 s.mux.Lock()
147 s.cancel = make(chan bool, 1)
148 s.mux.Unlock()
149 go func(s *session) {
150 for {
151 s.mux.RLock()
152 w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
153 s.mux.RUnlock()
154 if w < 0 {
155 return
156 }
157 timer = time.NewTimer(w)
158 select {
159 case <-timer.C:
160 renewal, err := cl.refreshSession(s)
161 if err != nil {
162 cl.Log("error refreshing session: %v", err)
163 }
164 if !renewal && err == nil {
165 // end this goroutine as there will have been a new login and new auto renewal goroutine created.
166 return
167 }
168 case <-s.cancel:
169 // cancel has been called. Stop the timer and exit.
170 timer.Stop()
171 return
172 }
173 }
174 }(s)
175}
176
177// renewTGT renews the client's TGT session.
178func (cl *Client) renewTGT(s *session) error {
179 realm, tgt, skey := s.tgtDetails()
180 spn := types.PrincipalName{
181 NameType: nametype.KRB_NT_SRV_INST,
182 NameString: []string{"krbtgt", realm},
183 }
184 _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
185 if err != nil {
186 return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
187 }
188 s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
189 cl.sessions.update(s)
190 cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
191 return nil
192}
193
194// refreshSession updates either through renewal or creating a new login.
195// The boolean indicates if the update was a renewal.
196func (cl *Client) refreshSession(s *session) (bool, error) {
197 s.mux.RLock()
198 realm := s.realm
199 renewTill := s.renewTill
200 s.mux.RUnlock()
201 cl.Log("refreshing TGT session for %s", realm)
202 if time.Now().UTC().Before(renewTill) {
203 err := cl.renewTGT(s)
204 return true, err
205 }
206 err := cl.realmLogin(realm)
207 return false, err
208}
209
210// ensureValidSession makes sure there is a valid session for the realm
211func (cl *Client) ensureValidSession(realm string) error {
212 s, ok := cl.sessions.get(realm)
213 if ok {
214 s.mux.RLock()
215 d := s.endTime.Sub(s.authTime) / 6
216 if s.endTime.Sub(time.Now().UTC()) > d {
217 s.mux.RUnlock()
218 return nil
219 }
220 s.mux.RUnlock()
221 _, err := cl.refreshSession(s)
222 return err
223 }
224 return cl.realmLogin(realm)
225}
226
227// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
228func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
229 err = cl.ensureValidSession(realm)
230 if err != nil {
231 return
232 }
233 s, ok := cl.sessions.get(realm)
234 if !ok {
235 err = fmt.Errorf("could not find TGT session for %s", realm)
236 return
237 }
238 _, tgt, sessionKey = s.tgtDetails()
239 return
240}
241
242func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
243 s, ok := cl.sessions.get(realm)
244 if !ok {
245 err = fmt.Errorf("could not find TGT session for %s", realm)
246 return
247 }
248 _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
249 return
250}
251
252// spnRealm resolves the realm name of a service principal name
253func (cl *Client) spnRealm(spn types.PrincipalName) string {
254 return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
255}