【发布时间】:2021-01-07 20:42:02
【问题描述】:
我想通过pyspark在spark数据帧的向量列中找到最大值的索引。
我的火花是
3.0.0
df:
id val (vector (nullable = true))
516 0: 1 1: 10 2: [] 3:[0.162, 0.511, 0.022, ....]
这是一个稀疏向量吗? 如何访问数组?
[0.162, 0.511, 0.022, ....]
基于How to find the index of the maximum value in a vector column?、How to get the index of the highest value in a list per row in a Spark DataFrame? [PySpark]、How to find the argmax of a vector in PySpark ML
它看起来像一个密集的向量? 我的代码:
import pyspark.sql.functions as F
from pyspark.ml.functions import vector_to_array
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import vector_to_array
def max_index(a_col):
if not a_col:
return a_col
if isinstance(a_col, SparseVector):
a_col = DenseVector(a_col)
a_col = vector_to_array(a_col)
return np.argmax(a_col)
my_f = F.udf(max_index, IntegerType())
t = df.withColumn("max_index_col", my_f("val")) # this returned a None type because ""max_index" did not work.
t.show()
错误:
AttributeError: 'NoneType' object has no attribute '_jvm'
我已经尝试了上述链接中提到的所有解决方案。但是,它们都不起作用。
我错过了什么吗?
谢谢
更新,我也试过了:
vec_to_array = F.udf(lambda v: v.toArray().tolist(), ArrayType(FloatType()))
def find_max_index(v):
return F.array_position(v, F.array_max(v))
t = df.withColumn("array_col", vec_to_array(F.col("features")))
t.withColumn("max_index", find_max_index(F.col("array_col"))).show(truncate=False)
同样的错误。
【问题讨论】: