Skip to content
Advertisement

How to select rows from list in PySpark

Suppose we have two dataframes df1 and df2 where df1 has columns [a, b, c, p, q, r] and df2 has columns [d, e, f, a, b, c]. Suppose the common columns are stored in a list common_cols = ['a', 'b', 'c'].

How do you join the two dataframes using the common_cols list within a sql command? The code below attempts to do this.

common_cols = ['a', 'b', 'c']
filter_df = spark.sql("""
    select * from df1 inner join df2
    on df1.common_cols = df2.common_cols
""")

Advertisement

Answer

Demo setup

df1 = spark.createDataFrame([(1,2,3,4,5,6)],['a','b','c','p','q','r'])
df2 = spark.createDataFrame([(7,8,9,1,2,3)],['d','e','f','a','b','c'])
common_cols = ['a','b','c']

df1.show()

+---+---+---+---+---+---+
|  a|  b|  c|  p|  q|  r|
+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|
+---+---+---+---+---+---+


df2.show()

+---+---+---+---+---+---+
|  d|  e|  f|  a|  b|  c|
+---+---+---+---+---+---+
|  7|  8|  9|  1|  2|  3|
+---+---+---+---+---+---+

Solution, based on using (SQL syntax for join)

df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')
common_cols_csv = ','.join(common_cols)

query = f'''
select  * 
from    df1 inner join df2 using ({common_cols_csv})
'''

       

print(query)

select  * 
from    df1 inner join df2 using (a,b,c)

filter_df = spark.sql(query)

filter_df.show()

+---+---+---+---+---+---+---+---+---+
|  a|  b|  c|  p|  q|  r|  d|  e|  f|
+---+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|  9|
+---+---+---+---+---+---+---+---+---+
Advertisement