Created 07-01-2019 02:28 AM
Hi all, I want to count the duplicated columns in a spark dataframe, for example:
| id | col1 | col2 | col3 | col4 |
| 1 | 3 | 999 | 4 | 999 |
| 2 | 2 | 888 | 5 | 888 |
| 3 | 1 | 777 | 6 | 777 |
In this case, the col2 and col4's values(font style is bold) are the same, which is my interest, so let the count +1.
BTW, my dataset contain 20000K+ rows.
Is there any function could solve this?
Any idea will be appreciate, thank you.
Created 07-01-2019 03:38 AM
--> You can try using case statement and then assign 1 value for all the rows that matches col2,col4 otherwise assign 0.
--> Then aggregate on the new column and sum the new column.
Example:
scala> val df=Seq((1,3,999,4,999),(2,2,888,5,888),(3,1,777,6,777)).toDF("id","col1","col2","col3","col4")
scala> df.withColumn("cw",when('col2 === 'col4,1).otherwise(0)).agg(sum('cw) as "su").show()
+---+
| su|
+---+
| 3|
+---+- As you are having 3 rows same values and the count is 3 in this case.
-
If the answer is helpful to resolve the issue, Login and Click on Accept button below to close this thread.This will help other community users to find answers quickly :-).
Created 07-01-2019 07:04 AM
@Shu
Thanks for your reply, and I'm sorry I didn't describe it clear, it's my fault. Actually I want to count is all the duplicated columns in a dataframe. Not specified column and not the duplicated rows.
From your suggestion, maybe I can find the answer by comparing each columns and count the (duplicated rows == row count).
Created 07-01-2019 02:19 PM
If you are thinking to count all duplicated rows you can use one of these methods.
1.Using dropDuplicates function:
scala> val df1=Seq((1,"q"),(2,"c"),(3,"d"),(1,"q"),(2,"c"),(3,"e")).toDF("id","n")
scala> println("duplicated counts:" + (df1.count - df1.dropDuplicates.count))
duplicated counts:2There are 2 duplicated rows in the dataframe it means in total there are 4 rows duplicated.
2.Using groupBy on all columns:
scala> import org.apache.spark.sql.functions._
scala> val cols=df1.columns
scala> df1.groupBy(cols.head,cols.tail:_*).agg(count("*").alias("cnt")).filter('cnt > 1).select(sum("cnt")).show()
+--------+
|sum(cnt)|
+--------+
| 4|
+--------+3.Using window functions:
scala> import org.apache.spark.sql.expressions.Window
scala> val wdw=Window.partitionBy(cols.head,cols.tail:_*)
wdw: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@1f6df7ac
scala> df1.withColumn("cnt",count("*").over(wdw)).filter('cnt > 1).count()
res80: Long = 4