I have the following dataframe:
d = [ {'id': 3, 'ratio': 1.3 ,'vol1': 100 }, {'id': 5, 'ratio': 0.3 ,'vol1': 200 }, {'id': 1, 'ratio': 1.1 ,'vol1': 300 }, {'id': 8, 'ratio': 0.8 ,'vol1': 400 }, {'id': 2, 'ratio': 2.0 ,'vol1': 500 }, {'id': 4, 'ratio': 0.0 ,'vol1': 600 } ] data = spark.createDataFrame(d)
To which I have to create an additional column new_col_cond
that is dependent on the values of multiple external lists/arrays (I have also tried with dictionaries), for example:
q1 = [10,20,30,40,50,60,70,80,90] q1_n = np.array(q1).reshape(-1) #numpy array from above q2 = [1,2,3,4,5,6,7,8,9] q2_n = np.array(q2).reshape(-1)
The new column depends on the value of ratio
and selects from either array according to id
as index. I have tried:
data = data.withColumn('new_col_cond', when(col('ratio')<1, q1[col('id')]) .when(col('ratio')>1, q2[col('id')]) ) #also with numpy arrays.
with errors coming. I assume that the main source of error is using a column as index for the array, but not sure how else to insert the index into the array. Given the conditional nature of the column I have not tried to join (data is millions of rows and lists are in the thousands).
Due to the size of the dataset I am steering away from Pandas and udfs. The resulting dataframe should look like this:
+---+-----+----+------------+ | id|ratio|vol1|new_col_cond| +---+-----+----+------------+ | 3| 1.3| 100| 4 | | 5| 0.3| 200| 60 | | 1| 1.1| 300| 2 | | 8| 0.8| 400| 90 | | 2| 2.0| 500| 3 | | 4| 0.0| 600| 50 | +---+-----+----+------------+
Any help in solving this issue is appreciated.
Advertisement
Answer
Create ArrayType column expressions from the numpy arrays and use them in your condition like this:
from pyspark.sql import functions as F q1_n = F.array(*[F.lit(int(x)) for x in q1_n]) q2_n = F.array(*[F.lit(int(x)) for x in q2_n]) result = data.withColumn( 'new_col_cond', F.when(F.col('ratio') < 1, q1_n[F.col('id')]) .when(F.col('ratio') > 1, q2_n[F.col('id')]) ) result.show() #+---+-----+----+------------+ #| id|ratio|vol1|new_col_cond| #+---+-----+----+------------+ #| 3| 1.3| 100| 4| #| 5| 0.3| 200| 60| #| 1| 1.1| 300| 2| #| 8| 0.8| 400| 90| #| 2| 2.0| 500| 3| #| 4| 0.0| 600| 50| #+---+-----+----+------------+