blob: f7654d0d07ebbfb69c6298611539a26e79ba4d03 [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001package client
2
3import (
4 "encoding/json"
5 "fmt"
6 "sort"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/jcmturner/gokrb5/v8/iana/nametype"
12 "github.com/jcmturner/gokrb5/v8/krberror"
13 "github.com/jcmturner/gokrb5/v8/messages"
14 "github.com/jcmturner/gokrb5/v8/types"
15)
16
17// sessions hold TGTs and are keyed on the realm name
18type sessions struct {
19 Entries map[string]*session
20 mux sync.RWMutex
21}
22
23// destroy erases all sessions
24func (s *sessions) destroy() {
25 s.mux.Lock()
26 defer s.mux.Unlock()
27 for k, e := range s.Entries {
28 e.destroy()
29 delete(s.Entries, k)
30 }
31}
32
33// update replaces a session with the one provided or adds it as a new one
34func (s *sessions) update(sess *session) {
35 s.mux.Lock()
36 defer s.mux.Unlock()
37 // if a session already exists for this, cancel its auto renew.
38 if i, ok := s.Entries[sess.realm]; ok {
39 if i != sess {
40 // Session in the sessions cache is not the same as one provided.
41 // Cancel the one in the cache and add this one.
42 i.mux.Lock()
43 defer i.mux.Unlock()
44 i.cancel <- true
45 s.Entries[sess.realm] = sess
46 return
47 }
48 }
49 // No session for this realm was found so just add it
50 s.Entries[sess.realm] = sess
51}
52
53// get returns the session for the realm specified
54func (s *sessions) get(realm string) (*session, bool) {
55 s.mux.RLock()
56 defer s.mux.RUnlock()
57 sess, ok := s.Entries[realm]
58 return sess, ok
59}
60
61// session holds the TGT details for a realm
62type session struct {
63 realm string
64 authTime time.Time
65 endTime time.Time
66 renewTill time.Time
67 tgt messages.Ticket
68 sessionKey types.EncryptionKey
69 sessionKeyExpiration time.Time
70 cancel chan bool
71 mux sync.RWMutex
72}
73
74// jsonSession is used to enable marshaling some information of a session in a JSON format
75type jsonSession struct {
76 Realm string
77 AuthTime time.Time
78 EndTime time.Time
79 RenewTill time.Time
80 SessionKeyExpiration time.Time
81}
82
83// AddSession adds a session for a realm with a TGT to the client's session cache.
84// A goroutine is started to automatically renew the TGT before expiry.
85func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
86 if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
87 // Not a TGT
88 return
89 }
90 realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
91 s := &session{
92 realm: realm,
93 authTime: dep.AuthTime,
94 endTime: dep.EndTime,
95 renewTill: dep.RenewTill,
96 tgt: tgt,
97 sessionKey: dep.Key,
98 sessionKeyExpiration: dep.KeyExpiration,
99 }
100 cl.sessions.update(s)
101 cl.enableAutoSessionRenewal(s)
102 cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
103}
104
105// update overwrites the session details with those from the TGT and decrypted encPart
106func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
107 s.mux.Lock()
108 defer s.mux.Unlock()
109 s.authTime = dep.AuthTime
110 s.endTime = dep.EndTime
111 s.renewTill = dep.RenewTill
112 s.tgt = tgt
113 s.sessionKey = dep.Key
114 s.sessionKeyExpiration = dep.KeyExpiration
115}
116
117// destroy will cancel any auto renewal of the session and set the expiration times to the current time
118func (s *session) destroy() {
119 s.mux.Lock()
120 defer s.mux.Unlock()
121 if s.cancel != nil {
122 s.cancel <- true
123 }
124 s.endTime = time.Now().UTC()
125 s.renewTill = s.endTime
126 s.sessionKeyExpiration = s.endTime
127}
128
129// valid informs if the TGT is still within the valid time window
130func (s *session) valid() bool {
131 s.mux.RLock()
132 defer s.mux.RUnlock()
133 t := time.Now().UTC()
134 if t.Before(s.endTime) && s.authTime.Before(t) {
135 return true
136 }
137 return false
138}
139
140// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
141func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
142 s.mux.RLock()
143 defer s.mux.RUnlock()
144 return s.realm, s.tgt, s.sessionKey
145}
146
147// timeDetails is a thread safe way to get the session's validity time values
148func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
149 s.mux.RLock()
150 defer s.mux.RUnlock()
151 return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
152}
153
154// JSON return information about the held sessions in a JSON format.
155func (s *sessions) JSON() (string, error) {
156 s.mux.RLock()
157 defer s.mux.RUnlock()
158 var js []jsonSession
159 keys := make([]string, 0, len(s.Entries))
160 for k := range s.Entries {
161 keys = append(keys, k)
162 }
163 sort.Strings(keys)
164 for _, k := range keys {
165 r, at, et, rt, kt := s.Entries[k].timeDetails()
166 j := jsonSession{
167 Realm: r,
168 AuthTime: at,
169 EndTime: et,
170 RenewTill: rt,
171 SessionKeyExpiration: kt,
172 }
173 js = append(js, j)
174 }
175 b, err := json.MarshalIndent(js, "", " ")
176 if err != nil {
177 return "", err
178 }
179 return string(b), nil
180}
181
182// enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
183func (cl *Client) enableAutoSessionRenewal(s *session) {
184 var timer *time.Timer
185 s.mux.Lock()
186 s.cancel = make(chan bool, 1)
187 s.mux.Unlock()
188 go func(s *session) {
189 for {
190 s.mux.RLock()
191 w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
192 s.mux.RUnlock()
193 if w < 0 {
194 return
195 }
196 timer = time.NewTimer(w)
197 select {
198 case <-timer.C:
199 renewal, err := cl.refreshSession(s)
200 if err != nil {
201 cl.Log("error refreshing session: %v", err)
202 }
203 if !renewal && err == nil {
204 // end this goroutine as there will have been a new login and new auto renewal goroutine created.
205 return
206 }
207 case <-s.cancel:
208 // cancel has been called. Stop the timer and exit.
209 timer.Stop()
210 return
211 }
212 }
213 }(s)
214}
215
216// renewTGT renews the client's TGT session.
217func (cl *Client) renewTGT(s *session) error {
218 realm, tgt, skey := s.tgtDetails()
219 spn := types.PrincipalName{
220 NameType: nametype.KRB_NT_SRV_INST,
221 NameString: []string{"krbtgt", realm},
222 }
223 _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
224 if err != nil {
225 return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
226 }
227 s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
228 cl.sessions.update(s)
229 cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
230 return nil
231}
232
233// refreshSession updates either through renewal or creating a new login.
234// The boolean indicates if the update was a renewal.
235func (cl *Client) refreshSession(s *session) (bool, error) {
236 s.mux.RLock()
237 realm := s.realm
238 renewTill := s.renewTill
239 s.mux.RUnlock()
240 cl.Log("refreshing TGT session for %s", realm)
241 if time.Now().UTC().Before(renewTill) {
242 err := cl.renewTGT(s)
243 return true, err
244 }
245 err := cl.realmLogin(realm)
246 return false, err
247}
248
249// ensureValidSession makes sure there is a valid session for the realm
250func (cl *Client) ensureValidSession(realm string) error {
251 s, ok := cl.sessions.get(realm)
252 if ok {
253 s.mux.RLock()
254 d := s.endTime.Sub(s.authTime) / 6
255 if s.endTime.Sub(time.Now().UTC()) > d {
256 s.mux.RUnlock()
257 return nil
258 }
259 s.mux.RUnlock()
260 _, err := cl.refreshSession(s)
261 return err
262 }
263 return cl.realmLogin(realm)
264}
265
266// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
267func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
268 err = cl.ensureValidSession(realm)
269 if err != nil {
270 return
271 }
272 s, ok := cl.sessions.get(realm)
273 if !ok {
274 err = fmt.Errorf("could not find TGT session for %s", realm)
275 return
276 }
277 _, tgt, sessionKey = s.tgtDetails()
278 return
279}
280
281// sessionTimes provides the timing information with regards to a session for the realm specified.
282func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
283 s, ok := cl.sessions.get(realm)
284 if !ok {
285 err = fmt.Errorf("could not find TGT session for %s", realm)
286 return
287 }
288 _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
289 return
290}
291
292// spnRealm resolves the realm name of a service principal name
293func (cl *Client) spnRealm(spn types.PrincipalName) string {
294 return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
295}