【发布时间】:2021-03-31 05:35:01
【问题描述】:
在 Colab 中运行以下代码。运行两个单独的实例(单独的文件)。在一种情况下,该代码有效,在另一种情况下,它不起作用。在它不起作用的情况下,函数 np.cumsum() 似乎返回的数组长度是输入数组的两倍,这会创建一个 ValueError: "operands could not be broadcast together with shapes (2000,) (1000 ,)”。
无法弄清楚它为什么会发生,或者它是如何可能的。也无法在网上找到任何答案(甚至是同一问题的类似实例),所以任何帮助将不胜感激!!!
'''
def b2_run_advanced_strategies_experiment(env_name='BanditTwoArmedUniform-v0'):
results = {}
experiments = [
# baseline strategies
lambda env: pure_exploitation(env, N_Episodes),
lambda env: pure_exploration(env, N_Episodes),
]
for env_seed in tqdm(SEEDS, desc='All experiments'):
env = gym.make(env_name, seed=env_seed) ; env.reset()
true_Q = np.array(env.env.p_dist * env.env.r_dist)
opt_V = np.max(true_Q)
for seed in tqdm(SEEDS, desc='All environments', leave=False):
for experiment in tqdm(experiments,
desc='Experiments with seed {}'.format(seed),
leave=False):
env.seed(seed) ; np.random.seed(seed) ; random.seed(seed)
name, Re, Qe, Ae = experiment(env)
Ae = np.expand_dims(Ae, -1)
print("len of Re=",len(Re)) # RESULT GIVES 1000
print("len of cumsum= ",len(np.cumsum(Re))) # RESULT GIVES 2000, HOW IS THAT POSSIBLE???
episode_mean_rew = np.cumsum(Re) / (np.arange(len(Re)) + 1) # ERROR ON THIS LINE
Q_selected = np.take_along_axis(
np.tile(true_Q, Ae.shape), Ae, axis=1).squeeze()
regret = opt_V - Q_selected
cum_regret = np.cumsum(regret)
if name not in results.keys(): results[name] = {}
if 'Re' not in results[name].keys(): results[name]['Re'] = []
if 'Qe' not in results[name].keys(): results[name]['Qe'] = []
if 'Ae' not in results[name].keys(): results[name]['Ae'] = []
if 'cum_regret' not in results[name].keys():
results[name]['cum_regret'] = []
if 'episode_mean_rew' not in results[name].keys():
results[name]['episode_mean_rew'] = []
results[name]['Re'].append(Re)
results[name]['Qe'].append(Qe)
results[name]['Ae'].append(Ae)
results[name]['cum_regret'].append(cum_regret)
results[name]['episode_mean_rew'].append(episode_mean_rew)
return results
b2_results_a = b2_run_advanced_strategies_experiment()
'''
【问题讨论】:
-
shape比len提供更多信息