Skip to content
Advertisement

PySpark udf returns null when function works in Pandas dataframe

I’m trying to create a user-defined function that takes a cumulative sum of an array and compares the value to another column. Here is a reproducible example:

from pyspark.sql.session import SparkSession

# instantiate Spark
spark = SparkSession.builder.getOrCreate()

# make some test data
columns = ['loc', 'id', 'date', 'x', 'y']
vals = [
    ('a', 'b', '2016-07-01', 1, 5),
    ('a', 'b', '2016-07-02', 0, 5),
    ('a', 'b', '2016-07-03', 5, 15),
    ('a', 'b', '2016-07-04', 7, 5),
    ('a', 'b', '2016-07-05', 8, 20),
    ('a', 'b', '2016-07-06', 1, 5)
]

# create DataFrame
temp_sdf = (spark
      .createDataFrame(vals, columns)
      .withColumn('x_ary', collect_list('x').over(Window.partitionBy(['loc','id']).orderBy(desc('date')))))

temp_df = temp_sdf.toPandas()

def test_function(x_ary, y):
  cumsum_array = np.cumsum(x_ary) 
  result = len([x for x in cumsum_array if x <= y])
  return result

test_function_udf = udf(test_function, ArrayType(LongType()))

temp_df['len'] = temp_df.apply(lambda x: test_function(x['x_ary'], x['y']), axis = 1)
display(temp_df)

In Pandas, this is the output:

loc id  date        x   y   x_ary           len
a   b   2016-07-06  1   5   [1]             1
a   b   2016-07-05  8   20  [1,8]           2
a   b   2016-07-04  7   5   [1,8,7]         1
a   b   2016-07-03  5   15  [1,8,7,5]       2
a   b   2016-07-02  0   5   [1,8,7,5,0]     1
a   b   2016-07-01  1   5   [1,8,7,5,0,1]   1

In Spark using temp_sdf.withColumn('len', test_function_udf('x_ary', 'y')), all of len ends up being null.

Would anyone know why this is the case?

Also, replacing cumsum_array = np.cumsum(np.flip(x_ary)) fails in pySpark with error AttributeError: module 'numpy' has no attribute 'flip', but I know it exists as I can run it fine with Pandas dataframe.
Can this issue be resolved, or is there a better way to flip arrays with pySpark?

Thanks in advance for your help.

Advertisement

Answer

Since test_function returns integer not List/Array. You will get null values as have you mentioned wrong return type. So please remove “ArrayType from udf” or replace return type as LongType() then it will work as given below. :

Note: You can optionally set the return type of your UDF else the default return type is StringType.

Option1:

test_function_udf = udf(test_function) # Returns String type

Option2:

test_function_udf = udf(test_function, LongType())  #Returns Long/integer type

temp_sdf = temp_sdf.withColumn('len', 
           test_function_udf('x_ary', 'y'))
temp_sdf.show()
User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement