MAE Link: https://arxiv.org/abs/2111.06377 备注: MAE原始论文,本记录参考自视频
43、逐行讲解Masked AutoEncoder(MAE)的PyTorch代码_哔哩哔哩_bilibili
掩码自编码器(MAE)是一种简单的自编码方法,可以在给定部分观测值的情况下重建原始图像。
方法 掩码 效仿VIT,将图像划分为规则的不重叠的小块。然后对patch进行采样,并掩码(即删除)剩余的patch。我们的采样策略很简单:我们对随机斑块进行采样,不进行替换,遵循均匀分布。
文中解释了选取较高的掩码率的原因:去除冗余信息,如果掩码率较小,那么编码器很容易就从大量冗余信息中还原出原图像,但模型需要编码器学习更好的特征,所以掩码率会比较高。
MAE encoder 编码器就是VIT的编码器,它的输入是采样得到的25%原来的patch,将他们打上位置编码就可以送入transformer块,输出也将是同样的尺寸。
MAE decoder 解码器的输出是完整的patch序列,解码器的输入是可见的编码器输出和不可见的token。这些不可见的token也是可学习的。这些token会被加上原来位置的位置编码,如果不加就缺失了位置信息。
解码器仅仅在预训练阶段有用,在下游任务上有用的仅仅是编码器。编解码器可以独立设计。
重构目标 MAE通过预测每个被屏蔽token的像素值来重建输入,解码器输出中的每个元素都是代表一个patch的像素值向量。所以重构损失就是对应像素的mean squared error (MSE),这个损失只对被掩码的token做,未被掩码的token不需要计算该损失。
作者讨论了是否对每个patch做均值方差归一化后再做损失,其结论是做了归一化后效果会更好。
关于实现
首先为每个小patch生成一个embedding,然后加上对应的位置编码;
然后用shuffle操作打散,根据掩码率取前25%的patch;
送入Vit后,将输出的patch根据unshuffle操作还原到原始序列的对应位置上去;
对masked token用一个特殊的token代替,然后给完整的序列加上位置编码,送入解码器;
代码解读 模型 先看models_mae.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 def __init__ (self, img_size=224 , patch_size=16 , in_chans=3 , embed_dim=1024 , depth=24 , num_heads=16 , decoder_embed_dim=512 , decoder_depth=8 , decoder_num_heads=16 , mlp_ratio=4. , norm_layer=nn.LayerNorm, norm_pix_loss=False ): super ().__init__() self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1 , 1 , embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1 , num_patches + 1 , embed_dim), requires_grad=False ) self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True , qk_scale=None , norm_layer=norm_layer) for i in range (depth)]) self.norm = norm_layer(embed_dim) self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True ) self.mask_token = nn.Parameter(torch.zeros(1 , 1 , decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1 , num_patches + 1 , decoder_embed_dim), requires_grad=False ) self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True , qk_scale=None , norm_layer=norm_layer) for i in range (decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True ) self.norm_pix_loss = norm_pix_loss self.initialize_weights()
初始化权重的函数initialize_weights
和_init_weights
主要是对位置编码做了固定参数的初始化(无梯度)、patch embedding做了均值初始化、对每一层的线形层、归一化层做了平均初始化。
patchify
函数和unpatchify
函数做了将图像分patch和拼接成图像的操作,random_masking
函数就是随机掩码,其中用到了如下函数:
1 2 3 4 5 6 7 noise = torch.rand(N, L, device=x.device) ids_shuffle = torch.argsort(noise, dim=1 ) ids_restore = torch.argsort(ids_shuffle, dim=1 )
1 2 3 4 5 6 tensor([0.6141 , 0.9232 , 0.5108 , 0.0047 , 0.6740 ]) tensor([3 , 2 , 0 , 4 , 1 ]) tensor([2 , 4 , 1 , 0 , 3 ])
随后根据torch.gather和ids_keep即序列维度上需要被取的那些索引得到未被掩码的序列x_masked ;同理得到mask 表示被掩码的token;
1 2 x_masked = torch.gather(x, dim=1 , index=ids_keep.unsqueeze(-1 ).repeat(1 , 1 , D)) mask = torch.gather(mask, dim=1 , index=ids_restore)
剩余的forward_encoder
和 forward_decoder
没特别大的难度,略过。
forward_loss
损失中就是普通的mse loss,但是如果要做归一化的话,会进入分支先计算patch的均值和方差。此外由于只预测被掩码的部分,所以需要乘上掩码再计算损失。
1 loss = (loss * mask).sum () / mask.sum ()
训练 训练的代码支持cpu、单卡、多卡,已经非常完美了,我们要训练自己的数据集只需要修改data_path变量即可。训练代码的前153行都是在初始化和准备数据加载器。
代码使用了ddp包裹模型,在160行:
1 model_without_ddp = model
并在174行做了检验:
1 2 3 if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True ) model_without_ddp = model.module
训练代码的主体如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 for epoch in range (args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args ) if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) log_stats = {**{f'train_{k} ' : v for k, v in train_stats.items()}, 'epoch' : epoch,} if args.output_dir and misc.is_main_process(): if log_writer is not None : log_writer.flush() with open (os.path.join(args.output_dir, "log.txt" ), mode="a" , encoding="utf-8" ) as f: f.write(json.dumps(log_stats) + "\n" )
train_one_epoch函数中的训练主体如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 for data_iter_step, (samples, _) in enumerate (metric_logger.log_every(data_loader, print_freq, header)):if data_iter_step % accum_iter == 0 : lr_sched.adjust_learning_rate(optimizer, data_iter_step / len (data_loader) + epoch, args) samples = samples.to(device, non_blocking=True ) with torch.cuda.amp.autocast(): loss, _, _ = model(samples, mask_ratio=args.mask_ratio) loss_value = loss.item() if not math.isfinite(loss_value): print ("Loss is {}, stopping training" .format (loss_value)) sys.exit(1 ) loss /= accum_iter loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1 ) % accum_iter == 0 ) if (data_iter_step + 1 ) % accum_iter == 0 : optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) lr = optimizer.param_groups[0 ]["lr" ] metric_logger.update(lr=lr) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and (data_iter_step + 1 ) % accum_iter == 0 : """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int ((data_iter_step / len (data_loader) + epoch) * 1000 ) log_writer.add_scalar('train_loss' , loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr' , lr, epoch_1000x)
根据视频所说,如果使用最新的timm需要将models_mae.py中的qk_scale注释掉。
微调 微调中的大部分和训练是一样的,它的数据集加载器使用了build_dataset
函数,其中做了额外的图像强增强。
在218行使用了mixup增强,是一种在分类任务上有效的增强方式。
1 2 3 4 5 6 7 8 mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: print ("Mixup is activated!" ) mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes)
在导入模型时使用的vit写法和models_vit不一样但是参数上是一样的,且在导入参数后确保新加的两层分类层参数未被导入。
1 2 3 4 5 6 7 8 9 msg = model.load_state_dict(checkpoint_model, strict=False ) print (msg)if args.global_pool: assert set (msg.missing_keys) == {'head.weight' , 'head.bias' , 'fc_norm.weight' , 'fc_norm.bias' } else : assert set (msg.missing_keys) == {'head.weight' , 'head.bias' }
分类损失的选择上也做了一些判断:
1 2 3 4 5 6 7 if mixup_fn is not None :criterion = SoftTargetCrossEntropy() elif args.smoothing > 0. : criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else : criterion = torch.nn.CrossEntropyLoss()