Support Questions

Find answers, ask questions, and share your expertise
Announcements
Celebrating as our community reaches 100,000 members! Thank you!

Scala- How to find duplicated columns with all values in spark dataframe?

avatar
Explorer

Hi all, I want to count the duplicated columns in a spark dataframe, for example:


idcol1col2col3col4
139994999
228885888
317776777

In this case, the col2 and col4's values(font style is bold) are the same, which is my interest, so let the count +1.

BTW, my dataset contain 20000K+ rows.

Is there any function could solve this?

Any idea will be appreciate, thank you.

3 REPLIES 3

avatar
Master Guru

@DADA206

--> You can try using case statement and then assign 1 value for all the rows that matches col2,col4 otherwise assign 0.

--> Then aggregate on the new column and sum the new column.

Example:

scala> val df=Seq((1,3,999,4,999),(2,2,888,5,888),(3,1,777,6,777)).toDF("id","col1","col2","col3","col4")
scala> df.withColumn("cw",when('col2 === 'col4,1).otherwise(0)).agg(sum('cw) as "su").show()
+---+
| su|
+---+
|  3|
+---+

- As you are having 3 rows same values and the count is 3 in this case.

-

If the answer is helpful to resolve the issue, Login and Click on Accept button below to close this thread.This will help other community users to find answers quickly :-).

avatar
Explorer

@Shu
Thanks for your reply, and I'm sorry I didn't describe it clear, it's my fault. Actually I want to count is all the duplicated columns in a dataframe. Not specified column and not the duplicated rows.
From your suggestion, maybe I can find the answer by comparing each columns and count the (duplicated rows == row count).

avatar
Master Guru

@DADA206

If you are thinking to count all duplicated rows you can use one of these methods.

1.Using dropDuplicates function:

scala> val df1=Seq((1,"q"),(2,"c"),(3,"d"),(1,"q"),(2,"c"),(3,"e")).toDF("id","n")
scala> println("duplicated counts:" +  (df1.count - df1.dropDuplicates.count))
duplicated counts:2

There are 2 duplicated rows in the dataframe it means in total there are 4 rows duplicated.

2.Using groupBy on all columns:

scala> import org.apache.spark.sql.functions._
scala> val cols=df1.columns
scala> df1.groupBy(cols.head,cols.tail:_*).agg(count("*").alias("cnt")).filter('cnt > 1).select(sum("cnt")).show()
+--------+
|sum(cnt)|
+--------+
|       4|
+--------+

3.Using window functions:

scala> import org.apache.spark.sql.expressions.Window
scala> val wdw=Window.partitionBy(cols.head,cols.tail:_*)
wdw: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@1f6df7ac

scala> df1.withColumn("cnt",count("*").over(wdw)).filter('cnt > 1).count()
res80: Long = 4