Spark mapPartitions

Understanding Map Partition in Spark


Problem : Given a parquet file having Employee data , one needs to find the maximum Bonus earned by each  employee and save the data back in parquet      (Github)

1. Parquet file (Huge file on HDFS ) ,  Avro  Schema :

|– emp_id: integer (nullable = false)
|– emp_name: string (nullable = false)
|– emp_country: string (nullable = false

|– emp_bonus: string (nullable = true)
|– subordinates: map (nullable = true)
|    |– key: string
|    |– value: string (valueContainsNull = false)


Dependency :


Code :


import com.databricks.spark.avro.SchemaConverters;
import org.apache.avro.AvroRuntimeException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.TreeSet;

public class EmployeeMaxBonusMapPartition extends Configured implements Tool, Closeable, Serializable {

    public static final String INPUT_PATH = "spark.input.path";
    public static final String OUTPUT_PATH = "spark.output.path";
    public static final String IS_RUN_LOCALLY = "";
    public static final String DEFAULT_FS = "spark.default.fs";
    public static final String NUM_PARTITIONS = "spark.num.partitions";

    // Just check because of a function use , the outer class is forced to be serialized
    // Example which throws light of serialization of Lambda function .
    private transient SQLContext sqlContext;
    private transient JavaSparkContext javaSparkContext;

    protected <T> JavaSparkContext getJavaSparkContext(final boolean isRunLocal,
                                                       final String defaultFs,
                                                       final Class<T> tClass) {
        final SparkConf sparkConf = new SparkConf()
                //Set spark conf here , 
                //After one gets spark context you can set hadoop configuration for InputFormats
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");

        if (isRunLocal) {

        final JavaSparkContext sparkContext = new JavaSparkContext(sparkConf);

        if (defaultFs != null) {
            sparkContext.hadoopConfiguration().set("fs.defaultFS", defaultFs);

        return sparkContext;

    // Convert Row to Avro POJO (Employee)
    public Employee convert(Row row) {
        try {
            // Employee Schema => ParquetRow Schema =>Row Schema

            Employee avroInstance = new Employee();

            for (StructField field : row.schema().fields()) {

                //row.fieldIndex => pos of the field Name in the schema
                avroInstance.put(, row.get(row.fieldIndex(;


            return avroInstance;

        } catch (Exception e) {
            throw new AvroRuntimeException("Avro POJO  building failed ", e);

    public int run(String[] args) throws Exception {

        //The arguments passed has been split into Key value by ToolRunner
        Configuration conf = getConf();
        String inputPath = conf.get(INPUT_PATH);

        String outputPath = conf.get(OUTPUT_PATH);

        //Get spark context, This is the central context , 
        //which can be wrapped in Any Other context
        javaSparkContext = getJavaSparkContext(conf.getBoolean(IS_RUN_LOCALLY, Boolean.FALSE),
                                                        conf.get(DEFAULT_FS), this.getClass());
        sqlContext = new SQLContext(javaSparkContext);

        // No input path has been read, no job has not been started yet .
        //To set any configuration use javaSparkContext.hadoopConfiguration().set(Key,value);
        // To set any custom inputformat use javaSparkContext.newAPIHadoopFile() and get a RDD

        // Avro schema to StructType conversion
        final StructType outPutSchemaStructType = (StructType) SchemaConverters

        // read data from parquetfile, 
        // The schema of the data is taken from the avro schema
        DataFrame inputDf =

        // Convert DataFrame into JavaRDD
        // The rows read from the parquetfile is converted into a Row object . 
        // Row has same schema as that of the parquet file row
        JavaRDD<Row> rowJavaRDD = inputDf.javaRDD();

        //Row has same schema as that of Parquet row , 
        //Parquet Row has same schema as that of Avro Object

                // convert each Row to Employee Object

                // if i use a method call e -> convert(e) instead of static class,
                // i will need to serialize the Outer class
                // Lambda Functions internall needs to be serialized and is causing this issue

                .map(e -> convert(e))

                // Key by empid so that we can collect all the object on Reducer


                //.combineByKey(new EmployeeMaxSalary.CreateCombiner(),
                         new EmployeeMaxSalary.MergeValue(), new EmployeeMaxSalary.MergeCombiner())

                .mapPartitions(new MapPartitionSpark());

        DataFrame outputDf = sqlContext.createDataFrame(rowJavaRDD, outPutSchemaStructType);

        // Convert JavaRDD to dataframe and save into parquet file

        return 0;

    public static class MapPartitionSpark implements
                         FlatMapFunction<Iterator<Tuple2<Integer, Iterable<Employee>>>, Object> {

         // LambdaFuncation used inside the Transformation are instantiated on Driver .
        // The Serialized object is sent to the executor

        // Making a filed transient helps in not serializing it
        private transient TreeSet<Long> employeeBonusSet;

        // Please do not declare any field with static , 
        // as Multiple task can spawn inside same JVM(Executor) is Spark, 
        //leading to Thread unsafe code

        // the call() method is called only once for each partition
        // ( partition from Map or Reduce Task)
        public Iterable<Object> call
                     (Iterator<Tuple2<Integer, Iterable<Employee>>> tuple2Iterator) throws Exception {

            // tuple2Iterator points to the whole partition (all the records in the partition) 
            // and  not a single record.
            // Tuple2<Integer, Iterable<Employee>> for a given key( id => Integer) , 
            // the group of values (Iterable<Employee>) is being pointed to
            // The partition can be of Map Task or Reduce TasK

            // All the code before iterating over  tuple2Iterator 
            // and will be executed only once as the call() is called only once for each partition

            employeeBonusSet = new TreeSet<>();

            // Any service like Hbase clientSetup, Aerospike setup can be instantiate here

            ArrayList<Object> outputOfMapPartition = new ArrayList<>();

            // the while loop points to the loop over all the keys in the partiton
            while (tuple2Iterator.hasNext()) {

                // this points to one of the key in the partiotn and the values associated  with it
                Tuple2<Integer, Iterable<Employee>> integerIterableTuple2 =;

                // Object is being reused and not created on every call

                for (Employee employee : integerIterableTuple2._2()) {



                // Just select one employee and add the Max Bonus to it
                Employee output = integerIterableTuple2._2().iterator().next();


            // the service can be shutdown here in the code.
            // You return only once from the MapPartion
            // when the  data for the whole partition has been processed
            // Unlike map function you dont return once per row , 
            // rather you return once per partition
            return outputOfMapPartition;

    public void close() throws IOException {

    public static void main(String[] args) throws Exception { EmployeeMaxBonusMapPartition(), args);


Integration Test : Github



import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.hadoop.fs.Path;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import parquet.avro.AvroParquetReader;
import parquet.avro.AvroParquetWriter;
import parquet.hadoop.ParquetReader;
import parquet.hadoop.ParquetWriter;
import parquet.hadoop.metadata.CompressionCodecName;


public class ReadWriteAvroParquetFilesTest {

    private static final Logger LOG = LoggerFactory.getLogger(ReadWriteAvroParquetFilesTest.class);
    private static final String BASEDIR = 
              "/tmp/ParquetFilesTest/avroparquetInputFile/" + System.currentTimeMillis() + "/";
    private String input;
    private String output;

    private Employee employee;

    public void setUp() throws IOException {

        input = BASEDIR + "input/";
        output = BASEDIR + "output/";

        employee = new Employee();

        //Write parquet file with GZIP compression
        ParquetWriter<Object> writer = AvroParquetWriter
                                          .builder(new Path(input + "1.gz.parquet"))


    public void testSuccess() throws Exception {

        String[] args = new String[]{"-D" + ReadWriteAvroParquetFiles.INPUT_PATH + "=" + input,
                "-D" + ReadWriteAvroParquetFiles.OUTPUT_PATH + "=" + output,
                "-D" + ReadWriteAvroParquetFiles.IS_RUN_LOCALLY + "=true",
                "-D" + ReadWriteAvroParquetFiles.DEFAULT_FS + "=file:///",
                "-D" + ReadWriteAvroParquetFiles.NUM_PARTITIONS + "=1"};


        ParquetReader<GenericRecord> reader = AvroParquetReader
                                                 .builder(new Path(output))
        //Use .withConf(FS.getConf()) for reading from a diferent HDFS and not local ,
        //By default the fs is local

        GenericData.Record event = (GenericData.Record);
        Employee outputEvent = AvroUtils
                               (event, Employee.getClassSchema()), Employee.getClassSchema());
        reader.close();"Data read from Sparkoutput is {}", outputEvent.toString());
        Assert.assertEquals(employee.getEmpId(), outputEvent.getEmpId());

    public void cleanup() throws IOException {
        FileUtils.deleteDirectory(new File(BASEDIR));