Spark mapPartitions
Understanding Map Partition in Spark
Problem : Given a parquet file having Employee data , one needs to find the maximum Bonus earned by each employee and save the data back in parquet (Github)
1. Parquet file (Huge file on HDFS ) , Avro Schema :
|– emp_id: integer (nullable = false)
|– emp_name: string (nullable = false)
|– emp_country: string (nullable = false
|– emp_bonus: string (nullable = true)
|– subordinates: map (nullable = true)
| |– key: string
| |– value: string (valueContainsNull = false)
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.10</artifactId>
<version>1.6.0-cdh5.9.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.6.0-cdh5.9.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.big.data</groupId>
<artifactId>avro-schema</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.databricks</groupId>
<artifactId>spark-avro_2.10</artifactId>
<version>3.0.0</version>
</dependency>
<dependency>
<groupId>com.twitter</groupId>
<artifactId>parquet-avro</artifactId>
<version>1.5.0-cdh5.9.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
<version>1.1.1</version>
</dependency>
<dependency>
<groupId>com.databricks</groupId>
<artifactId>spark-csv_2.10</artifactId>
<version>1.5.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.twitter</groupId>
<artifactId>parquet-avro</artifactId>
<version>1.5.0-cdh5.9.0</version>
</dependency>
</dependencies>
<repositories>
<repository>
<id>cloudera</id>
<url>https://repository.cloudera.com/artifactory/cloudera-repos/</url>
</repository>
</repositories>
[addToAppearHere]
Code :
package com.big.data.spark; import com.big.data.avro.schema.Employee; import com.databricks.spark.avro.SchemaConverters; import org.apache.avro.AvroRuntimeException; import org.apache.commons.io.IOUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.Tuple2; import java.io.Closeable; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Iterator; import java.util.TreeSet; public class EmployeeMaxBonusMapPartition extends Configured implements Tool, Closeable, Serializable { public static final String INPUT_PATH = "spark.input.path"; public static final String OUTPUT_PATH = "spark.output.path"; public static final String IS_RUN_LOCALLY = "spark.is.run.local"; public static final String DEFAULT_FS = "spark.default.fs"; public static final String NUM_PARTITIONS = "spark.num.partitions"; // Just check because of a function use , the outer class is forced to be serialized // Example which throws light of serialization of Lambda function . private transient SQLContext sqlContext; private transient JavaSparkContext javaSparkContext; protected <T> JavaSparkContext getJavaSparkContext(final boolean isRunLocal, final String defaultFs, final Class<T> tClass) { final SparkConf sparkConf = new SparkConf() //Set spark conf here , //After one gets spark context you can set hadoop configuration for InputFormats .setAppName(tClass.getSimpleName()) .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); if (isRunLocal) { sparkConf.setMaster("local[*]"); } final JavaSparkContext sparkContext = new JavaSparkContext(sparkConf); if (defaultFs != null) { sparkContext.hadoopConfiguration().set("fs.defaultFS", defaultFs); } return sparkContext; } // Convert Row to Avro POJO (Employee) public Employee convert(Row row) { try { // Employee Schema => ParquetRow Schema =>Row Schema Employee avroInstance = new Employee(); for (StructField field : row.schema().fields()) { //row.fieldIndex => pos of the field Name in the schema avroInstance.put(field.name(), row.get(row.fieldIndex(field.name()))); } return avroInstance; } catch (Exception e) { throw new AvroRuntimeException("Avro POJO building failed ", e); } } @Override public int run(String[] args) throws Exception { //The arguments passed has been split into Key value by ToolRunner Configuration conf = getConf(); String inputPath = conf.get(INPUT_PATH); String outputPath = conf.get(OUTPUT_PATH); //Get spark context, This is the central context , //which can be wrapped in Any Other context javaSparkContext = getJavaSparkContext(conf.getBoolean(IS_RUN_LOCALLY, Boolean.FALSE), conf.get(DEFAULT_FS), this.getClass()); sqlContext = new SQLContext(javaSparkContext); // No input path has been read, no job has not been started yet . //To set any configuration use javaSparkContext.hadoopConfiguration().set(Key,value); // To set any custom inputformat use javaSparkContext.newAPIHadoopFile() and get a RDD // Avro schema to StructType conversion final StructType outPutSchemaStructType = (StructType) SchemaConverters .toSqlType(Employee.getClassSchema()) .dataType(); // read data from parquetfile, // The schema of the data is taken from the avro schema DataFrame inputDf = sqlContext.read() .format(Employee.class.getCanonicalName()) .parquet(inputPath); // Convert DataFrame into JavaRDD // The rows read from the parquetfile is converted into a Row object . // Row has same schema as that of the parquet file row JavaRDD<Row> rowJavaRDD = inputDf.javaRDD(); //Row has same schema as that of Parquet row , //Parquet Row has same schema as that of Avro Object rowJavaRDD // convert each Row to Employee Object // if i use a method call e -> convert(e) instead of static class, // i will need to serialize the Outer class // Lambda Functions internall needs to be serialized and is causing this issue .map(e -> convert(e)) // Key by empid so that we can collect all the object on Reducer .keyBy(Employee::getEmpId) .groupByKey() //.combineByKey(new EmployeeMaxSalary.CreateCombiner(), new EmployeeMaxSalary.MergeValue(), new EmployeeMaxSalary.MergeCombiner()) .mapPartitions(new MapPartitionSpark()); DataFrame outputDf = sqlContext.createDataFrame(rowJavaRDD, outPutSchemaStructType); // Convert JavaRDD to dataframe and save into parquet file outputDf .write() .format(Employee.class.getCanonicalName()) .parquet(outputPath); return 0; } public static class MapPartitionSpark implements FlatMapFunction<Iterator<Tuple2<Integer, Iterable<Employee>>>, Object> { // LambdaFuncation used inside the Transformation are instantiated on Driver . // The Serialized object is sent to the executor // Making a filed transient helps in not serializing it private transient TreeSet<Long> employeeBonusSet; // Please do not declare any field with static , // as Multiple task can spawn inside same JVM(Executor) is Spark, //leading to Thread unsafe code // the call() method is called only once for each partition // ( partition from Map or Reduce Task) @Override public Iterable<Object> call (Iterator<Tuple2<Integer, Iterable<Employee>>> tuple2Iterator) throws Exception { // tuple2Iterator points to the whole partition (all the records in the partition) // and not a single record. // Tuple2<Integer, Iterable<Employee>> for a given key( id => Integer) , // the group of values (Iterable<Employee>) is being pointed to // The partition can be of Map Task or Reduce TasK // All the code before iterating over tuple2Iterator // and will be executed only once as the call() is called only once for each partition employeeBonusSet = new TreeSet<>(); // Any service like Hbase clientSetup, Aerospike setup can be instantiate here ArrayList<Object> outputOfMapPartition = new ArrayList<>(); // the while loop points to the loop over all the keys in the partiton while (tuple2Iterator.hasNext()) { // this points to one of the key in the partiotn and the values associated with it Tuple2<Integer, Iterable<Employee>> integerIterableTuple2 = tuple2Iterator.next(); // Object is being reused and not created on every call employeeBonusSet.clear(); for (Employee employee : integerIterableTuple2._2()) { employeeBonusSet.add(employee.getBonus()); } // Just select one employee and add the Max Bonus to it Employee output = integerIterableTuple2._2().iterator().next(); output.setBonus(employeeBonusSet.last()); outputOfMapPartition.add(output); } // the service can be shutdown here in the code. // You return only once from the MapPartion // when the data for the whole partition has been processed // Unlike map function you dont return once per row , // rather you return once per partition return outputOfMapPartition; } } @Override public void close() throws IOException { IOUtils.closeQuietly(javaSparkContext); } public static void main(String[] args) throws Exception { ToolRunner.run(new EmployeeMaxBonusMapPartition(), args); } } [addToAppearHere]
Integration Test : Github
package com.big.data.spark; import com.big.data.avro.AvroUtils; import com.big.data.avro.schema.Employee; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.commons.io.FileUtils; import org.apache.hadoop.fs.Path; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import parquet.avro.AvroParquetReader; import parquet.avro.AvroParquetWriter; import parquet.hadoop.ParquetReader; import parquet.hadoop.ParquetWriter; import parquet.hadoop.metadata.CompressionCodecName; import java.io.File; import java.io.IOException; public class ReadWriteAvroParquetFilesTest { private static final Logger LOG = LoggerFactory.getLogger(ReadWriteAvroParquetFilesTest.class); private static final String BASEDIR = "/tmp/ParquetFilesTest/avroparquetInputFile/" + System.currentTimeMillis() + "/"; private String input; private String output; private Employee employee; @Before public void setUp() throws IOException { input = BASEDIR + "input/"; output = BASEDIR + "output/"; employee = new Employee(); employee.setEmpId(1); employee.setEmpName("Maverick"); employee.setEmpCountry("DE"); //Write parquet file with GZIP compression ParquetWriter<Object> writer = AvroParquetWriter .builder(new Path(input + "1.gz.parquet")) .withCompressionCodec(CompressionCodecName.GZIP) .withSchema(Employee.getClassSchema()) .build(); writer.write(employee); writer.close(); } @Test public void testSuccess() throws Exception { String[] args = new String[]{"-D" + ReadWriteAvroParquetFiles.INPUT_PATH + "=" + input, "-D" + ReadWriteAvroParquetFiles.OUTPUT_PATH + "=" + output, "-D" + ReadWriteAvroParquetFiles.IS_RUN_LOCALLY + "=true", "-D" + ReadWriteAvroParquetFiles.DEFAULT_FS + "=file:///", "-D" + ReadWriteAvroParquetFiles.NUM_PARTITIONS + "=1"}; ReadWriteAvroParquetFiles.main(args); ParquetReader<GenericRecord> reader = AvroParquetReader .builder(new Path(output)) .build(); //Use .withConf(FS.getConf()) for reading from a diferent HDFS and not local , //By default the fs is local GenericData.Record event = (GenericData.Record) reader.read(); Employee outputEvent = AvroUtils .convertByteArraytoAvroPojo(AvroUtils.convertAvroPOJOtoByteArray (event, Employee.getClassSchema()), Employee.getClassSchema()); reader.close(); LOG.info("Data read from Sparkoutput is {}", outputEvent.toString()); Assert.assertEquals(employee.getEmpId(), outputEvent.getEmpId()); } @After public void cleanup() throws IOException { FileUtils.deleteDirectory(new File(BASEDIR)); } }