blob: 779d1c6ef66e889b3a8141ab5a9f45cfcf7e8c04 [file] [log] [blame]
Scott Baker2c1c4822019-10-16 11:02:41 -07001package rfc3961
2
3/*
4Implementation of the n-fold algorithm as defined in RFC 3961.
5
6n-fold is an algorithm that takes m input bits and "stretches" them
7to form n output bits with equal contribution from each input bit to
8the output, as described in [Blumenthal96]:
9
10We first define a primitive called n-folding, which takes a
11variable-length input block and produces a fixed-length output
12sequence. The intent is to give each input bit approximately
13equal weight in determining the value of each output bit. Note
14that whenever we need to treat a string of octets as a number, the
15assumed representation is Big-Endian -- Most Significant Byte
16first.
17
18To n-fold a number X, replicate the input value to a length that
19is the least common multiple of n and the length of X. Before
20each repetition, the input is rotated to the right by 13 bit
21positions. The successive n-bit chunks are added together using
221's-complement addition (that is, with end-around carry) to yield
23a n-bit result....
24*/
25
26/* Credits
27This golang implementation of nfold used the following project for help with implementation detail.
28Although their source is in java it was helpful as a reference implementation of the RFC.
29You can find the source code of their open source project along with license information below.
30We acknowledge and are grateful to these developers for their contributions to open source
31
32Project: Apache Directory (http://http://directory.apache.org/)
33https://svn.apache.org/repos/asf/directory/apacheds/tags/1.5.1/kerberos-shared/src/main/java/org/apache/directory/server/kerberos/shared/crypto/encryption/NFold.java
34License: http://www.apache.org/licenses/LICENSE-2.0
35*/
36
37// Nfold expands the key to ensure it is not smaller than one cipher block.
38// Defined in RFC 3961.
39//
40// m input bytes that will be "stretched" to the least common multiple of n bits and the bit length of m.
41func Nfold(m []byte, n int) []byte {
42 k := len(m) * 8
43
44 //Get the lowest common multiple of the two bit sizes
45 lcm := lcm(n, k)
46 relicate := lcm / k
47 var sumBytes []byte
48
49 for i := 0; i < relicate; i++ {
50 rotation := 13 * i
51 sumBytes = append(sumBytes, rotateRight(m, rotation)...)
52 }
53
54 nfold := make([]byte, n/8)
55 sum := make([]byte, n/8)
56 for i := 0; i < lcm/n; i++ {
57 for j := 0; j < n/8; j++ {
58 sum[j] = sumBytes[j+(i*len(sum))]
59 }
60 nfold = onesComplementAddition(nfold, sum)
61 }
62 return nfold
63}
64
65func onesComplementAddition(n1, n2 []byte) []byte {
66 numBits := len(n1) * 8
67 out := make([]byte, numBits/8)
68 carry := 0
69 for i := numBits - 1; i > -1; i-- {
70 n1b := getBit(&n1, i)
71 n2b := getBit(&n2, i)
72 s := n1b + n2b + carry
73
74 if s == 0 || s == 1 {
75 setBit(&out, i, s)
76 carry = 0
77 } else if s == 2 {
78 carry = 1
79 } else if s == 3 {
80 setBit(&out, i, 1)
81 carry = 1
82 }
83 }
84 if carry == 1 {
85 carryArray := make([]byte, len(n1))
86 carryArray[len(carryArray)-1] = 1
87 out = onesComplementAddition(out, carryArray)
88 }
89 return out
90}
91
92func rotateRight(b []byte, step int) []byte {
93 out := make([]byte, len(b))
94 bitLen := len(b) * 8
95 for i := 0; i < bitLen; i++ {
96 v := getBit(&b, i)
97 setBit(&out, (i+step)%bitLen, v)
98 }
99 return out
100}
101
102func lcm(x, y int) int {
103 return (x * y) / gcd(x, y)
104}
105
106func gcd(x, y int) int {
107 for y != 0 {
108 x, y = y, x%y
109 }
110 return x
111}
112
113func getBit(b *[]byte, p int) int {
114 pByte := p / 8
115 pBit := uint(p % 8)
116 vByte := (*b)[pByte]
117 vInt := int(vByte >> (8 - (pBit + 1)) & 0x0001)
118 return vInt
119}
120
121func setBit(b *[]byte, p, v int) {
122 pByte := p / 8
123 pBit := uint(p % 8)
124 oldByte := (*b)[pByte]
125 var newByte byte
126 newByte = byte(v<<(8-(pBit+1))) | oldByte
127 (*b)[pByte] = newByte
128}