【发布时间】:2021-12-26 13:58:10
【问题描述】:
我正在尝试对 10 张图像进行对抗性攻击,我需要将所有受干扰的图像保存在一个文件夹中。所以,我在 pytorch 中使用了torch.utils.save_image,效果很好。我希望所有图像都保存在文件夹中,但相反,它们被覆盖,最后看到的图像是唯一保存的图像。我有以下attack() 函数,它需要单个图像来扰动
def attack(img, label, net, target=None, pixels=1, maxiter=75, popsize=400, verbose=False):
# img: 1*3*W*H tensor
# label: a number
targeted_attack = target is not None
target_calss = target if targeted_attack else label
bounds = [(0,32), (0,32), (0,255), (0,255), (0,255)] * pixels
popmul = max(1, popsize//len(bounds))
predict_fn = lambda xs: predict_classes(
xs, img, target_calss, net, target is None)
callback_fn = lambda x, convergence: attack_success(
x, img, target_calss, net, targeted_attack, verbose)
inits = np.zeros([popmul*len(bounds), len(bounds)])
count = 1
for init in inits:
for i in range(pixels):
init[i*5+0] = np.random.random()*32
init[i*5+1] = np.random.random()*32
init[i*5+2] = np.random.normal(128,127)
init[i*5+3] = np.random.normal(128,127)
init[i*5+4] = np.random.normal(128,127)
attack_result = differential_evolution(predict_fn, bounds, maxiter=maxiter, popsize=popmul,
recombination=1, atol=-1, callback=callback_fn, polish=False, init=inits)
attack_image = perturb_image(attack_result.x, img)
# attack_var = Variable(attack_image, volatile=True).cuda()
with torch.no_grad():
attack_var = attack_image.to(device)
predicted_probs = F.softmax(net(attack_var), dim=1).data.cpu().numpy()[0]
predicted_class = np.argmax(predicted_probs)
vutils.save_image(vutils.make_grid(attack_image, normalize=True, scale_each=True), 'result_img/adversarial' + str(count) + '.png')
vutils.save_image(vutils.make_grid(img, normalize=True, scale_each=True), 'result_img/original' + str(count) + '.png')
count = count + 1
if (not targeted_attack and predicted_class != label) or (targeted_attack and predicted_class == target_calss):
return 1, attack_result.x.astype(int)
return 0, [None]
下面是 attack_all() 函数,它扰乱了一批图像(整个测试集),在我的例子中是 10 张图像。
def attack_all(net, loader, pixels=1, targeted=False, maxiter=75, popsize=400, verbose=False):
correct = 0
success = 0
for batch_idx, (input, target) in enumerate(loader):
# img_var = Variable(input, volatile=True).cuda()
with torch.no_grad():
img_var = input.to(device)
target = target
prior_probs = F.softmax(net(img_var), dim=1)
_, indices = torch.max(prior_probs, 1)
if target[0] != indices.data.cpu()[0]:
continue
correct += 1
target = target.numpy()
targets = [None] if not targeted else range(10)
for target_calss in targets:
if (targeted):
if (target_calss == target[0]):
continue
flag, x = attack(input, target[0], net, target_calss, pixels=pixels, maxiter=maxiter, popsize=popsize, verbose=verbose)
success += flag
if (targeted):
success_rate = float(success)/(9*correct)
else:
success_rate = float(success)/correct
if flag == 1:
print("success rate: %.4f (%d/%d) [(x,y) = (%d,%d) and (R,G,B)=(%d,%d,%d)]"%(
success_rate, success, correct, x[0],x[1],x[2],x[3],x[4]))
if correct == args.samples:
break
return success_rate
下面是main() 类,我用attack_all() 攻击这10 个图像。我希望保存所有 10 张图像(包括原始图像和扰动图像),但只保存最后看到的图像。
def main():
print ("==> Loading data and model...")
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tranfrom_test)
test_set = Cifar10Dataset(csv_file='mydata/cifar10.csv', root_dir = 'mydata/cifar_selected_10', transform = transform_test)
testloader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=2)
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/%s.t7'%args.model)
net = checkpoint['net']
net.cuda()
cudnn.benchmark = True
print ("==> Starting attack...")
results = attack_all(net, testloader, pixels=args.pixels, targeted=args.targeted, maxiter=args.maxiter, popsize=args.popsize, verbose=args.verbose)
print ("Final success rate: %.4f"%results)
【问题讨论】:
标签: python pytorch computer-vision