What is the C++ nth_element algorithm?

The C++ standard library provides an algorithm called nth_element() that is a useful tool for partial sorting. But how does it work and where is it useful?

First of all, the algorithm is defined as follows: given a range [first, last) and an nth element within the range, the algorithm partially sorts the range so that:

  • the nth element is changed to whatever element would appear in that position if the range [first, last) was sorted
  • all the elements before it are smaller than or equal to it, although their order is unspecified
  • all the elements after it are greater than or equal to it, although their order is also unspecified

Keep in mind that in C++ indexes are zero based. So nth = 0 means first element, nth = 1 means second element, nth = 5 means 6th element, and so on.

Here is a quick example:

#include <iostream>
#include <vector>
#include <print>
#include <algorithm>

int main()
{
    std::vector<int> v {9, 4, 3, 8, 1, 2, 1, 8, 7, 6};
    auto mid = v.begin() + v.size() / 2;
    std::nth_element(v.begin(), mid, v.end());

    for(auto const & e : v) 
        std::print("{} ", e);
    std::println();
}

The output of this program is:

2 4 3 1 1 6 8 8 7 9

The vector v has 10 elements. Therefore, mid, is the 6th element (the one at index 5 – half the vector size). After partial sorting, the first 5 elements (2 4 3 1 1) are smaller than the 6th element (which incidentally is the value 6), and the last 4 elements (8 8 7 9) are greater than the 6th element, although neither are in a sorted order.

You can also use the range algorithm std::ranges::nth_element() as follows:

std::ranges::nth_element(v, mid);

Now that we’ve seen how the algorithm works, let’s look at its uses cases.

Computing the median efficiently

The median is the middle value in a sorted dataset. If the size of the dataset is odd, then the median is at the (zero-based) index n/2. If the size of the dataset is even, then the median is the average (mean) value of the two middle values, at the (zero-based) indexes n/2 - 1 and n/2.

The median, should not be confused with the mean (or average, which is the sum of all the elements of the dataset divided by their number). The median is more useful than the mean in various scenarios because it’s not influenced by outliers. Here is an example: the monthly income of 10 households on a street is (in sorted order) 2500, 2700, 3100, 3200, 3300, 4500, 4900, 5100, 5500, 25000. The mean is 5980 because there is an outlier in the set, but the median is only 3900, which is very close to the mean value of the first 9 values in the dataset (3866).

Here is a C++ function that computes the median of a vector of integers:

double compute_median(std::vector<int> data) 
{
   size_t mid = n / 2;

   if (data.size() % 2 == 1) 
   {
      std::nth_element(data.begin(), data.begin() + mid, data.end());
      return data[mid];
   }
   else 
   {
      std::nth_element(data.begin(), data.begin() + mid - 1, data.end());
      int val1 = data[mid - 1];
      std::nth_element(data.begin(), data.begin() + mid, data.end());
      int val2 = data[mid];
      return (val1 + val2) / 2.0;
   }
}

Using the nth_element() algorithm to compute the median avoids sorting the dataset completely (with O(n log n) complexity) and runs in average O(n).

There are various practical uses for the median such as the following:

  • data analysis and statistics (income distribution such as in the previous mentioned example) where the median is more robust than the mean when data contains outliers
  • image processing: the median filter is a technique to reduce noise by replacing each pixel’s value with the median of its neighboring pixels within a defined window (for instance 3×3)
  • sensors reading: the median of several readings can be used to reduce noise from environmental sensors
  • real time systems: for quick estimates of typical latency, response time, and other metrics

Percentiles

Percentiles are basically a generalized median. A percentile is a way to describe the relative position of a value in a dataset. The pth percentile is the value bellow which p% of the data falls. Examples:

  • 25th percentile: 25% of the values in a dataset are smaller than this value, the other 75% are greater
  • 50th percentile: 50% of the values in a dataset are smaller than this value, the other 50% are greater; this is the median
  • 99th percentile: 99% of the values in a dataset are smaller than this value, and only 1% are greater

Common examples for percentiles:

  • in education: “you score in the 95th percentile” means you are better than 95% of the students
  • in finance: “the 90th percentile loss” means only 10% of losses are larger
  • in system monitoring: “99th percentile response time request” means slowest 1% of requests

You can easily determine the percentile using nth_element() as shown in the following snippet:

template <typename Iter>
void partial_percentile_sort(Iter first, Iter last, double const percentile) 
{
   size_t n = std::distance(first, last);
   if (n == 0 || percentile == 0.0) return;

   size_t k = static_cast<size_t>(percentile * n);
   if (k == 0) k = 1;

   std::nth_element(first, first + k, last);
}

int main()
{
   std::vector<int> v{ 9, 1, 8, 2, 7, 3, 6, 4, 5 };

   double percentile = 0.6; // smallest 60% of elements
   partial_percentile_sort(v.begin(), v.end(), percentile);

   std::print("First 60% smallest elements (unordered): ");
   size_t k = static_cast<size_t>(percentile * v.size());
   for (size_t i = 0; i < k; ++i)
   {
      std::print("{} ", v[i]);
   }
   std::println();
}

Finding the smallest/largest n elements

The nth_element() algorithm is useful for determining the smaller or largest n elements in a sequence, without having to sort it entirely. If you do need the elements sorted, you can call std::sort() afterwards but only for the first n elements of the (partially sorted) range.

Here is how to determine the smallest n elements in a range:

std::vector<int> v {9, 4, 3, 8, 1, 2, 1, 8, 7, 6};
auto n = n;

std::nth_element(v.begin(), v.begin() + n, v.end());
std::sort(v.begin(), v.begin() + n);

Here is how to determine the largest n elements in a range:

std::vector<int> v {9, 4, 3, 8, 1, 2, 1, 8, 7, 6};
auto n = n;

std::nth_element(v.begin(), v.begin() + n, v.end(), std::greater<>{});
std::sort(v.begin(), v.begin() + n);

As you can see, the only difference is the use of std::greater<>{} for comparing the elements.

Practical use cases for the smallest / largest n elements include the following:

  • leaderboards / rankings: you need to determine the first 10/25/100/etc. players in a tournament with a large number of participants; using nth_element() allows you to only sort the small number of elements you need not the entire data set
  • search results: fetching the most relevant search results ranked by a score
  • resource monitoring: find the processes with the largest memory consumption on a server
  • data analysis and statistics: find the top earning individuals in a group, find the least performing students in a school, etc.

Quicksort and pivot selection

Quicksort is a commonly used algorithm for sorting that works by selecting a pivot from the sequence of elements to sort, then partition the sequence into elements less than the pivot, the pivot, and elements greater than the pivot, and then recursively apply quicksort on the two partitions. However, its efficiency depends on the value of the pivot:

  • if the pivot is near the median, the partitioning is balanced and the performance is O(n log n)
  • if the pivot is the min or max value, the partitioning is unbalanced and the worst case scenario occurs with performance of O(n^2)

Therefore, the selection of the pivot is very important in the performance of the algorithm:

  • random pivot gives god average behavior, but not always
  • median of 3 or 5 random elements is more robust and avoids worst case scenarios
  • median of the whole array degrades performance to O(n log^2 n) because you must partially sort the array at every recursion step

Here is an example for how to compute the median of three random elements in a sequence of integers:

int median_of_three(std::vector<int> const & v)
{
    std::array<int,3> sample = { v.front(), v[v.size()/2], v.back() };
    auto mid = sample.begin() + 1;
    std::nth_element(sample.begin(), mid, sample.end());
    return *mid;
}

The following snippet shows an implementation of quicksort using a median of three values for the pivot:

#include <iostream>
#include <vector>
#include <print>
#include <algorithm>

template <typename Iter>
Iter median_of_three(Iter first, Iter last)
{
   Iter mid = first + (last - first) / 2;
   Iter last_elem = last - 1;

   if (*mid < *first) 
      std::iter_swap(mid, first);
   if (*last_elem < *first) 
      std::iter_swap(last_elem, first);
   if (*last_elem < *mid) 
      std::iter_swap(last_elem, mid);

   return mid;
}

template <typename Iter>
void quicksort(Iter first, Iter last)
{
   if (last - first <= 1) return;

   // Choose pivot using median of three and move it to the end
   Iter pivot_iter = median_of_three(first, last);
   std::iter_swap(pivot_iter, last - 1);
   auto pivot = *(last - 1);

   Iter left = first;
   Iter right = last - 2;

   while (true)
   {
      while (left <= right && *left < pivot) ++left;
      while (left <= right && *right > pivot) --right;
      if (left >= right) break;
      std::iter_swap(left, right);
      ++left;
      --right;
   }

   // put pivot back in place
   std::iter_swap(left, last - 1);

   // recurse on strict subranges
   quicksort(first, left);
   quicksort(left + 1, last);
}

int main()
{
   std::vector<int> v{ 9, 1, 8, 2, 7, 3, 6, 4, 5 };

   quicksort(v.begin(), v.end());

   for (int x : v)
      std::print("{} ", x);
   std::println();
}

The nth_element() algorithm has many applications with some typical cases presented here. Remember that, even though all the examples here used iterators-based algorithm, you can also use the ranges equivalent. More about the algorithm can be found here.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.