MapReduce之reduce过程浅析

博客中关于大数据的第三篇文章,国庆期间继续扒了扒MapReduce相关的源码,这篇文章来快速的总结一下reduce阶段。

总体概览

和Map阶段相似,Reduce阶段的入口我们可以看ReduceTask的run方法:

@Override
@SuppressWarnings("unchecked")
public void run(JobConf job, final TaskUmbilicalProtocol umbilical)
  throws IOException, InterruptedException, ClassNotFoundException {
  job.setBoolean(JobContext.SKIP_RECORDS, isSkipping());

  if (isMapOrReduce()) {
    copyPhase = getProgress().addPhase("copy");
    sortPhase  = getProgress().addPhase("sort");
    reducePhase = getProgress().addPhase("reduce");
  }
  // start thread that will handle communication with parent
  TaskReporter reporter = startReporter(umbilical);
  
  boolean useNewApi = job.getUseNewReducer();
  initialize(job, getJobID(), reporter, useNewApi);

  // check if it is a cleanupJobTask
  if (jobCleanup) {
    runJobCleanupTask(umbilical, reporter);
    return;
  }
  if (jobSetup) {
    runJobSetupTask(umbilical, reporter);
    return;
  }
  if (taskCleanup) {
    runTaskCleanupTask(umbilical, reporter);
    return;
  }
  
  // Initialize the codec
  codec = initCodec();
  RawKeyValueIterator rIter = null;
  ShuffleConsumerPlugin shuffleConsumerPlugin = null;
  
  Class combinerClass = conf.getCombinerClass();
  CombineOutputCollector combineCollector = 
    (null != combinerClass) ? 
   new CombineOutputCollector(reduceCombineOutputCounter, reporter, conf) : null;

  Class clazz =
        job.getClass(MRConfig.SHUFFLE_CONSUMER_PLUGIN, Shuffle.class, ShuffleConsumerPlugin.class);
         
  shuffleConsumerPlugin = ReflectionUtils.newInstance(clazz, job);
  LOG.info("Using ShuffleConsumerPlugin: " + shuffleConsumerPlugin);

  ShuffleConsumerPlugin.Context shuffleContext = 
    new ShuffleConsumerPlugin.Context(getTaskID(), job, FileSystem.getLocal(job), umbilical, 
                super.lDirAlloc, reporter, codec, 
                combinerClass, combineCollector, 
                spilledRecordsCounter, reduceCombineInputCounter,
                shuffledMapsCounter,
                reduceShuffleBytes, failedShuffleCounter,
                mergedMapOutputsCounter,
                taskStatus, copyPhase, sortPhase, this,
                mapOutputFile, localMapFiles);
  shuffleConsumerPlugin.init(shuffleContext);

  rIter = shuffleConsumerPlugin.run();

  // free up the data structures
  mapOutputFilesOnDisk.clear();
  
  sortPhase.complete();                         // sort is complete
  setPhase(TaskStatus.Phase.REDUCE); 
  statusUpdate(umbilical);
  Class keyClass = job.getMapOutputKeyClass();
  Class valueClass = job.getMapOutputValueClass();
  RawComparator comparator = job.getOutputValueGroupingComparator();

  if (useNewApi) {
    runNewReducer(job, umbilical, reporter, rIter, comparator, 
                  keyClass, valueClass);
  } else {
    runOldReducer(job, umbilical, reporter, rIter, comparator, 
                  keyClass, valueClass);
  }

  shuffleConsumerPlugin.close();
  done(umbilical, reporter);
}

run方法开头就很清晰的看到了reduce阶段要做的三件事:

  1. copy
  2. sort
  3. reduce

下面我们就按照这样的顺序来快速过一下吧。

copy阶段

我们还是先看ReduceTask的run方法:

@Override
@SuppressWarnings("unchecked")
public void run(JobConf job, final TaskUmbilicalProtocol umbilical)
  throws IOException, InterruptedException, ClassNotFoundException {
  job.setBoolean(JobContext.SKIP_RECORDS, isSkipping());

  if (isMapOrReduce()) {
    copyPhase = getProgress().addPhase("copy");
    sortPhase  = getProgress().addPhase("sort");
    reducePhase = getProgress().addPhase("reduce");
  }
  
  .....
    
  shuffleConsumerPlugin.init(shuffleContext);

  rIter = shuffleConsumerPlugin.run();

  .....
}

第一步,调用了shuffleConsumerPlugin的init方法去做一些初始化操作:

@Override
public void init(ShuffleConsumerPlugin.Context context) {
  this.context = context;

  this.reduceId = context.getReduceId();
  this.jobConf = context.getJobConf();
  this.umbilical = context.getUmbilical();
  this.reporter = context.getReporter();
  this.metrics = new ShuffleClientMetrics(reduceId, jobConf);
  this.copyPhase = context.getCopyPhase();
  this.taskStatus = context.getStatus();
  this.reduceTask = context.getReduceTask();
  this.localMapFiles = context.getLocalMapFiles();
  
  scheduler = new ShuffleSchedulerImpl(jobConf, taskStatus, reduceId,
      this, copyPhase, context.getShuffledMapsCounter(),
      context.getReduceShuffleBytes(), context.getFailedShuffleCounter());
  merger = createMergeManager(context);
}

可以看到里面创建了2个比较关键的类:scheduler和merger。scheduler的作用就是用于调度任务,而merger的作用则是用于进行第二阶段的sort。
我们回到ReduceTask的run方法,看shuffleConsumerPlugin的run:

@Override
public RawKeyValueIterator run() throws IOException, InterruptedException {
  
  .....

  // Start the map-completion events fetcher thread
  final EventFetcher eventFetcher = 
    new EventFetcher(reduceId, umbilical, scheduler, this,
        maxEventsToFetch);
  eventFetcher.start();
  
  // Start the map-output fetcher threads
  boolean isLocal = localMapFiles != null;
  final int numFetchers = isLocal ? 1 :
    jobConf.getInt(MRJobConfig.SHUFFLE_PARALLEL_COPIES, 5);
  Fetcher[] fetchers = new Fetcher[numFetchers];
  if (isLocal) {
    fetchers[0] = new LocalFetcher(jobConf, reduceId, scheduler,
        merger, reporter, metrics, this, reduceTask.getShuffleSecret(),
        localMapFiles);
    fetchers[0].start();
  } else {
    for (int i=0; i < numFetchers; ++i) {
      fetchers[i] = new Fetcher(jobConf, reduceId, scheduler, merger, 
                                     reporter, metrics, this, 
                                     reduceTask.getShuffleSecret());
      fetchers[i].start();
    }
  }
 
  .....
  
  return kvIter;
}

这个方法比较长,并且包含了前面说的copy和sort两个阶段,我们先来看前面的copy阶段,其中创建了2种线程:EventFetcher和Fetcher。其中Fetcher又根据是否是Uber模式分为LocalFetcher和Fetcher。我们先看EventFetcher:

@Override
public void run() {
  int failures = 0;
  LOG.info(reduce + " Thread started: " + getName());
  
  try {
    while (!stopped && !Thread.currentThread().isInterrupted()) {
      try {
        int numNewMaps = getMapCompletionEvents();
        failures = 0;
        if (numNewMaps > 0) {
          LOG.info(reduce + ": " + "Got " + numNewMaps + " new map-outputs");
        }
        LOG.debug("GetMapEventsThread about to sleep for " + SLEEP_TIME);
        if (!Thread.currentThread().isInterrupted()) {
          Thread.sleep(SLEEP_TIME);
        }
      } catch (InterruptedException e) {
        LOG.info("EventFetcher is interrupted.. Returning");
        return;
      } catch (IOException ie) {
        LOG.info("Exception in getting events", ie);
        // check to see whether to abort
        if (++failures >= MAX_RETRIES) {
          throw new IOException("too many failures downloading events", ie);
        }
        // sleep for a bit
        if (!Thread.currentThread().isInterrupted()) {
          Thread.sleep(RETRY_PERIOD);
        }
      }
    }
  } catch (InterruptedException e) {
    return;
  } catch (Throwable t) {
    exceptionReporter.reportException(t);
    return;
  }
}

从它的run方法可以知道,EventFetcher的作用就是通过rpc调用去获取已经完成的map task数量。
我们再看一下Fetcher:

public void run() {
  try {
    while (!stopped && !Thread.currentThread().isInterrupted()) {
      MapHost host = null;
      try {
        // If merge is on, block
        merger.waitForResource();

        // Get a host to shuffle from
        host = scheduler.getHost();
        metrics.threadBusy();

        // Shuffle
        copyFromHost(host);
      } finally {
        if (host != null) {
          scheduler.freeHost(host);
          metrics.threadFree();            
        }
      }
    }
  } catch (InterruptedException ie) {
    return;
  } catch (Throwable t) {
    exceptionReporter.reportException(t);
  }
}

首先判断是否在进行merge,如果在merge则阻塞,否则通过copyFromHost方法,使用一个http调用去获取map task存储在文件中的数据。
我们再深入的看一下CopyFromHost方法:

@VisibleForTesting
protected void copyFromHost(MapHost host) throws IOException {
  // reset retryStartTime for a new host
  retryStartTime = 0;
  // Get completed maps on 'host'
  List maps = scheduler.getMapsForHost(host);
  
  // Sanity check to catch hosts with only 'OBSOLETE' maps, 
  // especially at the tail of large jobs
  if (maps.size() == 0) {
    return;
  }
  
  if(LOG.isDebugEnabled()) {
    LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: "
      + maps);
  }
  
  // List of maps to be fetched yet
  Set remaining = new HashSet(maps);
  
  // Construct the url and connect
  URL url = getMapOutputURL(host, maps);
  DataInputStream input = openShuffleUrl(host, remaining, url);
  if (input == null) {
    return;
  }
  
  try {
    // Loop through available map-outputs and fetch them
    // On any error, faildTasks is not null and we exit
    // after putting back the remaining maps to the 
    // yet_to_be_fetched list and marking the failed tasks.
    TaskAttemptID[] failedTasks = null;
    while (!remaining.isEmpty() && failedTasks == null) {
      try {
        failedTasks = copyMapOutput(host, input, remaining, fetchRetryEnabled);
      } catch (IOException e) {
        //
        // Setup connection again if disconnected by NM
        connection.disconnect();
        // Get map output from remaining tasks only.
        url = getMapOutputURL(host, remaining);
        input = openShuffleUrl(host, remaining, url);
        if (input == null) {
          return;
        }
      }
    }
    
    if(failedTasks != null && failedTasks.length > 0) {
      LOG.warn("copyMapOutput failed for tasks "+Arrays.toString(failedTasks));
      scheduler.hostFailed(host.getHostName());
      for(TaskAttemptID left: failedTasks) {
        scheduler.copyFailed(left, host, true, false);
      }
    }

    // Sanity check
    if (failedTasks == null && !remaining.isEmpty()) {
      throw new IOException("server didn't return all expected map outputs: "
          + remaining.size() + " left.");
    }
    input.close();
    input = null;
  } finally {
    if (input != null) {
      IOUtils.cleanup(LOG, input);
      input = null;
    }
    for (TaskAttemptID left : remaining) {
      scheduler.putBackKnownMapOutput(host, left);
    }
  }
}

其中最重要的就是调用了copyMapOutput方法:

private TaskAttemptID[] copyMapOutput(MapHost host,
                              DataInputStream input,
                              Set remaining,
                              boolean canRetry) throws IOException {
  MapOutput mapOutput = null;
  TaskAttemptID mapId = null;
  long decompressedLength = -1;
  long compressedLength = -1;
  
  try {
    long startTime = Time.monotonicNow();
    int forReduce = -1;
    //Read the shuffle header
    try {
      ShuffleHeader header = new ShuffleHeader();
      header.readFields(input);
      mapId = TaskAttemptID.forName(header.mapId);
      compressedLength = header.compressedLength;
      decompressedLength = header.uncompressedLength;
      forReduce = header.forReduce;
    } catch (IllegalArgumentException e) {
      badIdErrs.increment(1);
      LOG.warn("Invalid map id ", e);
      //Don't know which one was bad, so consider all of them as bad
      return remaining.toArray(new TaskAttemptID[remaining.size()]);
    }

    InputStream is = input;
    is = CryptoUtils.wrapIfNecessary(jobConf, is, compressedLength);
    compressedLength -= CryptoUtils.cryptoPadding(jobConf);
    decompressedLength -= CryptoUtils.cryptoPadding(jobConf);
    
    // Do some basic sanity verification
    if (!verifySanity(compressedLength, decompressedLength, forReduce,
        remaining, mapId)) {
      return new TaskAttemptID[] {mapId};
    }
    
    if(LOG.isDebugEnabled()) {
      LOG.debug("header: " + mapId + ", len: " + compressedLength + 
          ", decomp len: " + decompressedLength);
    }
    
    // Get the location for the map output - either in-memory or on-disk
    try {
      mapOutput = merger.reserve(mapId, decompressedLength, id);
    } catch (IOException ioe) {
      // kill this reduce attempt
      ioErrs.increment(1);
      scheduler.reportLocalError(ioe);
      return EMPTY_ATTEMPT_ID_ARRAY;
    }
    
    // Check if we can shuffle *now* ...
    if (mapOutput == null) {
      LOG.info("fetcher#" + id + " - MergeManager returned status WAIT ...");
      //Not an error but wait to process data.
      return EMPTY_ATTEMPT_ID_ARRAY;
    } 
    
    // The codec for lz0,lz4,snappy,bz2,etc. throw java.lang.InternalError
    // on decompression failures. Catching and re-throwing as IOException
    // to allow fetch failure logic to be processed
    try {
      // Go!
      LOG.info("fetcher#" + id + " about to shuffle output of map "
          + mapOutput.getMapId() + " decomp: " + decompressedLength
          + " len: " + compressedLength + " to " + mapOutput.getDescription());
      mapOutput.shuffle(host, is, compressedLength, decompressedLength,
          metrics, reporter);
    } catch (java.lang.InternalError e) {
      LOG.warn("Failed to shuffle for fetcher#"+id, e);
      throw new IOException(e);
    }
    
    // Inform the shuffle scheduler
    long endTime = Time.monotonicNow();
    // Reset retryStartTime as map task make progress if retried before.
    retryStartTime = 0;
    
    scheduler.copySucceeded(mapId, host, compressedLength, 
                            startTime, endTime, mapOutput);
    // Note successful shuffle
    remaining.remove(mapId);
    metrics.successFetch();
    return null;
  } catch (IOException ioe) {
    
    if (canRetry) {
      checkTimeoutOrRetry(host, ioe);
    } 
    
    ioErrs.increment(1);
    if (mapId == null || mapOutput == null) {
      LOG.warn("fetcher#" + id + " failed to read map header" + 
               mapId + " decomp: " + 
               decompressedLength + ", " + compressedLength, ioe);
      if(mapId == null) {
        return remaining.toArray(new TaskAttemptID[remaining.size()]);
      } else {
        return new TaskAttemptID[] {mapId};
      }
    }
      
    LOG.warn("Failed to shuffle output of " + mapId + 
             " from " + host.getHostName(), ioe); 

    // Inform the shuffle-scheduler
    mapOutput.abort();
    metrics.failedFetch();
    return new TaskAttemptID[] {mapId};
  }
}

这个方法比较复杂,我们看关键的步骤:

mapOutput = merger.reserve(mapId, decompressedLength, id);

这一步通过前面在init中创建的MergeManagerImpl的reverse方法去获取一个mapOutput:

@Override
public synchronized MapOutput reserve(TaskAttemptID mapId, 
                                           long requestedSize,
                                           int fetcher
                                           ) throws IOException {
  if (!canShuffleToMemory(requestedSize)) {
    LOG.info(mapId + ": Shuffling to disk since " + requestedSize + 
             " is greater than maxSingleShuffleLimit (" + 
             maxSingleShuffleLimit + ")");
    return new OnDiskMapOutput(mapId, reduceId, this, requestedSize,
                                    jobConf, mapOutputFile, fetcher, true);
  }
  
  if (usedMemory > memoryLimit) {
    LOG.debug(mapId + ": Stalling shuffle since usedMemory (" + usedMemory
        + ") is greater than memoryLimit (" + memoryLimit + ")." + 
        " CommitMemory is (" + commitMemory + ")"); 
    return null;
  }
  
  // Allow the in-memory shuffle to progress
  LOG.debug(mapId + ": Proceeding with shuffle since usedMemory ("
      + usedMemory + ") is lesser than memoryLimit (" + memoryLimit + ")."
      + "CommitMemory is (" + commitMemory + ")"); 
  return unconditionalReserve(mapId, requestedSize, true);
}

可以看到这里有一个判断:

private boolean canShuffleToMemory(long requestedSize) {
    return (requestedSize < maxSingleShuffleLimit); 
  }

意思就是map产生的文件大小是否符合我们的要求,如果小于要求则copy到内存中,否则直接copy到文件里,这里的maxSingleShuffleLimit是通过config设置的,我们可以自行根据自己的实际情况进行重写。
通过这个方法我们知道,copy阶段存在内存copy和文件copy两种模式,根据我们设置的阈值去区分,好处是一些小文件copy到内存里,在之后的sort阶段可以直接在内存中排序,减少文件io。
获取到mapOutput,会调用起shuffle方法去进行真正的数据copy,具体代码就不分析了,无非就是InputStream写内存或者写文件。
在copy完毕之后,会调用scheduler的copySuccess方法:

public synchronized void copySucceeded(TaskAttemptID mapId,
                                       MapHost host,
                                       long bytes,
                                       long startMillis,
                                       long endMillis,
                                       MapOutput output
                                       ) throws IOException {
  .....
  
  if (!finishedMaps[mapIndex]) {
    output.commit();
    finishedMaps[mapIndex] = true;
    shuffledMapsCounter.increment(1);
    if (--remainingMaps == 0) {
      notifyAll();
    }

    .....
  }
}

这个方法最关键的两步就是会改变2个变量:finishedMaps和remainingMaps。
根据上面的学习我们可以得知,copy阶段最主要的功能就是获取map task结束后存储在文件中的数据,根据一定的规则决定是copy到内存中或者是copy到文件里。

sort阶段

我们回到shuffleConsumerPlugin的run方法,看一下后续的操作:

@Override
public RawKeyValueIterator run() throws IOException, InterruptedException {
  
  .....
  
  // Wait for shuffle to complete successfully
  while (!scheduler.waitUntilDone(PROGRESS_FREQUENCY)) {
    reporter.progress();
    
    synchronized (this) {
      if (throwable != null) {
        throw new ShuffleError("error in shuffle in " + throwingThreadName,
                               throwable);
      }
    }
  }

  // Stop the event-fetcher thread
  eventFetcher.shutDown();
  
  // Stop the map-output fetcher threads
  for (Fetcher fetcher : fetchers) {
    fetcher.shutDown();
  }
  
  // stop the scheduler
  scheduler.close();

  copyPhase.complete(); // copy is already complete
  taskStatus.setPhase(TaskStatus.Phase.SORT);
  reduceTask.statusUpdate(umbilical);

  // Finish the on-going merges...
  RawKeyValueIterator kvIter = null;
  try {
    kvIter = merger.close();
  } catch (Throwable e) {
    throw new ShuffleError("Error while doing final merge " , e);
  }

  .....
  
  return kvIter;
}

可以看到有一个while循环,条件是:

scheduler.waitUntilDone(PROGRESS_FREQUENCY)

我们去看一下具体的实现:

@Override
public synchronized boolean waitUntilDone(int millis
                                          ) throws InterruptedException {
  if (remainingMaps > 0) {
    wait(millis);
    return remainingMaps == 0;
  }
  return true;
}

从文字上我们也很好理解,就是去判断是否还有剩余的map没有处理,如果已经全部处理完了(remainingMaps == 0)就执行后续的操作。而copy阶段的分析中我们已经知道,在copy完成之后,会调用scheduler的copySuccess方法更新remainingMaps。
真正的sort是在merge的close方法中:

kvIter = merger.close();

这个方法比较复杂,我们就不进去看来,其中主要的逻辑就是先将copy阶段产生的内存数据进行排序然后输出到文件中,再将其和copy阶段产生的文件进行堆排序,然后进行文件的merge。

reduce阶段

copy和sort阶段完成之后,让我们回到ReduceTask的run方法:

@Override
@SuppressWarnings("unchecked")
public void run(JobConf job, final TaskUmbilicalProtocol umbilical)
  throws IOException, InterruptedException, ClassNotFoundException {
  job.setBoolean(JobContext.SKIP_RECORDS, isSkipping());

  .....

  if (useNewApi) {
    runNewReducer(job, umbilical, reporter, rIter, comparator, 
                  keyClass, valueClass);
  } else {
    runOldReducer(job, umbilical, reporter, rIter, comparator, 
                  keyClass, valueClass);
  }

  shuffleConsumerPlugin.close();
  done(umbilical, reporter);
}

这里的代码和map阶段就几乎是一模一样了,runNewReducer中就是读取数据,然后调用我们自己的Reduce类去进行操作,具体的代码就不看了。