How to interact with each element of an ArrayType column in pyspark?

If I have an ArrayType column in pyspark

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(((1,[]),(2,[1,2,3]),(3,[-2])),schema=StructType([StructField("a",IntegerType()),StructField("b",ArrayType(IntegerType()))]))
df.show()
output:
+---+---------+
|  a|        b|
+---+---------+
|  1|       []|
|  2|[1, 2, 3]|
|  3|     [-2]|

Now, I want to be able to interact with each element of column b, Like,

  1. Divide each element by 5 output:
+---+---------------+
|  a|              b|
+---+---------------+
|  1|             []|
|  2|[0.2, 0.4, 0.6]|
|  3|         [-0.4]|
+---+---------------+
  1. Add to each element etc.

How do I go about such transformations where some operator or function is applied to each element of the array type columns?


Solution 1:

You are looking for the tranform function. Transform enables to apply computation on each element of an array.

from pyspark.sql import functions as F

# Spark < 3.1.0
df.withColumn("b", F.expr("transform(b, x ->  x / 5)")).show()

"""
+---+---------------+
|  a|              b|
+---+---------------+
|  1|             []|
|  2|[0.2, 0.4, 0.6]|
|  3|         [-0.4]|
+---+---------------+
"""

# Spark >= 3.1.0

df.withColumn("b", F.transform("b", lambda x: x / 5)).show()
"""
+---+---------------+
|  a|              b|
+---+---------------+
|  1|             []|
|  2|[0.2, 0.4, 0.6]|
|  3|         [-0.4]|
+---+---------------+
"""