Support Questions
Find answers, ask questions, and share your expertise

Why Dataframe, Dataset runs much slower than RDD for crossjoin?

I did some test using join, and find RDD runs much faster than Dataframe and Dataset. I am using Spark 2.1.0
Is there a reason behind or I am doing something wrong?
I always heard that Dataframe and Dataset should be faster.
Here is the result
Run for iteration No. 0
Time to parallelize the RDD     1403 ms


Output of test testRDD iteration No. 0
Total filtered result: num of items 7599395
Total time spent on convert           query      4614 sum      1136


Output of test testDataFrameJohn iteration No.   0 hashJoin = false
Total filtered result: num of items 7599395
Total time spent on convert      3013 query     13602 sum      2846


Output of test testDataFrameJohn iteration No.   0 hashJoin = true
Total filtered result: num of items 7599395
Total time spent on convert       261 query     12256 sum      2773


Output of test testDataSetJoin iteration No. 0
Total filtered result: num of items 7599395
Total time spent on convert       260 query     12089  sum       688
Here is the code
package myTest;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.*;
import scala.Tuple2;
import org.apache.spark.api.java.function.Function;
 
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static java.lang.Math.abs;
import static java.lang.Math.ceil;
import static org.apache.spark.sql.functions.broadcast;
import static org.apache.spark.sql.functions.expr;

 
public class CrossJoinTest implements Serializable {
    static int num1 = 50000;
    static int num2 = 1000;

    static SparkConf conf;
    static JavaSparkContext sc;
    static List<Person> personList1=null;
    static List<Person> personList2=null;
    static int itr=0;

    static int upper = 89;
    static int lower = 13;
    static int difference = 5;
    static String exprStr1 = "people1.age >= " + lower + " AND people1.age <= " + upper  ;

    static String exprStr2 = " ABS(people1.age-people2.age) < " + difference;

    static {

    }

    public static void setUp() { 
        conf = new SparkConf().setAppName("NMS Tuning Engine").setMaster("local[1]");
        sc = new JavaSparkContext(conf);
    }

    public static void tearDown() { 
        sc.stop();
    }


    static private JavaRDD<Person> personRDD1;
    static private JavaRDD<Person> personRDD2;

    public static void initData (int itr) {


        personList1 = new ArrayList<>();
        personList2 = new ArrayList<>();

        for (int i=0; i< num1; i++) {
            Person p1 = new Person();
            p1.setAge((i + itr)/(ceil(num1/100)) );
            p1.setName("atAge1:"+(i+itr));
            personList1.add(p1);
        }

        for (int i=0; i< num2; i++) {
            Person p1 = new Person();
            p1.setAge((i+itr)/(ceil(num2/100)) );
            p1.setName("RDD"+(i+itr));
            personList2.add(p1);
        }

        long time1 = System.currentTimeMillis();

        personRDD1 = sc.parallelize (personList1);
        personRDD2 = sc.parallelize (personList2);

        personRDD1.cache();
        personRDD2.cache();

        personRDD1.foreachPartition(x-> {
            while (x.hasNext()) {
                Person xx = x.next();
                double a = xx.getAge();
                return;
            }
        });
        personRDD2.foreachPartition(x-> {
            while (x.hasNext()) {
                Person xx = x.next();
                double a = xx.getAge();
                return;
            }
        });

        long time2 = System.currentTimeMillis();

        System.err.println(String.format("Time to parallelize the RDD %8d ms", time2-time1));
    }



    public static void main(String[] args) {
        setUp();

        for (int k=0; k<1; k++) {
            itr = k;
            System.err.println("--------------- \nRun for iteration No. " + k);

            initData(k);

            runRDD ();

            runDataFrameJoin (false);
            runDataFrameJoin (true);

            runDataSetJoin();
        }

        tearDown ();
    }

    static class FilterPeople implements Function<Tuple2<Person , Person>, Boolean> {
        //return (twoPersons
        @Override
        public Boolean call (Tuple2<Person , Person> twoPeople) {
            return  abs(twoPeople._1.getAge() - twoPeople._2.getAge()) < difference;
        }
    }


    public static void runRDD ()   {


        long startSql = System.currentTimeMillis();

        JavaPairRDD<Person , Person> sameGeneration = personRDD1.filter(x->x.getAge()>=lower && x.getAge()<=upper)
                .cartesian(personRDD2).filter(new FilterPeople());

        sameGeneration.cache();

        long numfCont = sameGeneration.count();
        long cntSql = System.currentTimeMillis();

        JavaRDD.fromRDD(JavaPairRDD.toRDD(sameGeneration), sameGeneration.classTag()).foreachPartition(x-> {
            while (x.hasNext()) {
                Tuple2<Person, Person> xx = x.next();
                double a = xx._1.getAge() + xx._2.getAge();
                return;
            }
        });


        long endSql = System.currentTimeMillis();


        long t1 = cntSql - startSql;
        long t2 = endSql - cntSql;

        System.err.println("\nOutput of test testRDD iteration No. " + itr);
        System.err.println("Total filtered result: num of items " + numfCont);
        System.err.println(String.format("Total time spent on convert           query %9d sum %9d" ,  t1, t2));
    }


    public static void runDataSetJoin () {



        long startSql = System.currentTimeMillis();

        SparkSession ss = SparkSession.builder().getOrCreate();

        Dataset<Person> personDataset1 = ss.createDataFrame(personRDD1, Person.class).as(Encoders.bean(Person.class));
        Dataset<Person> personDataset2 = ss.createDataFrame(personRDD2, Person.class).as(Encoders.bean(Person.class)).as("people2");


        personDataset1.count();
        personDataset1.cache();
        personDataset2.count();
        personDataset2.cache();

        long convert = System.currentTimeMillis();

        Dataset<Tuple2<Person, Person>> sameGeneration = personDataset1.filter(x->x.getAge()>=lower && x.getAge()<=upper).as("people1").
                joinWith(personDataset2, expr(exprStr2));

        sameGeneration.cache();

        long numfCont = sameGeneration.count();
        long cntSql = System.currentTimeMillis();

        //sameGeneration.explain(true);

        sameGeneration.foreachPartition(x-> {
            while (x.hasNext()) {
                Tuple2<Person, Person> xx = x.next();
                double a = xx._1.getAge() + xx._2.getAge();
                return;
            }
        });

        long endSql = System.currentTimeMillis();


        long t0 = convert - startSql;
        long t1 = cntSql - convert;
        long t2 = endSql - cntSql;

        System.err.println("\nOutput of test testDataSetJoin iteration No. " + itr);
        System.err.println("Total filtered result: num of items " + numfCont);
        System.err.println(String.format("Total time spent on convert %9d query %9d  sum %9d" , t0, t1, t2));

    }


    public static void runDataFrameJoin (boolean hashJoin) {

        SparkSession spark = SparkSession.builder().getOrCreate();
        spark.conf().set ("spark.sql.crossJoin.enabled", true);

        long startSql = System.currentTimeMillis();

        Dataset<Row> schemaPeople1 =  (spark.createDataFrame(personRDD1, Person.class).select("name", "age").as("people1"));
        Dataset<Row> schemaPeople2 =  spark.createDataFrame(personRDD2, Person.class).select("name", "age").as ("people2");


        schemaPeople1.count();
        schemaPeople1.cache();
        schemaPeople2.count();
        schemaPeople2.cache();

        long convert = System.currentTimeMillis();

        schemaPeople1 =   schemaPeople1.filter(exprStr1).as ("people1");
        Dataset<Row> sameGeneration = null;

        if (hashJoin)
            sameGeneration = schemaPeople1.join( broadcast(schemaPeople2), expr(exprStr2));
        else
            sameGeneration = schemaPeople1.join(schemaPeople2, expr(exprStr2));

        sameGeneration.cache();

        long numfCont = sameGeneration.count();
        long cntSql = System.currentTimeMillis();


        sameGeneration.javaRDD().foreachPartition(new VoidFunction<Iterator<Row>>() {
            public void call (Iterator<Row> rows) {
                 while (rows.hasNext()) {
                    Row row = rows.next();
                    double age = row.getDouble(1) + row.getDouble(3);
                }
            }
        }
        );

        long endSql = System.currentTimeMillis();

        long t0 = convert - startSql;
        long t1 = cntSql - convert;
        long t2 = endSql - cntSql;

        System.err.println(String.format("\nOutput of test testDataFrameJohn iteration No. %3d hashJoin = " + itr) + hashJoin);
        System.err.println("Total filtered result: num of items " + numfCont);
        System.err.println(String.format("Total time spent on convert %9d query %9d sum %9d " , t0, t1, t2));

    }

}
package myTest;

import java.io.Serializable;
 
public  class Person implements Serializable {
    private String name;
    private double age;
    // private List<Integer> code;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public double getAge() {
        return age;
    }

    public void setAge(double age) {
        this.age = age;
    }
}
Here is the script I run the job
#!/usr/bin/bash

SPARK_MAJOR_VERSION=2
SPARK_CONF_DIR="/home/test/spark-defaults.conf"


export SPARK_MAJOR_VERSION
export SPARK_CONF_DIR


spark-submit --master local \
--class myTest.CrossJoinTest ./target/test-project-1.0.0.jar \
        --executor-cores 4   \
        --num-executors 2   \
        --executor-memory 4g  \
        --driver-cores  2 \
        --driver-memory 4g \


0 REPLIES 0