Recursion & Tail-Call Optimization
Learning Objectives
By the end of this checkpoint, you will:
- Understand the difference between tail and non-tail recursion
- Use accumulators to make recursion tail-optimized
- Know when to reverse the accumulator
- Implement basic list operations recursively
Setup
Mix.install([])
Concept: Recursion in Elixir
Recursion is a fundamental technique in functional programming. A recursive function calls itself with modified arguments until it reaches a base case.
defmodule SimpleRecursion do
# Base case
def countdown(0), do: IO.puts("Blast off! 🚀")
# Recursive case
def countdown(n) when n > 0 do
IO.puts(n)
countdown(n - 1)
end
end
SimpleRecursion.countdown(5)
The Problem: Stack Overflow
Non-tail-recursive functions can cause stack overflow on large inputs:
defmodule BadList do
def sum([]), do: 0
def sum([h | t]), do: h + sum(t)
end
# This works fine for small lists
IO.puts("Small list sum: #{BadList.sum([1, 2, 3, 4, 5])}")
# But try uncommenting this - it will crash on very large lists!
# BadList.sum(1..100_000 |> Enum.to_list())
Why does it crash?
The problem is that h + sum(t) performs the addition AFTER the recursive call returns. Each call adds a new stack frame, and with 100,000 items, we run out of stack space.
The Solution: Tail Recursion
A function is tail-recursive if the recursive call is the LAST operation. The BEAM can optimize this into a loop!
defmodule GoodList do
# Public API - easy to use
def sum(list), do: do_sum(list, 0)
# Private tail-recursive implementation
defp do_sum([], acc), do: acc
defp do_sum([h | t], acc), do: do_sum(t, acc + h)
end
# This handles large lists without crashing
IO.puts("Small list: #{GoodList.sum([1, 2, 3, 4, 5])}")
large_list = 1..100_000 |> Enum.to_list()
IO.puts("Large list (100k items): #{GoodList.sum(large_list)}")
Interactive Exercise 2.1: Understanding the Difference
Let’s visualize the difference:
defmodule Visualization do
# Non-tail recursive - builds up operations
def non_tail_fact(0), do: 1
def non_tail_fact(n), do: n * non_tail_fact(n - 1)
# Tail recursive - uses accumulator
def tail_fact(n), do: tail_fact(n, 1)
defp tail_fact(0, acc), do: acc
defp tail_fact(n, acc), do: tail_fact(n - 1, n * acc)
end
n = 5
IO.puts("Non-tail factorial(#{n}): #{Visualization.non_tail_fact(n)}")
IO.puts("Tail factorial(#{n}): #{Visualization.tail_fact(n)}")
# Compare with larger number
n = 10
IO.puts("\nNon-tail factorial(#{n}): #{Visualization.non_tail_fact(n)}")
IO.puts("Tail factorial(#{n}): #{Visualization.tail_fact(n)}")
Interactive Exercise 2.2: Convert to Tail Recursion
Implement a tail-recursive length/1 function:
defmodule MyList do
# Non-tail-recursive version (don't use this!)
def length_bad([]), do: 0
def length_bad([_ | t]), do: 1 + length_bad(t)
# Tail-recursive version - YOU IMPLEMENT THIS!
def length(list), do: do_length(list, 0)
defp do_length([], acc), do: acc
defp do_length([_ | t], acc), do: do_length(t, acc + 1)
end
# Test both versions
test_list = [1, 2, 3, 4, 5]
IO.puts("Bad length: #{MyList.length_bad(test_list)}")
IO.puts("Good length: #{MyList.length(test_list)}")
# Test with larger list
large_list = 1..10_000 |> Enum.to_list()
IO.puts("Large list length: #{MyList.length(large_list)}")
Interactive Exercise 2.3: Implement map/2
Implement a tail-recursive map/2 function:
defmodule MyMap do
@doc """
Maps a function over a list.
## Examples
iex> MyMap.map([1, 2, 3], fn x -> x * 2 end)
[2, 4, 6]
"""
def map(list, func), do: do_map(list, func, [])
defp do_map([], _func, acc), do: Enum.reverse(acc)
defp do_map([h | t], func, acc) do
do_map(t, func, [func.(h) | acc])
end
end
# Test the implementation
result = MyMap.map([1, 2, 3, 4, 5], fn x -> x * 2 end)
IO.inspect(result, label: "Double each number")
result = MyMap.map(["hello", "world"], &String.upcase/1)
IO.inspect(result, label: "Uppercase strings")
Important: Notice we reverse the accumulator at the end! This is because we build the list backwards (prepending is O(1), appending is O(n)).
Interactive Exercise 2.4: Implement filter/2
Now implement a tail-recursive filter/2 function:
defmodule MyFilter do
@doc """
Filters a list based on a predicate.
## Examples
iex> MyFilter.filter([1, 2, 3, 4], fn x -> rem(x, 2) == 0 end)
[2, 4]
"""
def filter(list, predicate), do: do_filter(list, predicate, [])
defp do_filter([], _pred, acc), do: Enum.reverse(acc)
defp do_filter([h | t], pred, acc) do
if pred.(h) do
do_filter(t, pred, [h | acc])
else
do_filter(t, pred, acc)
end
end
end
# Test the implementation
evens = MyFilter.filter([1, 2, 3, 4, 5, 6], fn x -> rem(x, 2) == 0 end)
IO.inspect(evens, label: "Even numbers")
long_words = MyFilter.filter(["hi", "hello", "hey", "greetings"], fn w -> String.length(w) > 3 end)
IO.inspect(long_words, label: "Words longer than 3 chars")
Advanced: When NOT to Use Accumulators
Sometimes the non-tail-recursive version is clearer and won’t cause problems:
defmodule TreeTraversal do
# Binary tree node
defmodule Node do
defstruct [:value, :left, :right]
end
# Non-tail recursive tree traversal is often clearer
def inorder(nil), do: []
def inorder(%Node{value: v, left: l, right: r}) do
inorder(l) ++ [v] ++ inorder(r)
end
end
# Create a small tree
tree = %TreeTraversal.Node{
value: 4,
left: %TreeTraversal.Node{
value: 2,
left: %TreeTraversal.Node{value: 1, left: nil, right: nil},
right: %TreeTraversal.Node{value: 3, left: nil, right: nil}
},
right: %TreeTraversal.Node{
value: 6,
left: %TreeTraversal.Node{value: 5, left: nil, right: nil},
right: %TreeTraversal.Node{value: 7, left: nil, right: nil}
}
}
IO.inspect(TreeTraversal.inorder(tree), label: "Inorder traversal")
For tree traversal, the non-tail version is fine because:
- Trees are typically not deep enough to cause stack overflow
- The code is much clearer
- The algorithmic complexity is the same
Practice: Implement reduce/3
The ultimate recursive function - implement your own reduce:
defmodule MyReduce do
@doc """
Reduces a list to a single value.
## Examples
iex> MyReduce.reduce([1, 2, 3, 4], 0, fn x, acc -> x + acc end)
10
iex> MyReduce.reduce([1, 2, 3], 1, fn x, acc -> x * acc end)
6
"""
def reduce(list, initial, func), do: do_reduce(list, initial, func)
defp do_reduce([], acc, _func), do: acc
defp do_reduce([h | t], acc, func) do
do_reduce(t, func.(h, acc), func)
end
end
# Test reduce
sum = MyReduce.reduce([1, 2, 3, 4, 5], 0, fn x, acc -> x + acc end)
IO.puts("Sum: #{sum}")
product = MyReduce.reduce([1, 2, 3, 4, 5], 1, fn x, acc -> x * acc end)
IO.puts("Product: #{product}")
# Build a string
sentence =
MyReduce.reduce(["Hello", "from", "Elixir"], "", fn word, acc ->
if acc == "", do: word, else: acc <> " " <> word
end)
IO.puts("Sentence: #{sentence}")
Self-Assessment
form = Kino.Control.form(
[
tail_vs_non: {:checkbox, "I understand the difference between tail and non-tail recursion"},
accumulators: {:checkbox, "I can use accumulators to make recursion tail-optimized"},
reverse_acc: {:checkbox, "I know when to reverse the accumulator"},
implement_ops: {:checkbox, "I can implement basic list operations recursively"},
recognize_tco: {:checkbox, "I can recognize tail call optimization opportunities"}
],
submit: "Check Progress"
)
Kino.render(form)
Kino.listen(form, fn event ->
completed = event.data |> Map.values() |> Enum.count(& &1)
total = map_size(event.data)
progress_message =
if completed == total do
"🎉 Excellent! You've mastered Checkpoint 2!"
else
"Keep going! #{completed}/#{total} objectives complete"
end
Kino.Markdown.new("### Progress: #{progress_message}") |> Kino.render()
end)
Key Takeaways
- Tail recursion is when the recursive call is the last operation
- The BEAM optimizes tail calls into loops (no stack growth!)
- Accumulators carry state through recursive calls
- Lists built in reverse need Enum.reverse/1 at the end
- Non-tail recursion is fine for small/shallow data structures
- Most list operations can be implemented with tail recursion
Next Steps
Great work! Continue to the next checkpoint:
Continue to Checkpoint 3: Enum vs Stream →
Or return to Checkpoint 1: Pattern Matching