首先,先看官方定义
dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1)
具体解释为:
当 dim=0 时,是对每一维度相同位置的数值进行softmax运算;
当 dim=1 时,是对某一维度的列进行softmax运算;
当 dim=2 或 -1 时,是对某一维度的行进行softmax运算;
Ref
pytorch中tf.nn.functional.softmax(x,dim = -1)对参数dim的理解