blob: 366ab7bf9ff022c29b36f4d10b964a077880083d [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7//+build gssapi
8//+build linux darwin
9
10package gssapi
11
12/*
13#cgo linux CFLAGS: -DGOOS_linux
14#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
15#cgo darwin CFLAGS: -DGOOS_darwin
16#cgo darwin LDFLAGS: -framework GSS
17#include "gss_wrapper.h"
18*/
19import "C"
20import (
21 "fmt"
22 "net"
23 "runtime"
24 "strings"
25 "unsafe"
26)
27
28// New creates a new SaslClient.
29func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
30 serviceName := "mongodb"
31
32 for key, value := range props {
33 switch strings.ToUpper(key) {
34 case "CANONICALIZE_HOST_NAME":
35 return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
36 case "SERVICE_REALM":
37 return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
38 case "SERVICE_NAME":
39 serviceName = value
40 default:
41 return nil, fmt.Errorf("unknown mechanism property %s", key)
42 }
43 }
44
45 hostname, _, err := net.SplitHostPort(target)
46 if err != nil {
47 return nil, fmt.Errorf("invalid endpoint (%s) specified: %s", target, err)
48 }
49
50 servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, hostname)
51
52 return &SaslClient{
53 servicePrincipalName: servicePrincipalName,
54 username: username,
55 password: password,
56 passwordSet: passwordSet,
57 }, nil
58}
59
60type SaslClient struct {
61 servicePrincipalName string
62 username string
63 password string
64 passwordSet bool
65
66 // state
67 state C.gssapi_client_state
68 contextComplete bool
69 done bool
70}
71
72func (sc *SaslClient) Close() {
73 C.gssapi_client_destroy(&sc.state)
74}
75
76func (sc *SaslClient) Start() (string, []byte, error) {
77 const mechName = "GSSAPI"
78
79 cservicePrincipalName := C.CString(sc.servicePrincipalName)
80 defer C.free(unsafe.Pointer(cservicePrincipalName))
81 var cusername *C.char
82 var cpassword *C.char
83 if sc.username != "" {
84 cusername = C.CString(sc.username)
85 defer C.free(unsafe.Pointer(cusername))
86 if sc.passwordSet {
87 cpassword = C.CString(sc.password)
88 defer C.free(unsafe.Pointer(cpassword))
89 }
90 }
91 status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
92
93 if status != C.GSSAPI_OK {
94 return mechName, nil, sc.getError("unable to initialize client")
95 }
96
97 return mechName, nil, nil
98}
99
100func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
101
102 var buf unsafe.Pointer
103 var bufLen C.size_t
104 var outBuf unsafe.Pointer
105 var outBufLen C.size_t
106
107 if sc.contextComplete {
108 if sc.username == "" {
109 var cusername *C.char
110 status := C.gssapi_client_username(&sc.state, &cusername)
111 if status != C.GSSAPI_OK {
112 return nil, sc.getError("unable to acquire username")
113 }
114 defer C.free(unsafe.Pointer(cusername))
115 sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
116 }
117
118 bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
119 buf = unsafe.Pointer(&bytes[0])
120 bufLen = C.size_t(len(bytes))
121 status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
122 if status != C.GSSAPI_OK {
123 return nil, sc.getError("unable to wrap authz")
124 }
125
126 sc.done = true
127 } else {
128 if len(challenge) > 0 {
129 buf = unsafe.Pointer(&challenge[0])
130 bufLen = C.size_t(len(challenge))
131 }
132
133 status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
134 switch status {
135 case C.GSSAPI_OK:
136 sc.contextComplete = true
137 case C.GSSAPI_CONTINUE:
138 default:
139 return nil, sc.getError("unable to negotiate with server")
140 }
141 }
142
143 if outBuf != nil {
144 defer C.free(outBuf)
145 }
146
147 return C.GoBytes(outBuf, C.int(outBufLen)), nil
148}
149
150func (sc *SaslClient) Completed() bool {
151 return sc.done
152}
153
154func (sc *SaslClient) getError(prefix string) error {
155 var desc *C.char
156
157 status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
158 if status != C.GSSAPI_OK {
159 if desc != nil {
160 C.free(unsafe.Pointer(desc))
161 }
162
163 return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
164 }
165 defer C.free(unsafe.Pointer(desc))
166
167 return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
168}