blob: 9536b1e3e3538fd24b964672cc670ff7b110fcb7 [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001package rfc3961
2
3// Implementation of the n-fold algorithm as defined in RFC 3961.
4
5/* Credits
6This golang implementation of nfold used the following project for help with implementation detail.
7Although their source is in java it was helpful as a reference implementation of the RFC.
8You can find the source code of their open source project along with license information below.
9We acknowledge and are grateful to these developers for their contributions to open source
10
11Project: Apache Directory (http://http://directory.apache.org/)
12https://svn.apache.org/repos/asf/directory/apacheds/tags/1.5.1/kerberos-shared/src/main/java/org/apache/directory/server/kerberos/shared/crypto/encryption/NFold.java
13License: http://www.apache.org/licenses/LICENSE-2.0
14*/
15
16// Nfold expands the key to ensure it is not smaller than one cipher block.
17// Defined in RFC 3961.
18//
19// m input bytes that will be "stretched" to the least common multiple of n bits and the bit length of m.
20func Nfold(m []byte, n int) []byte {
21 k := len(m) * 8
22
23 //Get the lowest common multiple of the two bit sizes
24 lcm := lcm(n, k)
25 relicate := lcm / k
26 var sumBytes []byte
27
28 for i := 0; i < relicate; i++ {
29 rotation := 13 * i
30 sumBytes = append(sumBytes, rotateRight(m, rotation)...)
31 }
32
33 nfold := make([]byte, n/8)
34 sum := make([]byte, n/8)
35 for i := 0; i < lcm/n; i++ {
36 for j := 0; j < n/8; j++ {
37 sum[j] = sumBytes[j+(i*len(sum))]
38 }
39 nfold = onesComplementAddition(nfold, sum)
40 }
41 return nfold
42}
43
44func onesComplementAddition(n1, n2 []byte) []byte {
45 numBits := len(n1) * 8
46 out := make([]byte, numBits/8)
47 carry := 0
48 for i := numBits - 1; i > -1; i-- {
49 n1b := getBit(&n1, i)
50 n2b := getBit(&n2, i)
51 s := n1b + n2b + carry
52
53 if s == 0 || s == 1 {
54 setBit(&out, i, s)
55 carry = 0
56 } else if s == 2 {
57 carry = 1
58 } else if s == 3 {
59 setBit(&out, i, 1)
60 carry = 1
61 }
62 }
63 if carry == 1 {
64 carryArray := make([]byte, len(n1))
65 carryArray[len(carryArray)-1] = 1
66 out = onesComplementAddition(out, carryArray)
67 }
68 return out
69}
70
71func rotateRight(b []byte, step int) []byte {
72 out := make([]byte, len(b))
73 bitLen := len(b) * 8
74 for i := 0; i < bitLen; i++ {
75 v := getBit(&b, i)
76 setBit(&out, (i+step)%bitLen, v)
77 }
78 return out
79}
80
81func lcm(x, y int) int {
82 return (x * y) / gcd(x, y)
83}
84
85func gcd(x, y int) int {
86 for y != 0 {
87 x, y = y, x%y
88 }
89 return x
90}
91
92func getBit(b *[]byte, p int) int {
93 pByte := p / 8
94 pBit := uint(p % 8)
95 vByte := (*b)[pByte]
96 vInt := int(vByte >> (8 - (pBit + 1)) & 0x0001)
97 return vInt
98}
99
100func setBit(b *[]byte, p, v int) {
101 pByte := p / 8
102 pBit := uint(p % 8)
103 oldByte := (*b)[pByte]
104 var newByte byte
105 newByte = byte(v<<(8-(pBit+1))) | oldByte
106 (*b)[pByte] = newByte
107}