blob: 3d6f516a595b3e0a8e530f2e220ea917da33ab62 [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package socks
6
7import (
8 "context"
9 "errors"
10 "io"
11 "net"
12 "strconv"
13 "time"
14)
15
16var (
17 noDeadline = time.Time{}
18 aLongTimeAgo = time.Unix(1, 0)
19)
20
21func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
22 host, port, err := splitHostPort(address)
23 if err != nil {
24 return nil, err
25 }
26 if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
27 c.SetDeadline(deadline)
28 defer c.SetDeadline(noDeadline)
29 }
30 if ctx != context.Background() {
31 errCh := make(chan error, 1)
32 done := make(chan struct{})
33 defer func() {
34 close(done)
35 if ctxErr == nil {
36 ctxErr = <-errCh
37 }
38 }()
39 go func() {
40 select {
41 case <-ctx.Done():
42 c.SetDeadline(aLongTimeAgo)
43 errCh <- ctx.Err()
44 case <-done:
45 errCh <- nil
46 }
47 }()
48 }
49
50 b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
51 b = append(b, Version5)
52 if len(d.AuthMethods) == 0 || d.Authenticate == nil {
53 b = append(b, 1, byte(AuthMethodNotRequired))
54 } else {
55 ams := d.AuthMethods
56 if len(ams) > 255 {
57 return nil, errors.New("too many authentication methods")
58 }
59 b = append(b, byte(len(ams)))
60 for _, am := range ams {
61 b = append(b, byte(am))
62 }
63 }
64 if _, ctxErr = c.Write(b); ctxErr != nil {
65 return
66 }
67
68 if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
69 return
70 }
71 if b[0] != Version5 {
72 return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
73 }
74 am := AuthMethod(b[1])
75 if am == AuthMethodNoAcceptableMethods {
76 return nil, errors.New("no acceptable authentication methods")
77 }
78 if d.Authenticate != nil {
79 if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
80 return
81 }
82 }
83
84 b = b[:0]
85 b = append(b, Version5, byte(d.cmd), 0)
86 if ip := net.ParseIP(host); ip != nil {
87 if ip4 := ip.To4(); ip4 != nil {
88 b = append(b, AddrTypeIPv4)
89 b = append(b, ip4...)
90 } else if ip6 := ip.To16(); ip6 != nil {
91 b = append(b, AddrTypeIPv6)
92 b = append(b, ip6...)
93 } else {
94 return nil, errors.New("unknown address type")
95 }
96 } else {
97 if len(host) > 255 {
98 return nil, errors.New("FQDN too long")
99 }
100 b = append(b, AddrTypeFQDN)
101 b = append(b, byte(len(host)))
102 b = append(b, host...)
103 }
104 b = append(b, byte(port>>8), byte(port))
105 if _, ctxErr = c.Write(b); ctxErr != nil {
106 return
107 }
108
109 if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
110 return
111 }
112 if b[0] != Version5 {
113 return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
114 }
115 if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
116 return nil, errors.New("unknown error " + cmdErr.String())
117 }
118 if b[2] != 0 {
119 return nil, errors.New("non-zero reserved field")
120 }
121 l := 2
122 var a Addr
123 switch b[3] {
124 case AddrTypeIPv4:
125 l += net.IPv4len
126 a.IP = make(net.IP, net.IPv4len)
127 case AddrTypeIPv6:
128 l += net.IPv6len
129 a.IP = make(net.IP, net.IPv6len)
130 case AddrTypeFQDN:
131 if _, err := io.ReadFull(c, b[:1]); err != nil {
132 return nil, err
133 }
134 l += int(b[0])
135 default:
136 return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
137 }
138 if cap(b) < l {
139 b = make([]byte, l)
140 } else {
141 b = b[:l]
142 }
143 if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
144 return
145 }
146 if a.IP != nil {
147 copy(a.IP, b)
148 } else {
149 a.Name = string(b[:len(b)-2])
150 }
151 a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
152 return &a, nil
153}
154
155func splitHostPort(address string) (string, int, error) {
156 host, port, err := net.SplitHostPort(address)
157 if err != nil {
158 return "", 0, err
159 }
160 portnum, err := strconv.Atoi(port)
161 if err != nil {
162 return "", 0, err
163 }
164 if 1 > portnum || portnum > 0xffff {
165 return "", 0, errors.New("port number out of range " + port)
166 }
167 return host, portnum, nil
168}