Saturday, November 20, 2010

Merge Sort Improvement?



C++ STL uses a trick to improve the performance of merge sort. Which is, when the number of elements in a call goes down to a few 50 or something near, it uses insertion sort to sort those elements instead of recurring farther. However, this doesn't seem to be much helpful when we write the merge sort procedure manually.

Probably this happens to STL because of, being OOP nature, STL methods have huge overhead on mostly each of them, so allocating more recursive calls might turn out to be more costly than a small O(k^2) sub-routine. As we have seen on the previous posts, recursive calls of merge sort procedure creates a binary tree like structure and the deeper it goes, the number of nodes increases dramatically. Moreover, allocating new functions on call stack always has a overhead for computer systems. So, the just change the algorithm a bit so that the bottom few depths are cut from the tree reducing huge number of nodes from it, i.e. reducing huge overhead which in result, improves performance.


So they changes the algorithm of MERGE-SORT() a bit:

MERGE-SORT(A, p, r):
if r - p < k ;k is the tolerance limit
INSERTION-SORT(A, p, r)
return
if p < r
then q := (r + p) / 2
MERGE-SORT(A, p, q)
MERGE-SORT(A, q+1, r)
MERGE(A, p, q, r)

Now, what happens when we use it manually? I tried it using a huge range and used k = 15, which means, when the number of elements are below 15, it will use insertion sort. I found interesting result in this experiment. Insertion sort seems better under certain range, but after that, it keeps getting slower. So, the value k is a tread off between the overhead of recursive calls and running time. This graph shows the result of my experiment, certainly it will vary if you try to do the same.

[click on the image to get a clear view]

Here is a simple java tester I used for this comparison. Have a look:

import java.util.*;
import java.io.*;

public class Main {

static final boolean useInsertion = false;
static final int limit = 15;

public static void main(String[] args) throws IOException {
Scanner stdin = new Scanner(new FileReader("in.txt"));
PrintWriter out = new PrintWriter(new FileWriter("out.txt"));
int n = stdin.nextInt();
int[] a = new int[n];
for(int i = 0; i < n; i++) a[i] = stdin.nextInt();
out.println("Elements: " + n);
long start = System.nanoTime();
mergeSort(a, 0, n-1);
long end = System.nanoTime();
for(int i = 0; i < n; i++) out.println(a[i] + " ");
out.println("Elapsed Time: " +(double)(end - start)/1000000.0 + "ms.");
out.flush();
stdin.close();
out.close();
}

static void mergeSort(int[] a, int p, int r) {
if(useInsertion && r-p < limit) {
insertionSort(a, p, r);
return;
}
if(p < r) {
int q = (p + r) / 2;
mergeSort(a, p, q);
mergeSort(a, q+1, r);
merge(a, p, q, r);
}
}

static void merge(int[] a, int p, int q, int r) {
int n1 = q-p+1, n2 = r-q;
int[] L = new int[n1];
int[] R = new int[n2];
for(int i = 0; i < n1; i++) L[i] = a[p+i];
for(int j = 0; j < n2; j++) R[j] = a[q+j+1];
for(int k = p, i = 0, j = 0; k <= r; k++) {
if(j >= n2 || (i < n1 && L[i] <= R[j])) a[k] = L[i++];
else a[k] = R[j++];
}
}

static void insertionSort(int[] a, int p, int r) {
for(int i = p+1; i <= r; i++) {
int t = a[i];
for(int j = i - 1; j >= p; j--) {
if(t > a[j]) break;
a[j+1] = a[j];
}
a[j+1] = t;
}
}
}

Change the value of static final boolean useInsertion = false; to 'true' to enable using insertion sort and change the value of static final int limit = 15; to suitable limit, this is the number of elements when to apply insertion sort.

Keep digging!

No comments:

Post a Comment