How to use list comprehension on a column with array in pyspark?
I have a pyspark dataframe that looks like this.
+--------------------+-------+--------------------+
| ID |country| attrs|
+--------------------+-------+--------------------+
|ffae10af | US|[1,2,3,4...] |
|3de27656 | US|[1,7,2,4...] |
|75ce4e58 | US|[1,2,1,4...] |
|908df65c | US|[1,8,3,0...] |
|f0503257 | US|[1,2,3,2...] |
|2tBxD6j | US|[1,2,3,4...] |
|33811685 | US|[1,5,3,5...] |
|aad21639 | US|[7,8,9,4...] |
|e3d9e3bb | US|[1,10,9,4...] |
|463f6f69 | US|[12,2,13,4...] |
+--------------------+-------+--------------------+
I also have a set that looks like this
reference_set = (1,2,100,500,821)
what I want to do is create a new list as a column in the dataframe using maybe a list comprehension like this [attr for attr in attrs if attr in reference_set]
so my final dataframe should be something like this
+--------------------+-------+--------------------+
| ID |country| filtered_attrs|
+--------------------+-------+--------------------+
|ffae10af | US|[1,2] |
|3de27656 | US|[1,2] |
|75ce4e58 | US|[1,2] |
|908df65c | US|[1] |
|f0503257 | US|[1,2] |
|2tBxD6j | US|[1,2] |
|33811685 | US|[1] |
|aad21639 | US|[] |
|e3d9e3bb | US|[1] |
|463f6f69 | US|[2] |
+--------------------+-------+--------------------+
How can I do this? as I'm new to pyspark I can't think of a logic.
Edit : posted a logic below, if there's a more efficient way of doing this please let me know.
Solution 1:
You can use built-in function - array_intersect.
# Sample dataframe
df = spark.createDataFrame([('ffae10af', 'US', [1,2,3,4])], ["ID", "Country", "attrs"])
reference_set = {1,2,100,500,821}
# This step is to add set as column in dataframe
set_to_string = ",".join([str(x) for x in reference_set])
df.withColumn('reference_set', split(lit(set_to_string), ',').cast('array<bigint>')). \
withColumn('filtered_attrs', array_intersect('attrs','reference_set'))\
.show(truncate = False)
+--------+-------+------------+---------------------+--------------+
|ID |Country|attrs |reference_set |filtered_attrs|
+--------+-------+------------+---------------------+--------------+
|ffae10af|US |[1, 2, 3, 4]|[1, 2, 100, 500, 821]|[1, 2] |
+--------+-------+------------+---------------------+--------------+