Passing a data frame column and external list to udf under withColumn
The cleanest solution is to pass additional arguments using closure:
def make_topic_word(topic_words):
return udf(lambda c: label_maker_topic(c, topic_words))
df = sc.parallelize([(["union"], )]).toDF(["tokens"])
(df.withColumn("topics", make_topic_word(keyword_list)(col("tokens")))
.show())
This doesn't require any changes in keyword_list
or the function you wrap with UDF. You can also use this method to pass an arbitrary object. This can be used to pass for example a list of sets
for efficient lookups.
If you want to use your current UDF and pass topic_words
directly you'll have to convert it to a column literal first:
from pyspark.sql.functions import array, lit
ks_lit = array(*[array(*[lit(k) for k in ks]) for ks in keyword_list])
df.withColumn("ad", topicWord(col("tokens"), ks_lit)).show()
Depending on your data and requirements there can alternative, more efficient solutions, which don't require UDFs (explode + aggregate + collapse) or lookups (hashing + vector operations).
The following works fine where any external parameter can be passed to the UDF (a tweaked code to help anyone)
topicWord=udf(lambda tkn: label_maker_topic(tkn,topic_words),StringType())
myDF=myDF.withColumn("topic_word_count",topicWord(myDF.bodyText_token))
The keyword_list
list should be broadcasted to all the nodes in the cluster if the list is big. I'm guessing zero's solution works because the list is tiny and is auto-broadcasted. It's better to explicitly broadcast in my opinion to leave no doubts (explicitly broadcasting is required for bigger lists).
keyword_list=[
['union','workers','strike','pay','rally','free','immigration',],
['farmer','plants','fruits','workers'],
['outside','field','party','clothes','fashions']]
def label_maker_topic(tokens, topic_words_broadcasted):
twt_list = []
for i in range(0, len(topic_words_broadcasted.value)):
count = 0
#print(topic_words[i])
for tkn in tokens:
if tkn in topic_words_broadcasted.value[i]:
count += 1
twt_list.append(count)
return twt_list
def make_topic_word_better(topic_words_broadcasted):
def f(c):
return label_maker_topic(c, topic_words_broadcasted)
return F.udf(f)
df = spark.createDataFrame([["union",], ["party",]]).toDF("tokens")
b = spark.sparkContext.broadcast(keyword_list)
df.withColumn("topics", make_topic_word_better(b)(F.col("tokens"))).show()
Here's what'll be outputted:
+------+---------+
|tokens| topics|
+------+---------+
| union|[0, 0, 0]|
| party|[0, 0, 0]|
+------+---------+
Note that you need to call value
to access the list that's been broadcasted (e.g. topic_words_broadcasted.value
). It's a difficult implementation, but important to master because a lot of PySpark UDFs rely on a list or dictionary that's been broadcasted.