Support Questions

Find answers, ask questions, and share your expertise
Celebrating as our community reaches 100,000 members! Thank you!

Scala- How to find duplicated columns with all values in spark dataframe?


Hi all, I want to count the duplicated columns in a spark dataframe, for example:


In this case, the col2 and col4's values(font style is bold) are the same, which is my interest, so let the count +1.

BTW, my dataset contain 20000K+ rows.

Is there any function could solve this?

Any idea will be appreciate, thank you.


Master Guru


--> You can try using case statement and then assign 1 value for all the rows that matches col2,col4 otherwise assign 0.

--> Then aggregate on the new column and sum the new column.


scala> val df=Seq((1,3,999,4,999),(2,2,888,5,888),(3,1,777,6,777)).toDF("id","col1","col2","col3","col4")
scala> df.withColumn("cw",when('col2 === 'col4,1).otherwise(0)).agg(sum('cw) as "su").show()
| su|
|  3|

- As you are having 3 rows same values and the count is 3 in this case.


If the answer is helpful to resolve the issue, Login and Click on Accept button below to close this thread.This will help other community users to find answers quickly :-).


Thanks for your reply, and I'm sorry I didn't describe it clear, it's my fault. Actually I want to count is all the duplicated columns in a dataframe. Not specified column and not the duplicated rows.
From your suggestion, maybe I can find the answer by comparing each columns and count the (duplicated rows == row count).

Master Guru


If you are thinking to count all duplicated rows you can use one of these methods.

1.Using dropDuplicates function:

scala> val df1=Seq((1,"q"),(2,"c"),(3,"d"),(1,"q"),(2,"c"),(3,"e")).toDF("id","n")
scala> println("duplicated counts:" +  (df1.count - df1.dropDuplicates.count))
duplicated counts:2

There are 2 duplicated rows in the dataframe it means in total there are 4 rows duplicated.

2.Using groupBy on all columns:

scala> import org.apache.spark.sql.functions._
scala> val cols=df1.columns
scala> df1.groupBy(cols.head,cols.tail:_*).agg(count("*").alias("cnt")).filter('cnt > 1).select(sum("cnt")).show()
|       4|

3.Using window functions:

scala> import org.apache.spark.sql.expressions.Window
scala> val wdw=Window.partitionBy(cols.head,cols.tail:_*)
wdw: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@1f6df7ac

scala> df1.withColumn("cnt",count("*").over(wdw)).filter('cnt > 1).count()
res80: Long = 4