blob: 278713ea5b41616326e04ebabefcb91d5c606ff3 [file] [log] [blame]
Scott Baker8461e152019-10-01 14:44:30 -07001// Package aescts provides AES CBC CipherText Stealing encryption and decryption methods
2package aescts
3
4import (
5 "crypto/aes"
6 "crypto/cipher"
7 "errors"
8 "fmt"
9)
10
11// Encrypt the message with the key and the initial vector.
12// Returns: next iv, ciphertext bytes, error
13func Encrypt(key, iv, plaintext []byte) ([]byte, []byte, error) {
14 l := len(plaintext)
15
16 block, err := aes.NewCipher(key)
17 if err != nil {
18 return []byte{}, []byte{}, fmt.Errorf("Error creating cipher: %v", err)
19 }
20 mode := cipher.NewCBCEncrypter(block, iv)
21
22 m := make([]byte, len(plaintext))
23 copy(m, plaintext)
24
25 /*For consistency, ciphertext stealing is always used for the last two
26 blocks of the data to be encrypted, as in [RC5]. If the data length
27 is a multiple of the block size, this is equivalent to plain CBC mode
28 with the last two ciphertext blocks swapped.*/
29 /*The initial vector carried out from one encryption for use in a
30 subsequent encryption is the next-to-last block of the encryption
31 output; this is the encrypted form of the last plaintext block.*/
32 if l <= aes.BlockSize {
33 m, _ = zeroPad(m, aes.BlockSize)
34 mode.CryptBlocks(m, m)
35 return m, m, nil
36 }
37 if l%aes.BlockSize == 0 {
38 mode.CryptBlocks(m, m)
39 iv = m[len(m)-aes.BlockSize:]
40 rb, _ := swapLastTwoBlocks(m, aes.BlockSize)
41 return iv, rb, nil
42 }
43 m, _ = zeroPad(m, aes.BlockSize)
44 rb, pb, lb, err := tailBlocks(m, aes.BlockSize)
45 if err != nil {
46 return []byte{}, []byte{}, fmt.Errorf("Error tailing blocks: %v", err)
47 }
48 var ct []byte
49 if rb != nil {
50 // Encrpt all but the lats 2 blocks and update the rolling iv
51 mode.CryptBlocks(rb, rb)
52 iv = rb[len(rb)-aes.BlockSize:]
53 mode = cipher.NewCBCEncrypter(block, iv)
54 ct = append(ct, rb...)
55 }
56 mode.CryptBlocks(pb, pb)
57 mode = cipher.NewCBCEncrypter(block, pb)
58 mode.CryptBlocks(lb, lb)
59 // Cipher Text Stealing (CTS) - Ref: https://en.wikipedia.org/wiki/Ciphertext_stealing#CBC_ciphertext_stealing
60 // Swap the last two cipher blocks
61 // Truncate the ciphertext to the length of the original plaintext
62 ct = append(ct, lb...)
63 ct = append(ct, pb...)
64 return lb, ct[:l], nil
65}
66
67// Decrypt the ciphertext with the key and the initial vector.
68func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
69 // Copy the cipher text as golang slices even when passed by value to this method can result in the backing arrays of the calling code value being updated.
70 ct := make([]byte, len(ciphertext))
71 copy(ct, ciphertext)
72 if len(ct) < aes.BlockSize {
73 return []byte{}, fmt.Errorf("Ciphertext is not large enough. It is less that one block size. Blocksize:%v; Ciphertext:%v", aes.BlockSize, len(ct))
74 }
75 // Configure the CBC
76 block, err := aes.NewCipher(key)
77 if err != nil {
78 return nil, fmt.Errorf("Error creating cipher: %v", err)
79 }
80 var mode cipher.BlockMode
81
82 //If ciphertext is multiple of blocksize we just need to swap back the last two blocks and then do CBC
83 //If the ciphertext is just one block we can't swap so we just decrypt
84 if len(ct)%aes.BlockSize == 0 {
85 if len(ct) > aes.BlockSize {
86 ct, _ = swapLastTwoBlocks(ct, aes.BlockSize)
87 }
88 mode = cipher.NewCBCDecrypter(block, iv)
89 message := make([]byte, len(ct))
90 mode.CryptBlocks(message, ct)
91 return message[:len(ct)], nil
92 }
93
94 // Cipher Text Stealing (CTS) using CBC interface. Ref: https://en.wikipedia.org/wiki/Ciphertext_stealing#CBC_ciphertext_stealing
95 // Get ciphertext of the 2nd to last (penultimate) block (cpb), the last block (clb) and the rest (crb)
96 crb, cpb, clb, _ := tailBlocks(ct, aes.BlockSize)
97 v := make([]byte, len(iv), len(iv))
98 copy(v, iv)
99 var message []byte
100 if crb != nil {
101 //If there is more than just the last and the penultimate block we decrypt it and the last bloc of this becomes the iv for later
102 rb := make([]byte, len(crb))
103 mode = cipher.NewCBCDecrypter(block, v)
104 v = crb[len(crb)-aes.BlockSize:]
105 mode.CryptBlocks(rb, crb)
106 message = append(message, rb...)
107 }
108
109 // We need to modify the cipher text
110 // Decryt the 2nd to last (penultimate) block with a the original iv
111 pb := make([]byte, aes.BlockSize)
112 mode = cipher.NewCBCDecrypter(block, iv)
113 mode.CryptBlocks(pb, cpb)
114 // number of byte needed to pad
115 npb := aes.BlockSize - len(ct)%aes.BlockSize
116 //pad last block using the number of bytes needed from the tail of the plaintext 2nd to last (penultimate) block
117 clb = append(clb, pb[len(pb)-npb:]...)
118
119 // Now decrypt the last block in the penultimate position (iv will be from the crb, if the is no crb it's zeros)
120 // iv for the penultimate block decrypted in the last position becomes the modified last block
121 lb := make([]byte, aes.BlockSize)
122 mode = cipher.NewCBCDecrypter(block, v)
123 v = clb
124 mode.CryptBlocks(lb, clb)
125 message = append(message, lb...)
126
127 // Now decrypt the penultimate block in the last position (iv will be from the modified last block)
128 mode = cipher.NewCBCDecrypter(block, v)
129 mode.CryptBlocks(cpb, cpb)
130 message = append(message, cpb...)
131
132 // Truncate to the size of the original cipher text
133 return message[:len(ct)], nil
134}
135
136func tailBlocks(b []byte, c int) ([]byte, []byte, []byte, error) {
137 if len(b) <= c {
138 return []byte{}, []byte{}, []byte{}, errors.New("bytes slice is not larger than one block so cannot tail")
139 }
140 // Get size of last block
141 var lbs int
142 if l := len(b) % aes.BlockSize; l == 0 {
143 lbs = aes.BlockSize
144 } else {
145 lbs = l
146 }
147 // Get last block
148 lb := b[len(b)-lbs:]
149 // Get 2nd to last (penultimate) block
150 pb := b[len(b)-lbs-c : len(b)-lbs]
151 if len(b) > 2*c {
152 rb := b[:len(b)-lbs-c]
153 return rb, pb, lb, nil
154 }
155 return nil, pb, lb, nil
156}
157
158func swapLastTwoBlocks(b []byte, c int) ([]byte, error) {
159 rb, pb, lb, err := tailBlocks(b, c)
160 if err != nil {
161 return nil, err
162 }
163 var out []byte
164 if rb != nil {
165 out = append(out, rb...)
166 }
167 out = append(out, lb...)
168 out = append(out, pb...)
169 return out, nil
170}
171
172// zeroPad pads bytes with zeros to nearest multiple of message size m.
173func zeroPad(b []byte, m int) ([]byte, error) {
174 if m <= 0 {
175 return nil, errors.New("Invalid message block size when padding")
176 }
177 if b == nil || len(b) == 0 {
178 return nil, errors.New("Data not valid to pad: Zero size")
179 }
180 if l := len(b) % m; l != 0 {
181 n := m - l
182 z := make([]byte, n)
183 b = append(b, z...)
184 }
185 return b, nil
186}