【发布时间】:2021-10-08 03:38:06
【问题描述】:
我正在尝试实现 MI 算法。 这是我的代码:
import numpy as np
from copy import copy
from sklearn import metrics
from sklearn import preprocessing
from sklearn.datasets import load_digits
data, labels = load_digits(return_X_y=True) # labels is Y vector #data shape = (1797, 64), and labels shape is (1797,1)
(n_samples, n_features), n_digits = data.shape, np.unique(labels).size
#normalize the data and set into DataFrame
scaler = preprocessing.StandardScaler()
d = scaler.fit_transform(data)
现在这里是我的 MI 实现:
def mi_algo(_data, _labels, size, defualt_x = 17, defualt_y = 10):
theta = np.zeros(_data.shape[1])
x_len = y_len = _data.shape[0]
py = np.array([len(_labels[_labels==y_val])/y_len for y_val in range(defualt_y)]) #P(y)
for col in range(len(theta)):
temp = np.copy(_data[:,col])
px = np.array([len(temp[temp==x_val])/x_len for x_val in range(defualt_x)]) #P(x)
for x in range(defualt_x):
if px[x] == 0:
continue
for y in range(defualt_y):
if py[y] == 0:
continue
pxy = np.sum((temp == x) & (labels == y)) #P(x,y)
pxy = np.divide(pxy,x_len)
yx = np.multiply(px[x],py[y])
pxy = np.divide(pxy, yx)
log = np.log2(pxy)
theta[col] += np.multiply(pxy,log)
return theta
我使用了所有的 np,因为我遇到了一些错误。 这是输出:
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:23: RuntimeWarning: divide by zero encountered in log2
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:24: RuntimeWarning: invalid value encountered in multiply
[ 0. nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan 0. nan nan nan
nan nan nan 0. nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan]
现在,我知道有一些除以零,但我无法弄清楚。谢谢你的帮助!
【问题讨论】:
-
你可以通过尝试提出一个更有针对性的问题来帮助别人帮助你;就目前而言,一个人唯一能做的就是运行你的代码,看看会发生什么——期望这样是不合理的。有几件事让你看看。一个是 MI = log(something)/log(something else) 的平均值——是否有任何项等于零?另一个是,尝试一个你知道答案的更简单的例子。例如。尝试通过构造 MI = 0 的示例。 MI的最大值是多少?尝试为此构建一个示例。
-
@RobertDodier thnx,我没想到有人会运行我的代码,只是稍微回顾一下。正如警告所说,我知道日志中有一些零,但是我运行的所有测试,所有参数都不是零
-
您是否尝试过在您的 yx 中添加一个 epsilon 值来查看它是否会使 NaN 消失?例如像 1e-8 这样的东西。