I enjoyed reading the The Fibonacci Problem blog post that shows (among other things) a tail recursive algorithm for the Fibonacci sequence.
The original code is in Golang which has no tail recursion optimization. I wanted to compare the Golang code to Kotlin, see Kotlin’s tailrec
in action and measure the entire thing with jmh
.
I retained the original function names (up to capitalization) so the original Golang code:
func FibNaive(n int) int {
if n < 2 {
return n
}
return FibNaive(n-1) + FibNaive(n-2)
}
looks like so in my version in Kotlin:
fun fibNaive(n: Int): Int =
when (n) {
0, 1 -> n
else -> fibNaive(n - 1) + fibNaive(n - 2)
}
Sources are here. To build and run, execute:
mvn clean install && java -jar target/benchmarks.jar
FibCached
Here, we improve on the naive approach by caching the intermediate results:
1fun fibCached(
2 n: Int,
3 cache: MutableMap<Int, Int> = mutableMapOf(Pair(0, 0), Pair(1, 1))): Int =
4
5 cache.getOrPut(n) {
6 cache.getOrPut(n - 1) { fibCached(n - 1, cache) } +
7 cache.getOrPut(n - 2) { fibCached(n - 2, cache) }
8 }
This gives us over five orders of magnitude improvement over fibNaive
.
The initial implementation worked around Kotlin not understanding my intent, details are on StackOverflow
FibVectorSum
The key observation here is that we can replace two recursive calls each returning a scalar ( i.e. f(n-1) + f(n-2)
) with a single recursive
call that returns a vector (a Pair really).
1fun T(p: Pair<Int, Int>): Pair<Int, Int> = Pair(p.first + p.second, p.first)
2
3fun fibVecSum(n: Int): Int {
4
5 fun fibVec(n: Int): Pair<Int, Int> =
6 when (n) {
7 1 -> Pair(1, 0)
8 else -> T(fibVec(n - 1))
9 }
10
11
12 return when (n) {
13 0, 1 -> n
14
15 else -> {
16 val (a, b) = fibVec(n - 1)
17 a + b
18 }
19 }
20}
This gives us an order of magnitude improvement over fibCached
.
Incidentally, there is a a more Kotlin-idiomatic way to implement the same approach using the generateSequence
function. It doesn’t match
the narrative of the original post though:
fun fibVecSumKotlin(n: Int): Int {
fun genPairSequence(): Sequence<Pair<Int, Int>> =
generateSequence(Pair(1, 0), { T(it) })
return genPairSequence().take(n + 1).last().second
}
FibTailVecSum
The remaining issue in FibVectorSum
preventing it being tail recursive is the transformation at line 8. This can be fixed by accumulating
the sum into the recursive call:
1fun fibTailVecSum(n: Int): Int {
2
3 tailrec fun fibTailVec(acc: Int, a: Int, b: Int): Pair<Int, Int> =
4 when (acc) {
5 1 -> Pair(a, b)
6 else -> fibTailVec(acc - 1, a + b, a)
7 }
8
9 return when (n) {
10 0, 1 -> n
11 else -> {
12 val (a, b) = fibTailVec(n - 1, 1, 0)
13 a + b
14 }
15 }
16}
Note the tailrec
decoration of the recursive function at line 3. Once again we have an order of magnitude improvement over the previous take; fibVectorSum
.
FibIterative
In the iterative version the fibTailVec
is replaced by a while loop
fun fibIterative(n: Int): Int {
if (n < 2) {
return n
}
var acc = n
var a = 1
var b = 0
var tmp = 0
while (acc > 2) {
acc--
tmp = a
a += b
b = tmp
}
return a + b
}
The code is more verbose than the Go version because Kotlin has no “multiple assignments” as in n, a, b = n-1, a+b, a
.
It is also the fastest.
Benchmarks
Benchmarks were generated using jmh
with the following configuration:
@State(Scope.Benchmark)
@Fork(1)
@Warmup(iterations = 15)
@Measurement(iterations = 15)
@BenchmarkMode(AverageTime)
@OutputTimeUnit(NANOSECONDS)
Benchmark | (n) | Score | Units | Normalaized by fibIterative |
---|---|---|---|---|
fibNaive | 12 | 505.787 | ns/op | 109 |
fibNaive | 40 | 359727438.400 | ns/op | 32702494 |
fibCached | 12 | 490.736 | ns/op | 106 |
fibCached | 40 | 1393.079 | ns/op | 123 |
fibVecSum | 12 | 56.979 | ns/op | 12 |
fibVecSum | 40 | 242.230 | ns/op | 21 |
fibVecSumKotlin | 12 | 99.435 | ns/op | 21 |
fibVecSumKotlin | 40 | 273.117 | ns/op | 24 |
fibTailVecSum | 12 | 8.447 | ns/op | 2 |
fibTailVecSum | 40 | 14.800 | ns/op | 1 |
fibIterative | 12 | 4.666 | ns/op | 1 |
fibIterative | 40 | 11.372 | ns/op | 1 |
fibIterativeTabulated | 12 | 81.520 | ns/op | 17 |
fibIterativeTabulated | 40 | 257.801 | ns/op | 22 |