Community Articles
Find and share helpful community-sourced technical articles
Labels (1)

This first in a series of article lists 3 easy ways in which you can optimize your Spark code. This can be summed up as follows:

  1. Use ReduceByKey over GroupByKey
  2. Be vary of Actions
  3. Gracefully Deal with Bad Quality Data

Use ReduceByKey over GroupByKey

Let's look at two different ways to compute word counts, one using reduceByKey and the other using groupByKey:

val words = Array("one", "two", "two", "three", "three", "three")
val wordPairsRDD = sc.parallelize(words).map(word => (word, 1))

val wordCountsWithReduce = wordPairsRDD
  .reduceByKey(_ + _)

val wordCountsWithGroup = wordPairsRDD
  .map(t => (t._1, t._2.sum))

While both of these functions will produce the correct answer, the reduceByKey example works much better on a large dataset. That's because Spark knows it can combine output with a common key on each partition before shuffling the data.

With reduceByKey, the pairs on the same machine with the same key are combined (by using the lamdba function passed into reduceByKey) before the data is shuffled. Then the lamdba function is called again to reduce all the values from each partition to produce one final result.

On the other hand, when calling groupByKey - all the key-value pairs are shuffled around. This is a lot of unnessary data to being transferred over the network.

To determine which machine to shuffle a pair to, Spark calls a partitioning function on the key of the pair. Spark spills data to disk when there is more data shuffled onto a single executor machine than can fit in memory. However, it flushes out the data to disk one key at a time - so if a single key has more key-value pairs than can fit in memory, an out of memory exception occurs. This will be more gracefully handled in a later release of Spark so the job can still proceed, but should still be avoided - when Spark needs to spill to disk, performance is severely impacted.

You can imagine that for a much larger dataset size, the difference in the amount of data you are shuffling becomes more exaggerated and different between reduceByKey and groupByKey.

Here are more functions to prefer over groupByKey:

  • combineByKey can be used when you are combining elements but your return type differs from your input value type.
  • foldByKey merges the values for each key using an associative function and a neutral "zero value".

Be Vary of Actions

If your RDD is so large that all of it's elements won't fit in memory on the drive machine, don't do this:

val values = myVeryLargeRDD.collect()

Collect will attempt to copy every single element in the RDD onto the single driver program, and then run out of memory and crash.

Instead, you can make sure the number of elements you return is capped by calling take or takeSample, or perhaps filtering or sampling your RDD.

Similarly, be cautious of these other actions as well unless you are sure your dataset size is small enough to fit in memory:

  • countByKey
  • countByValue
  • collectAsMap

If you really do need every one of these values of the RDD and the data is too big to fit into memory, you can write out the RDD to files or export the RDD to a database that is large enough to hold all the data.

Gracefully Deal with Bad Quality Data

When dealing with vast amounts of data, a common problem is that a small amount of the data is malformed or corrupt. Using a filter transformation, you can easily discard bad inputs, or use a maptransformation if it's possible to fix the bad input. Or perhaps the best option is to use a flatMapfunction where you can try fixing the input but fall back to discarding the input if you can't.

Let's consider the json strings below as input:

input_rdd = sc.parallelize(["{\"value\": 1}",  # Good
                            "bad_json",  # Bad
                            "{\"value\": 2}",  # Good
                            "{\"value\": 3"  # Missing an ending brace.

If we tried to input this set of json strings to a sqlContext, it would clearly fail due to the malformed input's.

# The above command will throw an error.

Instead, let's try fixing the input with this python function:

def try_correct_json(json_string):
    # First check if the json is okay.
    return [json_string]
  except ValueError:
      # If not, try correcting it by adding a ending brace.
      try_to_correct_json = json_string + "}"
      return [try_to_correct_json]
    except ValueError:
      # The malformed json input can't be recovered, drop this input.
      return []

Now, we can apply that function to fix our input and try again. This time we will succeed to read in three inputs:

corrected_input_rdd = input_rdd.flatMap(try_correct_json)
sqlContext.sql("select * from valueTable").collect() 

# Returns [Row(value=1), Row(value=2), Row(value=3)]