【发布时间】:2021-02-02 19:34:01
【问题描述】:
我正在尝试创建一个图形,显示图像“重建”作为 PC 数量的函数。我想对此进行动画处理以显示原始图像、累积图像(在 PC 1、...、i 上)以及仍有待“重建”的部分。除此之外,我想将原始图像和重建图像之间的距离显示为 PC 数量的函数。
我设法创建了下图,它为底部的散点图和顶部的图像设置了动画。
问题是,一旦动画开始,右侧的两个图像“消失”,我认为它们出现在“原始图像”下
这是我的代码(创建包含所有 3 个图像和散点图的动画帧,然后形成图形):
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA
pio.templates["custom"] = go.layout.Template(
layout=go.Layout(
margin=dict(l=20, r=20, t=40, b=0)
)
)
pio.templates.default = "simple_white+custom"
class AnimationButtons():
def play_scatter(frame_duration = 500, transition_duration = 300):
return dict(label="Play", method="animate", args=
[None, {"frame": {"duration": frame_duration, "redraw": False},
"fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
def play(frame_duration = 1000, transition_duration = 0):
return dict(label="Play", method="animate", args=
[None, {"frame": {"duration": frame_duration, "redraw": True},
"mode":"immediate",
"fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
def pause():
return dict(label="Pause", method="animate", args=
[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])
pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))
img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T
reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
# Reconstruct image using the first i principal components
reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
distortion.append(np.sum((img - reconstructed) ** 2))
# Append animation frame every 5'th reconstruction
if i % 2 == 0 or i == pca.n_components_-1:
frames.append(go.Frame(
data = [px.imshow(img, binary_string=True).data[0],
px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
px.imshow(reconstructed.copy(), binary_string=True).data[0],
go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
traces = [0, 1, 2, 3],
layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))
fig = make_subplots(rows=2, cols=3,
subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)
fig.update_layout(title=frames[0]["layout"]["title"],
xaxis4=dict(range=[0, 50], autorange=False),
yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
margin = dict(t = 100),
width=800,
updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()
我尝试找到类似的问题,但找不到任何可以同时显示 px.imshow 和 go.Scatter 的内容以及子图和动画。
数据X是居中后的MNIST数字图像。这是一个带有这样一张图片的 numpy 数组:(X.shape=(16,5,5) - 16 张 5x5 的图片 - 仅在第一张图片上显示动画)
X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-1.04166667e-06],
[ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
-2.71484375e-05],
[ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
-4.69401042e-05],
[ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
-3.15950521e-04]]] )
将上述代码放在Jupyter notebook on GitHub
【问题讨论】:
-
如果你花时间去share a sample dataset,或者一个与你的数据集结构相似的数据集,那么我很确定你会得到所需的帮助。
-
@vestland - 你是对的!添加了sn -p中使用的数据
-
对我来说
X = np.array...不起作用,可能缺少逗号? -
也许
np.array2string(X, separator=',')会有所帮助。 -
@vestland - 这是一个天真的 Git 存储库,其中包含我为上面的示例运行的代码:github.com/GreenGilad/Stackoverflow.git