How to use list comprehension on a column with array in pyspark?

I have a pyspark dataframe that looks like this.

+--------------------+-------+--------------------+
|              ID    |country|               attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2,3,4...]        |
|3de27656            |     US|[1,7,2,4...]        |
|75ce4e58            |     US|[1,2,1,4...]        |
|908df65c            |     US|[1,8,3,0...]        |
|f0503257            |     US|[1,2,3,2...]        |
|2tBxD6j             |     US|[1,2,3,4...]        |
|33811685            |     US|[1,5,3,5...]        |
|aad21639            |     US|[7,8,9,4...]        |
|e3d9e3bb            |     US|[1,10,9,4...]       |
|463f6f69            |     US|[12,2,13,4...]      |
+--------------------+-------+--------------------+

I also have a set that looks like this

reference_set = (1,2,100,500,821)

what I want to do is create a new list as a column in the dataframe using maybe a list comprehension like this [attr for attr in attrs if attr in reference_set]

so my final dataframe should be something like this

+--------------------+-------+--------------------+
|              ID    |country|      filtered_attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2]               |
|3de27656            |     US|[1,2]               |
|75ce4e58            |     US|[1,2]               |
|908df65c            |     US|[1]                 |
|f0503257            |     US|[1,2]               |
|2tBxD6j             |     US|[1,2]               |
|33811685            |     US|[1]                 |
|aad21639            |     US|[]                  |
|e3d9e3bb            |     US|[1]                 |
|463f6f69            |     US|[2]                 |
+--------------------+-------+--------------------+

How can I do this? as I'm new to pyspark I can't think of a logic.

Edit : posted a logic below, if there's a more efficient way of doing this please let me know.


Solution 1:

You can use built-in function - array_intersect.

# Sample dataframe

df = spark.createDataFrame([('ffae10af', 'US', [1,2,3,4])], ["ID", "Country", "attrs"])

reference_set = {1,2,100,500,821}

# This step is to add set as column in dataframe
set_to_string = ",".join([str(x) for x in reference_set])

df.withColumn('reference_set', split(lit(set_to_string), ',').cast('array<bigint>')). \
withColumn('filtered_attrs', array_intersect('attrs','reference_set'))\ 
.show(truncate = False)

+--------+-------+------------+---------------------+--------------+
|ID      |Country|attrs       |reference_set        |filtered_attrs|
+--------+-------+------------+---------------------+--------------+
|ffae10af|US     |[1, 2, 3, 4]|[1, 2, 100, 500, 821]|[1, 2]        |
+--------+-------+------------+---------------------+--------------+