Skip to content
Advertisement

How to write this pandas logic for pyspark.sql.dataframe.DataFrame without using pandas on spark API?

I’m totally new to Pyspark, as Pyspark doesn’t have loc feature how can we write this logic. I tried by specifying conditions but couldn’t get the desirable result, any help would be greatly appreciated!

df['Total'] = (df['level1']+df['level2']+df['level3']+df['level4'])/df['Number']
df.loc[df['level4'] > 0, 'Total'] += 4
df.loc[((df['level3'] > 0) & (df['Total'] < 1)), 'Total'] += 3
df.loc[((df['level2'] > 0) & (df['Total'] < 1)), 'Total'] += 2
df.loc[((df['level1'] > 0) & (df['Total'] < 1)), 'Total'] += 1

Advertisement

Answer

For a data like the following

data_ls = [
    (1, 1, 1, 1, 10),
    (5, 5, 5, 5, 10)
]

data_sdf = spark.sparkContext.parallelize(data_ls). 
    toDF(['level1', 'level2', 'level3', 'level4', 'number'])

# +------+------+------+------+------+
# |level1|level2|level3|level4|number|
# +------+------+------+------+------+
# |     1|     1|     1|     1|    10|
# |     5|     5|     5|     5|    10|
# +------+------+------+------+------+

You’re actually updating total column in each statement, not in an if-then-else way. Your code can be replicated (as is) in pyspark using multiple withColumn() with when() like the following.

data_sdf. 
    withColumn('total', (func.col('level1') + func.col('level2') + func.col('level3') + func.col('level4')) / func.col('number')). 
    withColumn('total', func.when(func.col('level4') > 0, func.col('total') + 4).otherwise(func.col('total'))). 
    withColumn('total', func.when((func.col('level3') > 0) & (func.col('total') < 1), func.col('total') + 3).otherwise(func.col('total'))). 
    withColumn('total', func.when((func.col('level2') > 0) & (func.col('total') < 1), func.col('total') + 2).otherwise(func.col('total'))). 
    withColumn('total', func.when((func.col('level1') > 0) & (func.col('total') < 1), func.col('total') + 1).otherwise(func.col('total'))). 
    show()

# +------+------+------+------+------+-----+
# |level1|level2|level3|level4|number|total|
# +------+------+------+------+------+-----+
# |     1|     1|     1|     1|    10|  4.4|
# |     5|     5|     5|     5|    10|  6.0|
# +------+------+------+------+------+-----+

We can merge all the withColumn() with when() into a single withColumn() with multiple when() statements.

data_sdf. 
withColumn('total', (func.col('level1') + func.col('level2') + func.col('level3') + func.col('level4')) / func.col('number')). 
withColumn('total', 
           func.when(func.col('level4') > 0, func.col('total') + 4).
           when((func.col('level3') > 0) & (func.col('total') < 1), func.col('total') + 3).
           when((func.col('level2') > 0) & (func.col('total') < 1), func.col('total') + 2).
           when((func.col('level1') > 0) & (func.col('total') < 1), func.col('total') + 1).
           otherwise(func.col('total'))
           ). 
show()

# +------+------+------+------+------+-----+
# |level1|level2|level3|level4|number|total|
# +------+------+------+------+------+-----+
# |     1|     1|     1|     1|    10|  4.4|
# |     5|     5|     5|     5|    10|  6.0|
# +------+------+------+------+------+-----+

It’s like numpy.where and SQL’s case statements.

User contributions licensed under: CC BY-SA
5 People found this is helpful
Advertisement