maskrcnn_benchmark训练过程

->训练命令:

python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1

->调用train_net.py,在train()函数中建立模型,优化器,dataloader,checkpointerd等,进入trainer.py核心训练代码:

def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)
        ipdb.set_trace()
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == (max_iter - 1):
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )
        if iteration % checkpoint_period == 0 and iteration > 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)

    checkpointer.save("model_{:07d}".format(iteration), **arguments)
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )

->输出一次迭代,变量过程,target为batch=2的gt图像:

ipdb> loss_dict
{'loss_box_reg': tensor(0.1005, device='cuda:0', grad_fn=<DivBackward0>), 'loss_rpn_box_reg': tensor(0.0486, device='cuda:0', grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.0165, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>), 'loss_classifier': tensor(0.2494, device='cuda:0', grad_fn=<NllLossBackward>), 'loss_mask': tensor(0.2332, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}
ipdb> images
<maskrcnn_benchmark.structures.image_list.ImageList object at 0x7f9cb9190668>
ipdb> targets
[BoxList(num_boxes=3, image_width=1066, image_height=800, mode=xyxy), BoxList(num_boxes=17, image_width=1199, image_height=800, mode=xyxy)]

进入model内部进行:

->在generalized_rcnn.py中经过backbone网络提取特征feature:features = self.backbone(images.tensors)

ipdb> features[0].size()
torch.Size([2, 256, 200, 336])
ipdb> features[1].size()
torch.Size([2, 256, 100, 168])
ipdb> features[2].size()
torch.Size([2, 256, 50, 84])
ipdb> features[3].size()
torch.Size([2, 256, 25, 42])
ipdb> features[4].size()
torch.Size([2, 256, 13, 21])

RNP网络

->proposals, proposal_losses = self.rpn(images, features, targets)

    def forward(self, images, features, targets=None):
        """
        Arguments:
            images (ImageList): images for which we want to compute the predictions
            features (list[Tensor]): features computed from the images that are
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
            targets (list[BoxList): ground-truth boxes present in the image (optional)

        Returns:
            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
                image.
            losses (dict[Tensor]): the losses for the model during training. During
                testing, it is an empty dict.
        """
        objectness, rpn_box_regression = self.head(features)
        anchors = self.anchor_generator(images, features)

        if self.training:
            return self._forward_train(anchors, objectness, rpn_box_regression, targets)
        else:
            return self._forward_test(anchors, objectness, rpn_box_regression)

def _forward_train(self, anchors, objectness, rpn_box_regression, targets):
if self.cfg.MODEL.RPN_ONLY:
# When training an RPN-only model, the loss is determined by the
# predicted objectness and rpn_box_regression values and there is
# no need to transform the anchors into predicted boxes; this is an
# optimization that avoids the unnecessary transformation.
boxes = anchors
else:
# For end-to-end models, anchors must be transformed into boxes and
# sampled into a training batch.
with torch.no_grad():
boxes = self.box_selector_train(
anchors, objectness, rpn_box_regression, targets
)
loss_objectness, loss_rpn_box_reg = self.loss_evaluator(
anchors, objectness, rpn_box_regression, targets
)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
return boxes, losses

->首先所有feature通过rpn_head网络(3×3+1×1分类与回归)得到结果;然后和生成的anchor进行算loss

->objectness, rpn_box_regression = self.head(features)返回5个stage下回归和分类的结果,每个等级3个anchor

ipdb> objectness[0].size()
torch.Size([2, 3, 200, 336]) =200*336*3=201600
ipdb> objectness[1].size()
torch.Size([2, 3, 100, 168])
ipdb> objectness[2].size()
torch.Size([2, 3, 50, 84])
ipdb> objectness[3].size()
torch.Size([2, 3, 25, 42])
ipdb> objectness[4].size()
torch.Size([2, 3, 13, 21])
ipdb> objectness[5].size()
*** IndexError: list index out of range
ipdb> rpn_box_regression[0].size()
torch.Size([2, 12, 200, 336])
ipdb> rpn_box_regression[4].size()
torch.Size([2, 12, 13, 21])

-> anchors = self.anchor_generator(images, features)生成anchor

ipdb> anchors[1][0]
BoxList(num_boxes=201600, image_width=1204, image_height=800, mode=xyxy)
ipdb> anchors[1][1]
BoxList(num_boxes=50400, image_width=1204, image_height=800, mode=xyxy)
ipdb> anchors[0][1]
BoxList(num_boxes=50400, image_width=1333, image_height=794, mode=xyxy)
ipdb> anchors[1][2]
BoxList(num_boxes=12600, image_width=1204, image_height=800, mode=xyxy)
ipdb> anchors[1][3]
BoxList(num_boxes=3150, image_width=1204, image_height=800, mode=xyxy)
ipdb> anchors[1][4]
BoxList(num_boxes=819, image_width=1204, image_height=800, mode=xyxy)

 ->boxes = self.box_selector_train(anchors, objectness, rpn_box_regression, targets)选择boxes去训练fast rcnn,这一步不需要梯度更新

ipdb> boxes
[BoxList(num_boxes=316, image_width=1333, image_height=794, mode=xyxy), BoxList(num_boxes=1696, image_width=1204, image_height=800, mode=xyxy)]

 -> loss_objectness, loss_rpn_box_reg = self.loss_evaluator(anchors, objectness, rpn_box_regression, targets) 算loss时候选择正负1:1的anchor进行训练rpn网络

->这里选择512个样本训练;_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256;两张图像

ipdb> sampled_pos_inds
tensor([ 16477,  16480,  16483,  16486,  17485,  17488,  17491,  17494,  18493,
         18496,  18499,  18502,  19501,  19504,  19507,  19510, 217452, 217453,
        217455, 217456, 217458, 217459, 217960, 268151, 529150, 534017, 534020,
        534143, 534146, 534586, 534607, 534712, 534733, 534838, 534859, 535356,
        535359, 535362, 535365, 535368, 536602, 536652, 536655, 536658, 536661,
        536664, 536667, 536670, 536715, 536718, 536721, 536724, 536727, 536730,
        536733, 536778, 536781, 536784, 536787, 536790, 536793, 536796, 536841,
        536844, 536847, 536850, 536853, 536856, 536859], device='cuda:0')
ipdb> sampled_neg_inds
tensor([  3045,   4275,   5323,   6555,   7538,   8406,   8469,   9761,  11316,
         11684,  12319,  13195,  13354,  15405,  20431,  25105,  26405,  26786,
         27324,  30698,  33503,  38168,  39244,  40064,  40535,  41046,  41162,
         41203,  41864,  43170,  44060,  44416,  44905,  45161,  47299,  48043,
         49890,  49900,  50992,  51248,  52082,  52236,  52371,  52568,  54079,
         54207,  55251,  56973,  57135,  58376,  59816,  61509,  62473,  62942,
         64722,  65548,  66681,  67925,  68650,  71368,  72610,  73268,  74727,
         75655,  77795,  78937,  79115,  80101,  80808,  81001,  83846,  87064,
         89891,  91207,  92579,  92771,  93113,  94118,  94526,  94586,  95822,
         96850,  97256,  97303,  97500,  98194,  98338, 101724, 102082, 103835,
        103947, 104678, 105168, 105630, 106132, 108751, 108933, 109684, 110552,
        111373, 111965, 114691, 114736, 115213, 115468, 120710, 121785, 123138,
        126383, 126957, 128197, 128282, 129449, 130472, 132269, 133131, 133384,
        135197, 135926, 136468, 137306, 137620, 138671, 141848, 142643, 145618,
        147402, 148283, 148353, 149313, 150389, 150528, 151949, 154413, 156156,
        157155, 158716, 160001, 160227, 160428, 160496, 160920, 161023, 162605,
        163131, 166371, 166561, 167200, 171280, 174531, 175690, 175957, 175996,
        179025, 179766, 180781, 182893, 182980, 183152, 183159, 183531, 183785,
        184531, 185565, 186520, 187194, 187772, 188100, 191068, 191289, 191419,
        192022, 193388, 194892, 196902, 204682, 206878, 207981, 208066, 208366,
        210761, 210862, 211624, 213567, 213627, 214601, 214651, 214770, 215032,
        216806, 218299, 220127, 220221, 221133, 222489, 223512, 224844, 225115,
        225225, 225337, 228044, 228580, 228691, 229787, 231390, 231405, 231666,
        233068, 233379, 233416, 234464, 236145, 238078, 239161, 239633, 240260,
        240492, 241033, 241702, 241758, 242546, 243372, 244102, 248078, 248632,
        255377, 256325, 257079, 258010, 259857, 260872, 261896, 271659, 274495,
        275822, 276450, 276728, 278865, 279179, 279338, 279735, 280208, 280216,
        282300, 283240, 283717, 285074, 285157, 287528, 287804, 288191, 289901,
        290179, 294877, 296999, 298420, 301631, 301890, 303575, 304982, 305983,
        305992, 307922, 312438, 313507, 314289, 316348, 318599, 319751, 321304,
        321735, 321748, 326308, 326315, 327131, 327290, 327671, 328439, 332674,
        333130, 333144, 334633, 336337, 337399, 340980, 341619, 347289, 347364,
        347579, 353057, 353309, 354001, 355039, 355271, 355597, 356617, 359064,
        359068, 360402, 362098, 362652, 363356, 363741, 364744, 365997, 370109,
        370949, 372977, 373248, 373992, 374786, 375293, 376785, 377661, 377761,
        378991, 379663, 380167, 380817, 382269, 383560, 387387, 388389, 389665,
        389862, 390138, 391941, 394183, 399113, 400423, 402411, 404907, 405436,
        406457, 407348, 408005, 408356, 409728, 411376, 411571, 412210, 412426,
        415363, 415453, 415601, 418159, 418174, 418928, 419064, 419394, 419783,
        421039, 421405, 423287, 426369, 429895, 430293, 431338, 432330, 432745,
        433529, 433699, 433738, 435389, 437567, 438410, 439164, 440481, 442532,
        445424, 446074, 446146, 446550, 447703, 449683, 450601, 451138, 452505,
        455922, 457464, 460557, 461150, 461431, 462641, 463544, 471945, 472032,
        473327, 474938, 475450, 477505, 477917, 478033, 479038, 480127, 481613,
        482384, 484433, 484542, 484556, 484588, 487380, 490897, 492173, 493279,
        493464, 494139, 498077, 498172, 498426, 499201, 500289, 500739, 503145,
        506227, 506661, 509266, 509355, 509382, 509556, 510331, 510346, 511426,
        511604, 512428, 512560, 513306, 514096, 515320, 516682, 516949, 517815,
        517984, 524421, 525174, 525384, 525697, 526692, 527047, 527576, 532272,
        535005, 535582], device='cuda:0')
ipdb> sampled_pos_inds.size()
torch.Size([69])
ipdb> sampled_neg_inds.size()
torch.Size([443])
View Code

相关文章:

  • 2022-12-23
  • 2022-01-13
  • 2021-12-04
  • 2021-06-22
  • 2021-06-20
  • 2022-01-04
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2021-05-30
  • 2021-06-17
  • 2021-12-05
  • 2022-01-23
  • 2021-07-13
  • 2022-01-15
  • 2022-12-23
相关资源
相似解决方案