论文阅读--MAE

Huang Zhiwei

MAE

Link: https://arxiv.org/abs/2111.06377
备注: MAE原始论文,本记录参考自视频

43、逐行讲解Masked AutoEncoder(MAE)的PyTorch代码_哔哩哔哩_bilibili

MaeMain

掩码自编码器(MAE)是一种简单的自编码方法,可以在给定部分观测值的情况下重建原始图像。

方法

掩码

效仿VIT,将图像划分为规则的不重叠的小块。然后对patch进行采样,并掩码(即删除)剩余的patch。我们的采样策略很简单:我们对随机斑块进行采样,不进行替换,遵循均匀分布。

文中解释了选取较高的掩码率的原因:去除冗余信息,如果掩码率较小,那么编码器很容易就从大量冗余信息中还原出原图像,但模型需要编码器学习更好的特征,所以掩码率会比较高。

MAE encoder

编码器就是VIT的编码器,它的输入是采样得到的25%原来的patch,将他们打上位置编码就可以送入transformer块,输出也将是同样的尺寸。

MAE decoder

解码器的输出是完整的patch序列,解码器的输入是可见的编码器输出和不可见的token。这些不可见的token也是可学习的。这些token会被加上原来位置的位置编码,如果不加就缺失了位置信息。

解码器仅仅在预训练阶段有用,在下游任务上有用的仅仅是编码器。编解码器可以独立设计。

重构目标

MAE通过预测每个被屏蔽token的像素值来重建输入,解码器输出中的每个元素都是代表一个patch的像素值向量。所以重构损失就是对应像素的mean squared error (MSE),这个损失只对被掩码的token做,未被掩码的token不需要计算该损失。

作者讨论了是否对每个patch做均值方差归一化后再做损失,其结论是做了归一化后效果会更好。

关于实现

  • 首先为每个小patch生成一个embedding,然后加上对应的位置编码;
  • 然后用shuffle操作打散,根据掩码率取前25%的patch;
  • 送入Vit后,将输出的patch根据unshuffle操作还原到原始序列的对应位置上去;
  • 对masked token用一个特殊的token代替,然后给完整的序列加上位置编码,送入解码器;

代码解读

模型

先看models_mae.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()

# --------------------------------------------------------------------------
# MAE encoder specifics
# patchembedding
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# class token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置编码,长度是patch个数加class token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
# 重复模块depth次
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------

# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding

self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
# 解码器输出归一化
self.decoder_norm = norm_layer(decoder_embed_dim)
# 线性层,输出是patch面积*3通道
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------

self.norm_pix_loss = norm_pix_loss

self.initialize_weights()

初始化权重的函数initialize_weights_init_weights主要是对位置编码做了固定参数的初始化(无梯度)、patch embedding做了均值初始化、对每一层的线形层、归一化层做了平均初始化。

patchify函数和unpatchify函数做了将图像分patch和拼接成图像的操作,random_masking函数就是随机掩码,其中用到了如下函数:

1
2
3
4
5
6
7
# 生成和patch尺寸一样的随机张量
noise = torch.rand(N, L, device=x.device)

# sort noise for each sample
# 根据第一个维度排序(sequence length维度),那么这个顺序就是打乱的顺序,相应的将这个排序的索引再排序就是还原的顺序了。
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
1
2
3
4
5
6
#随机生成的张量
tensor([0.6141, 0.9232, 0.5108, 0.0047, 0.6740])
#张量排序
tensor([3, 2, 0, 4, 1])
#第二个位置的元素是原序列的第0个,第4个位置的元素是原序列的第1个。。。
tensor([2, 4, 1, 0, 3])

随后根据torch.gather和ids_keep即序列维度上需要被取的那些索引得到未被掩码的序列x_masked ;同理得到mask 表示被掩码的token;

1
2
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
mask = torch.gather(mask, dim=1, index=ids_restore)

剩余的forward_encoderforward_decoder 没特别大的难度,略过。

forward_loss 损失中就是普通的mse loss,但是如果要做归一化的话,会进入分支先计算patch的均值和方差。此外由于只预测被掩码的部分,所以需要乘上掩码再计算损失。

1
loss = (loss * mask).sum() / mask.sum()

训练

训练的代码支持cpu、单卡、多卡,已经非常完美了,我们要训练自己的数据集只需要修改data_path变量即可。训练代码的前153行都是在初始化和准备数据加载器。

代码使用了ddp包裹模型,在160行:

1
model_without_ddp = model

并在174行做了检验:

1
2
3
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module

训练代码的主体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
for epoch in range(args.start_epoch, args.epochs):
#分布式训练要确保每张卡分到的数据不一样
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
#外层循环只有epoch,还有一层循环即一次epoch遍历数据的循环写在了train_one_epoch中
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
# misc库的保存模型函数
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)

log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}

if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")

train_one_epoch函数中的训练主体如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 由于是自监督学习,所以(samples, _)即样本和标签中的标签是不需要的
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

samples = samples.to(device, non_blocking=True)
# 混合精度
with torch.cuda.amp.autocast():
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)

loss_value = loss.item()

if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
# 判断是否需要参数更新,并做参数更新
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()

torch.cuda.synchronize()

metric_logger.update(loss=loss_value)

lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)

loss_value_reduce = misc.all_reduce_mean(loss_value)
# 日志相关
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)

根据视频所说,如果使用最新的timm需要将models_mae.py中的qk_scale注释掉。

微调

微调中的大部分和训练是一样的,它的数据集加载器使用了build_dataset 函数,其中做了额外的图像强增强。

在218行使用了mixup增强,是一种在分类任务上有效的增强方式。

1
2
3
4
5
6
7
8
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)

在导入模型时使用的vit写法和models_vit不一样但是参数上是一样的,且在导入参数后确保新加的两层分类层参数未被导入。

1
2
3
4
5
6
7
8
9
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)

if args.global_pool:
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
else:
assert set(msg.missing_keys) == {'head.weight', 'head.bias'}

分类损失的选择上也做了一些判断:

1
2
3
4
5
6
7
if mixup_fn is not None:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
  • 标题: 论文阅读--MAE
  • 作者: Huang Zhiwei
  • 创建于: 2023-05-03 22:57:32
  • 更新于: 2023-09-02 23:06:24
  • 链接: https://huangzhw0221.github.io/2023/05/03/Paper-MAE/
  • 版权声明: 本文章采用 CC BY-NC-SA 4.0 进行许可。
 评论