Tuesday, June 9, 2009

Tuning mergesorts for speed, part 2(a)

These are the routines that I was using to generate 3-way and 4-way direction-switching forecasting mergesorts... The main routine is generateFastMergesort, which takes a parameter indicating the order of merge (e.g. 3 for 3-way). To use them #define debug to std::cout (if you generate the (r)-mergesort its (r)-merge routines will call the (r-1)-way merge routines.

I've only tested it for 2,3,4 and 5-way mergesort generation (I stopped at that point because the 5-way mergesort was ridiculously long, and considerably slower than the 4-way mergesort. Even the 4-way mergesort is slower than the 3-way mergesort, on my machine anyway, for sorting integers possibly because the machine's instruction cache was not big enough). While that was true for single-threaded routines, it mightn't be so for multi-core routines. For multi-core sorting routines, main memory access really does become the bottleneck, and a 4-way mergesort should gain ground (and perhaps win out).

void generateForecastBranching(int r, int backward, int bestSoFar, int indent, int tryFrom)
{
  //
  // at this point we already know whether merging is proceeding left-to-right or
  // right-to-left (if righ-to-left, backward will be 1), but that is all.
  // this function generates code to go at the very front of an r-way merge.  it 
  // generates (some of) the goto statements for branching to r different labels, 
  // of the form
  //
  //   e[ex]_[dir]:
  //
  // where [ex] is the input to be exhausted first (0 through r-1) and dir
  // is either back or fwd.
  //
  for (int i=0;i<indent;++i) debug<<" ";
  if (tryFrom==r)
  {
    debug << "goto e" << bestSoFar << (backward ? "_back;\n" : "_fwd;\n");
  }
  else
  {
    if (backward)
      debug << "if (stop" << tryFrom << "[1] < stop" << bestSoFar<< "[1])\n";
    else
      debug << "if (stop" << bestSoFar << "[-1] <= stop" << tryFrom << "[-1])\n";
    generateForecastBranching(r, backward, bestSoFar, indent+2, tryFrom+1);
    for (int i=0;i<indent;++i) debug<<" ";
    debug << "else\n";
    generateForecastBranching(r, backward, tryFrom, indent+2, tryFrom+1);
  }
}

void generateFrontBranching( int *perm, int r, int e , int d, int indent, int p, int lo, int hi )
{
  //
  //here, we know which input will be exhausted first, and we're generating (some)
  //of the code for branching to one of the factorial(r) read states.  
  //r = radix, e= list exhausted first, d=direction (1 == backwards, 0 == forwards)
  //p = input to place, lo = earliest it could come before of elements ordered already
  //hi = latest it could come after, of elements ordered already
  //(this is actually the ickiest bit of generating a mergesort, and the code used
  // here is not particularly efficient when r is large.
  //This generates the code that branches out from labels of the form:
  //
  //   e[ex]_[dir]:
  //
  //(see previous function) to labels of the form
  //
  //   e[ex]_[dir]_[perm]:
  //
  //where [perm] indicates the relative order in which the "next" elements 
  //from each of the inputs are to be copied to the destination 
  //([perm] consists of r letters, A for input 0, B for input 1, and so on).
  //
  for (int i=0;i<indent;++i) debug<<" ";
  if (p==r)
  {
    debug << "goto e" << e << ( d ? "_back_" : "_fwd_" ) ;
    for (int i=0; i<r;++i)
      debug << (const char)('A' + perm[i]);
    debug << ";\n";
    return;
  }
  else
  {
    int m = (lo + hi)/2;
    if (d==0)
      debug << "if (*start" << p << " < *start" << perm[m] << ")\n";
    else
      debug << "if (*start" << perm[m] << " <= *start" << p << ")\n";
    if (lo==m) 
    {
      //p comes before perm[m]
      for (int i=p-1;i>=m;--i) perm[i+1]=perm[i];
      perm[m]=p;
      generateFrontBranching(perm, r, e, d, indent+2, p+1, 0, p);
      for (int i=m;i<p;++i) perm[i]=perm[i+1]; 
      perm[p]=p;
    }
    else
      generateFrontBranching(perm, r, e, d, indent+2, p, lo, m-1);

    for (int i=0;i<indent;++i) debug<<" ";
    debug << "else\n";
    if (hi==m)
    {
      //p comes after perm[m]
      for (int i=p-1;i>=m+1;--i) perm[i+1]=perm[i];
      perm[m+1]=p;
      generateFrontBranching(perm, r, e, d, indent+2, p+1, 0, p);
      for (int i=m+1;i<p;++i) perm[i]=perm[i+1];
      perm[p]=p;
    }
    else
      generateFrontBranching(perm, r, e, d, indent+2, p, m+1, hi);
  }
}

void branchOut(int e, int* perm, int r, int backward, int indent, int lo, int hi)
{
  //
  // generate (some of) the branching code to determine which of the other
  // factorial(r) states to go to, given that we already know which
  // of the inputs is going to be exhausted, e, and what state we were
  // in previously (on perm[0] indicates which of the inputs we just
  // read from, and perm[1..(r-1)] indicate the order in which the
  // other (r-1) elements were in the state we were just in.
  // lo and hi are indexes into perm, and indicate which other inputs
  // haven't been compared with the next element from the input that
  // we just read from.
  //
  // This generates the code that branches after the copy-and-move-on
  // code after a
  //
  // e[ex]_[dir]_[perm]:
  //
  // label to each of the possible r successor states.  For example, in
  // the left-to-right merging part of a three-way merge, where input 0
  // will be exhausted first, and a record from input 0 was last read,
  // and the next record from input 1 is to copied to the output before 
  // the next record from input 2, it generates the code to branch from
  //
  // e0_fwd_ABC:
  //
  // to whichever of the successor states
  //
  // e0_fwd_ABC:
  // e0_fwd_BAC:
  // e0_fwd_BCA:
  //
  // should be reached next.
  //
  int mid = (lo+hi)/2;
  const char *szDirection = backward ? "_back_" : "_fwd_";

  if (lo<=hi)
  {
    for (int i=0;i<indent;++i) debug<<" ";
    int le = (perm[0]<perm[mid]) ? (1-backward) : backward;
    debug << "if (*start" << (backward ? perm[mid] : perm[0]) 
        << ( le ? "<=" : "<") 
        << "*start" << (backward ? perm[0] : perm[mid])
        << ")\n";
  }
  if (lo>=hi)
  {
    for (int i=0;i<indent+2;++i) debug<<" ";
    debug << "goto e" << e << szDirection; 
    for (int i=1;i<r;++i) 
    {
      if (lo==i) debug << (const char)('A' + perm[0]);
      debug << (const char)('A' + perm[i]);
    }
    if (lo==r) debug << (const char)('A' + perm[0]);
    debug << ";\n";
    if (lo==hi)
    {
      for (int i=0;i<indent;++i) debug<<" ";
      debug << "else\n";
      for (int i=0;i<indent+2;++i) debug<<" ";
      debug << "goto e" << e << szDirection;
      for (int i=1;i<r;++i) 
      {
        debug << (const char)('A' + perm[i]);
        if (lo==i) debug << (const char)('A' + perm[0]);
      }
      debug << ";\n";
    }
    return;
  }
  branchOut(e, perm, r, backward, indent+2, lo, mid-1);
  for (int i=0;i<indent;++i) debug<<" ";
  debug << "else\n";
  branchOut(e, perm, r, backward, indent+2, mid+1, hi);    
}

void generateFastMergesort(int r)
{
  //
  //Notes:  Generates a (take a deep breath)...
  //          forecasting switchback order-in-program-counter r-way mergesort routine
  //

  for (int d=0;d<2;++d)  //Twice; once for Forward, once for Backward merge routine
  {  
    const char *szDirection = d ? "Backward" : "Forward";  
    const char *szBump      = d ? "--" : "++";

    debug << "template <class T> int mergeRadix" << r << "Fast" << szDirection << "(";
    for (int i=0;i<r;++i) debug << "T* start" << i << ", T* stop" << i << ",";  
    debug << "T* dest)\n{\n";
    int *perm = new int [r];

    for (int i=0;i<r;++i) perm[i]=i;
    generateForecastBranching( r, d, 0, 2, 1);

    for (int e=0;e<r;++e) //e==input that runs out first
    {
      debug << "  e" << e << (d ? "_back:\n" : "_fwd:\n");
      for (int i=0;i<r;++i) perm[i]=i;
      generateFrontBranching( perm, r, e, d, 4, 1, 0, 0);

      for (int i=0;i<r;++i) perm[i]=i;
      int nStates=1;
      for (int i=1;i<=r;++i) nStates*=i;
      for (;nStates>0;--nStates)
      {
        debug << "  e" << e << "_" << (d ? "back" : "fwd" ) << "_";
        for (int i=0;i<r;++i) debug << (const char)('A' + perm[i]);
        debug << ":\n";
        debug << "    *dest" << szBump << " = *start" << perm[0] << szBump << ";\n";
          //copy next record to output area

        if (e==perm[0]) //if the record comes from the input to be exhausted first
        {          
          //check if the input has been exhausted
          debug << "    if (start" << e << "==stop" << e << ")\n";
          //and if so, hand over the remaining inputs to a merge routine that
          //will deal with the remaining r-1 inputs
          debug << "      return mergeRadix" << (r-1) << "Fast" << szDirection << "(";
          for (int i=0; i<r; ++i) if (i!=e)
            debug << "start" << i << ", stop" << i << ", ";
          debug << "dest);\n";
        }
        //otherwise, figure out what the next state should be and branch to it
        branchOut( e, perm, r, d, 4, 1, r-1);

        //next permutation
        for (int x=0;x<r-1;x++)
        {    
          int *permFromX = perm+x;
          rotateArrayRight(permFromX, r-x, 1);
          if (perm[x]!=x) break;
        }
      }
    }
    delete [] perm;

    debug << "}\n\n";
  }

  //generate stock standard code for the (external) mergesort, which is given a work area...
  debug << "template <class T> void mergeSortExternalRadix" << r << "Fast(T *src, int count, T *dest, bool isSourceInput)\n{\n";
  debug << "  if (count<64)\n";
  debug << "  {\n";
  debug << "    if (isSourceInput) for (int i=0;i<count;++i) dest[i]=src[i];\n";
  debug << "    insertionSort(dest, count);\n";
  debug << "  }\n";
  debug << "  else\n";
  debug << "  {\n";
  debug << "    int step = count / " << r << ";\n";
  debug << "    for (int subList=0; subList<" << r << "; ++subList)\n";
  debug << "    {\n";
  debug << "      mergeSortExternalRadix" << r << "Fast(dest + (step*subList) "
        << ", (subList<" << (r-1) << ") ? step : (count - step*" << (r-1) << ")"
        << ", src + (step*subList), !isSourceInput);\n";
  debug << "    }\n";
  debug << "    if (isSourceInput)\n";
  debug << "      mergeRadix" << r << "FastForward(src, ";
  for (int x=1;x<r;++x)
  {
    debug << "src + (step*" << x << "), src + (step*" << x << "), ";
  }
  debug << "src + count, dest);\n";
  debug << "    else\n";
  debug << "      mergeRadix" << r << "FastBackward(";
  for (int x=1;x<r;++x)
  {
    debug << "src + (step*" << x << ")-1, src+(step*" << (x-1) << ")-1, ";
  }
  debug << "src+count-1, src+(step*" << (r-1) << ")-1, dest + count - 1);\n";
  debug << "  }\n";
  debug << "}\n\n";

  //and stock-standard code for the templated mergesort function that generates a work area,
  //calls the external mergesort function, and then throws away the work area
  debug << "template <class T> void mergeSortRadix" << r << "Fast(T *arr, int count)\n{\n";
  debug << "  if (count<64) { insertionSort(arr,count); return; }\n";
  debug << "  T *a2 = new T[count];\n";
  debug << "  mergeSortExternalRadix" << r << "Fast(a2, count, arr, false);\n";
  debug << "  delete [] a2;\n";
  debug << "}\n\n";
}

Generally, it's okay to use less tuned (r-1)-way merge routines, as they do not have much to do, in the average case. For that reason the tuned 3-way mergesort, that I have been mucking about with, calls de-tuned 2-way merge routines, like so:

template<class T> void mergeRadix2External(T *a, T *aStop, T *b, T *bStop, T *dest)
{
  while (a<aStop && b<bStop)
  {
    if (*a<=*b)
    {
      *dest++ = *a++;
    }
    else
    {
      *dest++ = *b++;
    }
  }
  while (a<aStop) *dest++=*a++;
  while (b<bStop) *dest++=*b++;
}

#define mergeRadix2FastForward(a,aStop,b,bStop,dest)  ( mergeRadix2External(a,aStop,b,bStop,dest) , 0 )
template <class T> int mergeRadix2FastBackward(T* a, T* aStop, T* b, T* bStop, T* dest)
{
  while (aStop<a && bStop<b)
  {
    if (*a<=*b)
    {
      *dest-- = *b--;
    }
    else
    {
      *dest-- = *a--;
    }
  }
  while (aStop<a) *dest-- = *a--;
  while (bStop<b) *dest-- = *b--;
  return 0;
}

Come to think of it, maybe that was part of the problem with the 4-way merge. Probably it should have been calling the untuned 3-way merge routine (because if it calls the tuned 3-way merge routine that will be forecasting etc. on very very short lists, usually one or two elements each, which is... well... sort of pointless).

No comments:

Post a Comment