【发布时间】:2020-12-10 21:10:49
【问题描述】:
我的代码有一点问题,我正在尝试制作成对的测试图像,以便以后在少数镜头学习中进行比较。我的代码如下:
def make_oneshot_task(N, s="val", language=None):
"""Create pairs of test image, support set for testing N way one-shot learning. """
if s == 'train':
X = Xtrain
X= X.reshape(-1,11,100,100,3)
categories = train_classes
else:
X = X_val
X= X.reshape(-1,4,100,100,3)
categories = val_classes
n_examples, n_classes, w, h, chan = X.shape
#n_samples, n_examples, w, h = X.shape
n_examples = 40
n_classes = 11
indices = rng.randint(0, n_examples,size=(N,))
if language is not None: # if language is specified, select characters for that language
low, high = categories[language]
if N > high - low:
raise ValueError("This language ({}) has less than {} letters".format(language, N))
categories = rng.choice(range(low,high),size=(N,),replace=True)
else: # if no language specified just pick a bunch of random letters
categories = rng.choice(range(n_classes),size=(N,),replace=True)
true_category = categories[0]
ex1, ex2 = rng.choice(n_examples,replace=True,size=(2,))
test_image = np.asarray([X[true_category,ex1,:,:,:]]*N).reshape(N, w, h,3)
support_set = X[categories,indices,:,:,:]
support_set[0,:,:] = X[true_category,ex2]
support_set = support_set.reshape(N, w, h,3)
targets = np.zeros((N,))
targets[0] = 1
targets, test_image, support_set = shuffle(targets, test_image, support_set)
pairs = [test_image,support_set]
return pairs, targets
我主要担心test_image = np.asarray([X[true_category,ex1,:,:,:]]*N).reshape(N, w, h,3) 出现错误。这是回溯:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-38-c6b9166704d8> in <module>
----> 1 make_oneshot_task(11)
<ipython-input-37-67c1a13297ca> in make_oneshot_task(N, s, language)
27 true_category = categories[0]
28 ex1, ex2 = rng.choice(n_examples,replace=True,size=(2,))
---> 29 test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N, w, h,3)
30 support_set = X[categories,indices,:,:,:]
31 support_set[0,:,:] = X[true_category,ex2]
IndexError: index 18 is out of bounds for axis 1 with size 4
谁能帮我看看怎么回事?
【问题讨论】:
标签: python tensorflow keras