How to access element of a VectorUDT column in a Spark DataFrame?
I have a dataframe df
with a VectorUDT
column named features
. How do I get an element of the column, say first element?
I've tried doing the following
from pyspark.sql.functions import udf
first_elem_udf = udf(lambda row: row.values[0])
df.select(first_elem_udf(df.features)).show()
but I get a net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict(for numpy.dtype)
error. Same error if I do first_elem_udf = first_elem_udf(lambda row: row.toArray()[0])
instead.
I also tried explode()
but I get an error because it requires an array or map type.
This should be a common operation, I think.
Solution 1:
Convert output to float
:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import lit, udf
def ith_(v, i):
try:
return float(v[i])
except ValueError:
return None
ith = udf(ith_, DoubleType())
Example usage:
from pyspark.ml.linalg import Vectors
df = sc.parallelize([
(1, Vectors.dense([1, 2, 3])),
(2, Vectors.sparse(3, [1], [9]))
]).toDF(["id", "features"])
df.select(ith("features", lit(1))).show()
## +-----------------+
## |ith_(features, 1)|
## +-----------------+
## | 2.0|
## | 9.0|
## +-----------------+
Explanation:
Output values have to be reserialized to equivalent Java objects. If you want to access values
(beware of SparseVectors
) you should use item
method:
v.values.item(0)
which return standard Python scalars. Similarly if you want to access all values as a dense structure:
v.toArray().tolist()
Solution 2:
If you prefer using spark.sql, you can use the follow custom function 'to_array' to convert the vector to array. Then you can manipulate it as an array.
from pyspark.sql.types import ArrayType, DoubleType
def to_array_(v):
return v.toArray().tolist()
from pyspark.sql import SQLContext
sqlContext=SQLContext(spark.sparkContext, sparkSession=spark, jsqlContext=None)
sqlContext.udf.register("to_array",to_array_, ArrayType(DoubleType()))
example
from pyspark.ml.linalg import Vectors
df = sc.parallelize([
(1, Vectors.dense([1, 2, 3])),
(2, Vectors.sparse(3, [1], [9]))
]).toDF(["id", "features"])
df.createOrReplaceTempView("tb")
spark.sql("""select * , to_array(features)[1] Second from tb """).toPandas()
output
id features Second
0 1 [1.0, 2.0, 3.0] 2.0
1 2 (0.0, 9.0, 0.0) 9.0
Solution 3:
I ran into the same problem with not being able to use explode(). One thing you can do is use VectorSlice from the pyspark.ml.feature library. Like so:
from pyspark.ml.feature import VectorSlicer
from pyspark.ml.linalg import Vectors
from pyspark.sql.types import Row
slicer = VectorSlicer(inputCol="features", outputCol="features_one", indices=[0])
output = slicer.transform(df)
output.select("features", "features_one").show()