Context: I need to filter a dataframe based on what contains another dataframe’s column using the isin function.
For Python users working with pandas, that would be: isin().
For R users, that would be: %in%.
So I have a simple spark dataframe with id and value columns:
l = [(1, 12), (1, 44), (1, 3), (2, 54), (3, 18), (3, 11), (4, 13), (5, 78)] df = spark.createDataFrame(l, ['id', 'value']) df.show() +---+-----+ | id|value| +---+-----+ | 1| 12| | 1| 44| | 1| 3| | 2| 54| | 3| 18| | 3| 11| | 4| 13| | 5| 78| +---+-----+
I want to get all ids that appear multiple times. Here’s a dataframe of unique ids in df:
unique_ids = df.groupBy('id').count().where(col('count') < 2) unique_ids.show() +---+-----+ | id|count| +---+-----+ | 5| 1| | 2| 1| | 4| 1| +---+-----+
So the logical operation would be:
df = df[~df.id.isin(unique_ids.id)] # This is the same than: df = df[df.id.isin(unique_ids.id) == False]
However, I get an empty dataframe:
df.show() +---+-----+ | id|value| +---+-----+ +---+-----+
This “error” works in the opposite way:
df[df.id.isin(unique_ids.id)]
returns all the rows of df.
Advertisement
Answer
The expression df.id.isin(unique_ids.id) == False
is evaluating if Column<b'((id IN (id)) = false)'>
and this will never happen because id is in id. However, the expression df.id.isin(unique_ids.id)
is evaluating if Column<b'(id IN (id))'>
, and this is always true, for that reason it returns the whole data frame. unique_ids.id
is a Column not a list.
isin(*cols)
receives a list of values as an argument, not a column, so, to work in this way, you should execute the following:
ids = unique_ids.rdd.map(lambda x:x.id).collect() df[df.id.isin(ids)].collect() # or show...
and you will obtain:
[Row(id=2, value=54), Row(id=4, value=13), Row(id=5, value=78)]
In any case, I think it would be better if you join both data frames:
df_ = df.join(unique_ids, on='id')
getting:
df_.show() +---+-----+-----+ | id|value|count| +---+-----+-----+ | 5| 78| 1| | 2| 54| 1| | 4| 13| 1| +---+-----+-----+