blob: ec6c513d3bd66c37c110fca96760583605b790b0 [file] [log] [blame]
package client
import (
"fmt"
"strings"
"sync"
"time"
"gopkg.in/jcmturner/gokrb5.v7/iana/nametype"
"gopkg.in/jcmturner/gokrb5.v7/krberror"
"gopkg.in/jcmturner/gokrb5.v7/messages"
"gopkg.in/jcmturner/gokrb5.v7/types"
)
// sessions hold TGTs and are keyed on the realm name
type sessions struct {
Entries map[string]*session
mux sync.RWMutex
}
// destroy erases all sessions
func (s *sessions) destroy() {
s.mux.Lock()
defer s.mux.Unlock()
for k, e := range s.Entries {
e.destroy()
delete(s.Entries, k)
}
}
// update replaces a session with the one provided or adds it as a new one
func (s *sessions) update(sess *session) {
s.mux.Lock()
defer s.mux.Unlock()
// if a session already exists for this, cancel its auto renew.
if i, ok := s.Entries[sess.realm]; ok {
if i != sess {
// Session in the sessions cache is not the same as one provided.
// Cancel the one in the cache and add this one.
i.mux.Lock()
defer i.mux.Unlock()
i.cancel <- true
s.Entries[sess.realm] = sess
return
}
}
// No session for this realm was found so just add it
s.Entries[sess.realm] = sess
}
// get returns the session for the realm specified
func (s *sessions) get(realm string) (*session, bool) {
s.mux.RLock()
defer s.mux.RUnlock()
sess, ok := s.Entries[realm]
return sess, ok
}
// session holds the TGT details for a realm
type session struct {
realm string
authTime time.Time
endTime time.Time
renewTill time.Time
tgt messages.Ticket
sessionKey types.EncryptionKey
sessionKeyExpiration time.Time
cancel chan bool
mux sync.RWMutex
}
// AddSession adds a session for a realm with a TGT to the client's session cache.
// A goroutine is started to automatically renew the TGT before expiry.
func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
// Not a TGT
return
}
realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
s := &session{
realm: realm,
authTime: dep.AuthTime,
endTime: dep.EndTime,
renewTill: dep.RenewTill,
tgt: tgt,
sessionKey: dep.Key,
sessionKeyExpiration: dep.KeyExpiration,
}
cl.sessions.update(s)
cl.enableAutoSessionRenewal(s)
cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
}
// update overwrites the session details with those from the TGT and decrypted encPart
func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
s.mux.Lock()
defer s.mux.Unlock()
s.authTime = dep.AuthTime
s.endTime = dep.EndTime
s.renewTill = dep.RenewTill
s.tgt = tgt
s.sessionKey = dep.Key
s.sessionKeyExpiration = dep.KeyExpiration
}
// destroy will cancel any auto renewal of the session and set the expiration times to the current time
func (s *session) destroy() {
s.mux.Lock()
defer s.mux.Unlock()
if s.cancel != nil {
s.cancel <- true
}
s.endTime = time.Now().UTC()
s.renewTill = s.endTime
s.sessionKeyExpiration = s.endTime
}
// valid informs if the TGT is still within the valid time window
func (s *session) valid() bool {
s.mux.RLock()
defer s.mux.RUnlock()
t := time.Now().UTC()
if t.Before(s.endTime) && s.authTime.Before(t) {
return true
}
return false
}
// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
s.mux.RLock()
defer s.mux.RUnlock()
return s.realm, s.tgt, s.sessionKey
}
// timeDetails is a thread safe way to get the session's validity time values
func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
s.mux.RLock()
defer s.mux.RUnlock()
return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
}
// enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
func (cl *Client) enableAutoSessionRenewal(s *session) {
var timer *time.Timer
s.mux.Lock()
s.cancel = make(chan bool, 1)
s.mux.Unlock()
go func(s *session) {
for {
s.mux.RLock()
w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
s.mux.RUnlock()
if w < 0 {
return
}
timer = time.NewTimer(w)
select {
case <-timer.C:
renewal, err := cl.refreshSession(s)
if err != nil {
cl.Log("error refreshing session: %v", err)
}
if !renewal && err == nil {
// end this goroutine as there will have been a new login and new auto renewal goroutine created.
return
}
case <-s.cancel:
// cancel has been called. Stop the timer and exit.
timer.Stop()
return
}
}
}(s)
}
// renewTGT renews the client's TGT session.
func (cl *Client) renewTGT(s *session) error {
realm, tgt, skey := s.tgtDetails()
spn := types.PrincipalName{
NameType: nametype.KRB_NT_SRV_INST,
NameString: []string{"krbtgt", realm},
}
_, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
if err != nil {
return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
}
s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
cl.sessions.update(s)
cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
return nil
}
// refreshSession updates either through renewal or creating a new login.
// The boolean indicates if the update was a renewal.
func (cl *Client) refreshSession(s *session) (bool, error) {
s.mux.RLock()
realm := s.realm
renewTill := s.renewTill
s.mux.RUnlock()
cl.Log("refreshing TGT session for %s", realm)
if time.Now().UTC().Before(renewTill) {
err := cl.renewTGT(s)
return true, err
}
err := cl.realmLogin(realm)
return false, err
}
// ensureValidSession makes sure there is a valid session for the realm
func (cl *Client) ensureValidSession(realm string) error {
s, ok := cl.sessions.get(realm)
if ok {
s.mux.RLock()
d := s.endTime.Sub(s.authTime) / 6
if s.endTime.Sub(time.Now().UTC()) > d {
s.mux.RUnlock()
return nil
}
s.mux.RUnlock()
_, err := cl.refreshSession(s)
return err
}
return cl.realmLogin(realm)
}
// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
err = cl.ensureValidSession(realm)
if err != nil {
return
}
s, ok := cl.sessions.get(realm)
if !ok {
err = fmt.Errorf("could not find TGT session for %s", realm)
return
}
_, tgt, sessionKey = s.tgtDetails()
return
}
func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
s, ok := cl.sessions.get(realm)
if !ok {
err = fmt.Errorf("could not find TGT session for %s", realm)
return
}
_, authTime, endTime, renewTime, sessionExp = s.timeDetails()
return
}
// spnRealm resolves the realm name of a service principal name
func (cl *Client) spnRealm(spn types.PrincipalName) string {
return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
}