0%

VGGT和VGGsfm论文阅读

VGGT和VGGsfm

VGGT是目前比较新的sfm方法,和dust3r、mast3r非常类似,最近出的一个效果很惊艳的sfm方法,其前身是VGGsfm。我们之前已经解析过dust3r和mast3r,他们几乎完全抛弃了所有的几何方法来进行重建,从特征提取到恢复3D点云全是都是网络输出的,但它也有局限性,就是一次只能对两张图像进行重建,超过两张图像则是通过后处理的方式来实现配准。

与之相比,VGGsfm则不同,它同样具有特征提取、匹配、三角化、BA等传统sfm流程,也同样保留很多基于几何的方法,比如8点法、DLT之类的,只不过在这些过程中引入网络来一步到位,并且网络的输入还会嵌入几何估计的结果。

VGGT虽然论文里说是基于VGGsfm的,但实际上两者的网络结构和流程都有很大的不同,VGGT是一个完全基于网络的sfm方法,更像是dust3r的近亲,不过它本身的架构就支持一次处理多张图像,而不是像dust3r那样一次处理一对图像。

周末正好没事干,顺手写下文章,就先来看看VGGsfm,然后再看VGGT。

VGGsfm

前面也提到了,VGGsfm和传统基于几何的sfm类似,具有特征提取、匹配、三角化、BA这些步骤,只是每个步骤都引入网络最终实现end2end。

下图是VGGsfm的总体架构,蓝颜色的框是包含网络的组件,tracker负责在输入图像之间进行匹配,然后分别使用两个head来估计camera的位姿和深度图,最后使用一个可微分BA来完成优化。

特征提取与匹配

特征提取非常简单,文中没提,代码里是用sp、sift和aliked几种方法来提取的。但是注意只是用来提取特征点的位置,而不是用这些方法计算出来的特征描述子(而且显然也没法混用描述子是不是)

我们着重看一下匹配,VGGsfm的tracker模块和用于在视频里光流估计方法类似,但是由于sfm输入的图像序列不一定是时间连续的,所以缺少了时间上的信息。

对于第张图像,假设提取的特征点为,首先使用不同大小的CNN来计算的特征子图,然后用双线性插值到同一维度上去最终得到图像的特征图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class BasicEncoder(nn.Module):
def forward(self, x):
# x是输入图像
_, _, H, W = x.shape

x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)

# a b c d 分别对应了不同尺度的特征图
# layer1-4分别是带有残差连接的两个卷积层
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)

# 双线性插值到同一维度上去
a = _bilinear_intepolate(a, self.stride, H, W)
b = _bilinear_intepolate(b, self.stride, H, W)
c = _bilinear_intepolate(c, self.stride, H, W)
d = _bilinear_intepolate(d, self.stride, H, W)

# 拼接然后输出特征图
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x

def _bilinear_intepolate(x, stride, H, W):
return F.interpolate(
x, (H // stride, W // stride), mode="bilinear", align_corners=True
)

拿到特征点以及对应的特征图后就可以正式开始匹配了,首先需要把特征点坐标放缩到特征图上去。由于输入的图像没有时间上的先后,所以会同时处理所有的图像来进行匹配。

首先,对于计算好的特征图集合张量{}进行降采样,从而得到不同尺度下的特征张量金字塔;然后对于所有待查询的特征点,计算以其为中心的局部窗口的坐标,考虑金字塔层级尺度不同引入额外的补偿区域,最终得到待查询区域的坐标;接着对于每个特征点,从特征图中采样,以获取局部窗口的特征值,最后计算目标特征与采样特征之间的相关性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# 计算相关性的代码
class EfficientCorrBlock:
def __init__(self, fmaps, num_levels=4, radius=4):
B, S, C, H, W = fmaps.shape
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.fmaps_pyramid.append(fmaps)
# 降采样计算特征张量金字塔
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)

def sample(self, coords, target):
r = self.radius
device = coords.device
B, S, N, D = coords.shape
assert D == 2
target = target.permute(0, 1, 3, 2).unsqueeze(-1)

out_pyramid = []

# 金字塔每层都会计算相关性
for i in range(self.num_levels):
pyramid = self.fmaps_pyramid[i]
C, H, W = pyramid.shape[2:]
# 计算以查询坐标为中心的局部窗口的坐标
centroid_lvl = (
torch.cat(
[torch.zeros_like(coords[..., :1], device=device), coords],
dim=-1,
).reshape(B * S, N, 1, 1, 3)
/ 2**i
)

# 考虑金字塔层级尺度不同引入额外的补偿区域
dx = torch.linspace(-r, r, 2 * r + 1, device=device)
dy = torch.linspace(-r, r, 2 * r + 1, device=device)
xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
zgrid = torch.zeros_like(xgrid, device=device)
delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)

# 最终待查询区域
coords_lvl = centroid_lvl + delta_lvl

# 双线性插值从特征图中采样,以获取局部窗口的特征值
pyramid_sample = bilinear_sampler(
pyramid.reshape(B * S, C, 1, H, W), coords_lvl
)

# 计算目标特征与采样特征之间的相关性
# 对应元素相乘再相加
corr = torch.sum(
target * pyramid_sample.reshape(B, S, C, N, -1), dim=2
)
corr = corr / torch.sqrt(torch.tensor(C).float())
out_pyramid.append(corr)

out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out

然后把相关特征拉平,在保证第一帧的特征点坐标为初始值的前提下,将跟踪特征、相关性特征和一个相对偏移编码特征送入Transformer并计算坐标和特征的增量,整个过程会持续迭代直到收敛。这部分内容看代码会比较好理解一点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class BaseTrackerPredictor(nn.Module):
def forward(
self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1
):
"""
query_points: B x N x 2, the number of batches, tracks, and xy
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
note HH and WW is the size of feature maps instead of original images
"""
B, N, D = query_points.shape
B, S, C, HH, WW = fmaps.shape

assert D == 2

# Scale the input query_points because we may downsample the images
# by down_ratio or self.stride
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
# its query_points should be query_points/4
if down_ratio > 1:
query_points = query_points / float(down_ratio)
query_points = query_points / float(self.stride)

# Init with coords as the query points
# It means the search will start from the position of query points at the reference frames
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)

# Sample/extract the features of the query points in the query frame
# track_feats表示跟踪特征,初始值是通过在查询帧中对查询点的特征进行采样得到的
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])

# init track feats by query feats
track_feats = query_track_feat.unsqueeze(1).repeat(
1, S, 1, 1
) # B, S, N, C
# back up the init coords
coords_backup = coords.clone()

# Construct the correlation block
# 初始化计算相关特征的函数,有两种不同的构造
if self.efficient_corr:
fcorr_fn = EfficientCorrBlock(
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
)
else:
fcorr_fn = CorrBlock(
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
)

# 最终预测结果
coord_preds = []

# Iterative Refinement
for itr in range(iters):
# Detach the gradients from the last iteration
# (in my experience, not very important for performance)
coords = coords.detach()

# Compute the correlation (check the implementation of CorrBlock)
# 相关性特征
if self.efficient_corr:
fcorrs = fcorr_fn.sample(coords, track_feats)
else:
fcorr_fn.corr(track_feats)
fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim

corrdim = fcorrs.shape[3]
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)

# Movement of current coords relative to query points
# 相对偏移编码特征
flows = (
(coords - coords[:, 0:1])
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
)
flows_emb = get_2d_embedding(
flows, self.flows_emb_dim, cat_coords=False
)
# (In my trials, it is also okay to just add the flows_emb instead of concat)
flows_emb = torch.cat([flows_emb, flows], dim=-1)

# 跟踪特征
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
B * N, S, self.latent_dim
)

# Concatenate them as the input for the transformers
transformer_input = torch.cat(
[flows_emb, fcorrs_, track_feats_], dim=2
)

if transformer_input.shape[2] < self.transformer_dim:
# pad the features to match the dimension
pad_dim = self.transformer_dim - transformer_input.shape[2]
pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
transformer_input = torch.cat([transformer_input, pad], dim=2)

# 2D positional embed
# TODO: this can be much simplified
pos_embed = get_2d_sincos_pos_embed(
self.transformer_dim, grid_size=(HH, WW)
).to(query_points.device)
sampled_pos_emb = sample_features4d(
pos_embed.expand(B, -1, -1, -1), coords[:, 0]
)
sampled_pos_emb = rearrange(
sampled_pos_emb, "b n c -> (b n) c"
).unsqueeze(1)

x = transformer_input + sampled_pos_emb

# B, N, S, C
x = rearrange(x, "(b n) s d -> b n s d", b=B)

# Compute the delta coordinates and delta track features
# 迭代更新
delta = self.updateformer(x)
# BN, S, C
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
delta_coords_ = delta[:, :, :2]
delta_feats_ = delta[:, :, 2:]

track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)

# Update the track features
track_feats_ = (
self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
)
track_feats = track_feats_.reshape(
B, N, S, self.latent_dim
).permute(
0, 2, 1, 3
) # BxSxNxC

# B x S x N x 2
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(
0, 2, 1, 3
)

# Force coord0 as query
# because we assume the query points should not be changed
coords[:, 0] = coords_backup[:, 0]

# The predicted tracks are in the original image scale
if down_ratio > 1:
coord_preds.append(coords * self.stride * down_ratio)
else:
coord_preds.append(coords * self.stride)

# B, S, N
if not self.fine:
vis_e = self.vis_predictor(
track_feats.reshape(B * S * N, self.latent_dim)
).reshape(B, S, N)
vis_e = torch.sigmoid(vis_e)
else:
vis_e = None

if return_feat:
return coord_preds, vis_e, track_feats, query_track_feat
else:
return coord_preds, vis_e

这样子就完成了一次匹配,VGGsfm有粗匹配和精匹配两步匹配,架构是一样的,只不过粗匹配是使用一张图像的完整特征图,而精匹配是对粗匹配结果的附近特征图进行裁剪后再重新匹配,从而得到亚像素级别的匹配精度。可以看到,VGGsfm在匹配上并没有非常粗暴的直接使用网络一步到位,相反还留有浓厚的传统sfm的pipeline的影子。

估计相机位姿与三角化

这里结构其实跟上面差不太多,预测相机pose的也是一个Transformer,只不过输入是由ResNet50提取的全局图像特征和前面计算好的匹配,还有一点值得注意的是经过cross-attention的输出会和8点法初步估计的位姿嵌合起来,最终才估计出相机位姿。三角化模块的输入是匹配模块输出的跟踪特征、预测出来的相机位姿和由DLT估计出来的初始点云,同样经过Transformer处理后输出三角化后的3D点云。

(这部分就懒得贴代码了,一方面是代码非常长而且很容易找到,另一方面是我也只是粗略看了一下)

可微分BA

一般来说,BA都是放在网络后面的一个单独的步骤,因为BA本身迭代次数可能非常多而且可能会包括一些不可微的操作,一般都被看成是一个不可导层,也就没办法直接参与到网络的训练中。有些间接参与训练的方法是这样子来实现的,但是效果并不一定会很好。

可微分BA就是能让BA本身变得可微,从而能够当作网络的一个层,直接参与到网络的训练中,一般来说可微分BA都是通过隐函数定理来间接求导的。

VGGsfm直接用的Theseus这个库来实现可微分BA的,但是这里还是简单推导一下。

欧式空间可微分BA

记优化为,其中为初始化输入,为优化后的输出,如果最终Loss为,那么需要计算的梯度为

如果存在函数,使得,那么根据隐函数定理可以将看作为的函数,这是根据隐函数定理就可以得到

这样就可以通过求解来求得。那如何找到呢,一般来说,BA最终的输出都会满足一阶导为0,所以可以将定义为BA的目标函数的导数即可。

当然,这里只是简单推导了一下欧式空间中的可微分BA,实际因为需要优化重投影误差,需要在李空间里优化旋转,我们再来看下李空间的可微分BA。

李空间可微分BA

对于位姿可以用李代数参数化:

其中将李代数向量转换为反对称矩阵,是指数映射。优化变量包括李代数参数和三维点坐标

记李空间的优化问题为,最终损失函数为,需计算梯度:

在最优解处,目标函数的导数为零,仿照欧式空间的推导,记为BA的目标函数的导数:

这定义了隐式方程,将视为的函数。

对隐式方程两边关于求导:

解出:

整理为矩阵形式:

进一步写成:

其中Hessian矩阵

解得:

那么最终梯度为:

不过由于是流形空间对应正切空间的变量,计算对应导数的时候需要记得乘以对应的Jacobian矩阵。

损失函数

记网络输出点云为,对应BA优化后点云为,那么可以得到点云Loss,其中为Huber loss:

记网络输出相机位姿为,对应BA优化后相机位姿为,那么可以得到位姿Loss:

记网络输出的跟踪点为,对应的置信度为,那么可以得到跟踪Loss,它描述了真值跟踪点在以为均值、以为方差的高斯分布下的似然度:

最终的Loss就是上面三种Loss的和。

VGGsfm总结

说实话我也只是粗略的看了一下,主要是它文章很多地方写的确实不太清楚,而且VGGT已经是它的上位替代了,也就没有仔细读它代码的欲望了...不过VGGsfm的架构还是很有意思的,它保留了很多基于几何法的sfm的pipeline,但是又在每个步骤中引入了网络,这样能复用很多基于几何的sfm的组件,比如可以把几何法计算的配准结果嵌入到网络的输入中。但是鉴于后续大量的sfm都抛弃了这种做法而转向拥抱纯粹的end2end,我感觉这种基于几何法的sfm的pipeline可能并不是引入网络的最好的选择。

VGGT

下面正式进入VGGT,和VGGsfm很不一样,VGGT基本上没有几何法sfm pipeline的影子了,就是纯粹的end2end;而且也不像dust3r、mast3r那样,先对两张图像进行重建,再通过后处理的方式来配准,而是本身就支持一次处理多张图像,也不需要额外的BA,当然有BA的话会更准一点,但最后没BA也不是不行。

由上图可以看到,网络可以一次性输入一组图像,然后输出每张图像的相机参数,深度图,点图和用于跟踪点的特征矩阵

其中相机参数包含内参和外参,其中外参由来参数化,而内参而是用视场角fov来参数化。深度图就是描述了每个点的深度信息。点图则跟dust3r中定义一致,即从光心到每个像素构成的射线碰到的第一个空间点的坐标(点图中空间点全部定义在第一帧图像的坐标系里,因此点图和深度图是有很大差异的,点图深度图+外参)。而特征矩阵是一个由所有图像中对应二维点组成的轨迹,用于跟踪点。

网络架构

网络本身并不复杂,直接来看代码,和论文写的一致,整体很好理解,Aggregator模块对图像处理然后输出token,CameraHead模块负责预测相机位姿,DPTHead模块负责预测深度图,还有一个可选的TrackHead模块,如果提供了需要跟踪的点那么它会输出跟踪点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class VGGT(nn.Module, PyTorchModelHubMixin):
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
super().__init__()
# 5个模块初始化
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.camera_head = CameraHead(dim_in=2 * embed_dim)
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)

def forward(
self,
images: torch.Tensor,
query_points: torch.Tensor = None,
):
if len(images.shape) == 4:
images = images.unsqueeze(0)
if query_points is not None and len(query_points.shape) == 2:
query_points = query_points.unsqueeze(0)

# aggregator提取token,因为会同时输出camera相关的token和feature相关token
# 需要patch_start_idx标记两者分界
aggregated_tokens_list, patch_start_idx = self.aggregator(images)

predictions = {}

with torch.cuda.amp.autocast(enabled=False):
# 计算相机位姿
if self.camera_head is not None:
pose_enc_list = self.camera_head(aggregated_tokens_list)
...

# 计算深度图
if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
...

# 计算点图
if self.point_head is not None:
pts3d, pts3d_conf = self.point_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
...

# 可选的tracker模块用来估计给定点的轨迹
if self.track_head is not None and query_points is not None:
track_list, vis, conf = self.track_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
)
...

return predictions

下面就分别介绍这几个模块,当然重点肯定是Aggregator。

Aggregator和Alternating Attention

Aggregator输入的是一组图像,输出的是一组token。每两个self-attention组成一个Alternating Attention,每个Alternating Attention中的第一个self-attention负责对每帧内特征进行关联提取,第二个self-attention负责对所有帧的特征进行关联提取。然后堆叠多个Alternating Attention就组成了Aggregator模块的骨干网络。

当有多张图像输入时,会首先使用dinov2这个模型为每张图像计算patch_tokens并添加上对应的编码信息,然后为每帧图像都初始化camera_token和register_token(就是把默认的camera_token和register_token拓展下batchsize,他们初始化的时候是一些极小值),把这三种token拼接后就作为骨干网络的输入了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class Aggregator(nn.Module):
def __build_patch_embed__(...):
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(...)
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}

self.patch_embed = vit_models[patch_embed](...)


def forward(
self,
images: torch.Tensor,
) -> Tuple[List[torch.Tensor], int]:

# 归一化然后用dinov2来计算每张图的patch_tokens
images = (images - self._resnet_mean) / self._resnet_std
images = images.view(B * S, C_in, H, W)
patch_tokens = self.patch_embed(images)

if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]

_, P, C = patch_tokens.shape

# 把camera_token和register_token分别拉起到当前batchsize上
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
register_token = slice_expand_and_flatten(self.register_token, B, S)

# 然后把patch_tokens和camera_token和register_token拼接起来作为网络的输入
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)

# 因为每张图会划分成一块一块的,给每块添加位置编码
# 但是camera_token和register_token又不需要这种块位置编码,所以跳过他们
pos = None
if self.rope is not None:
pos = self.position_getter(
B * S, H // self.patch_size, W // self.patch_size, device=images.device
)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = (
torch.zeros(B * S, self.patch_start_idx, 2)
.to(images.device)
.to(pos.dtype)
)
pos = torch.cat([pos_special, pos], dim=1)

# update P because we added special tokens
_, P, C = tokens.shape

frame_idx = 0
global_idx = 0
output_list = []

# 重复交替执行frame_attention和global_attention就完事了
for _ in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = (
self._process_frame_attention(
tokens, B, S, P, C, frame_idx, pos=pos
)
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = (
self._process_global_attention(
tokens, B, S, P, C, global_idx, pos=pos
)
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")

for i in range(len(frame_intermediates)):
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat(
[frame_intermediates[i], global_intermediates[i]], dim=-1
)
output_list.append(concat_inter)

del concat_inter
del frame_intermediates
del global_intermediates
return output_list, self.patch_start_idx

显然,Aggregator的核心就是Alternating Attention(AA),论文里说了,AA里交替执行的frame attention和global attention都是self-attention,而且构建frame attention和global attention的参数都完全一样。其本质是控制batch
这个维度出现的位置来让不同帧的token出现在同一batch/不同batch,如果出现在同一batch里,那么就是global attention,所有帧都会参与彼此token的关联计算;反之出现在不同batch里,那么就是frame attention,每一帧计算的attention时只会影响到自己的token。

这么说可能有点抽象,简单举个例子。如果输入张图像,每张图像有个patch,每个patch的token的维度是,那么就可以构造张量,在frame attention时将其变换成,这样每张图像的patch token就会被看成是一个单独的 batch,这样就只会计算每一帧内的所有patch token;而在global attention时将其变换成,这样所有patch token都会被看成为一个batch,计算时所有帧的patch token都会参与彼此attention的计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Aggregator(nn.Module):
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)

if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)

intermediates = []

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
# self.frame_blocks是一个列表,每个元素都是Block类,Block类是一个self attention的封装
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))

return tokens, frame_idx, intermediates

def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)

if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)

intermediates = []

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
# self.global_blocks是一个列表,每个元素都是Block类,Block类是一个self attention的封装
tokens = self.global_blocks[global_idx](tokens, pos=pos)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))

return tokens, global_idx, intermediates

self.global_blocksself.frame_blocks的每个元素都是BLOCK类,而且他们的参数也完全相同。BLOCK类是对self attention的封装,代码就不贴了,大致画了下结构。

CameraHead

还记得Aggregator模块里的camera_token和register_token嘛,每一张图像都会使用1个camera_token和4个register_token来增强,并且第一帧的camera_token和register_token跟其余帧的还不一样,这是因为最终的点图是要在第一帧的坐标系里,使用额外的token可以把第一帧给单独拿出来。

1
2
3
4
5
6
7
8
9
def slice_expand_and_flatten(token_tensor, B, S):
# 第一个位置的token给第一帧用
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
# 其余所有帧的token都是others
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])

combined = torch.cat([query, others], dim=1)
combined = combined.view(B * S, *combined.shape[2:])
return combined

不过,最终CameraHead会丢弃掉register_token,仅使用camera_token来预测相机位姿。这里从camera_token解码出相机内外参的步骤还是稍微有点绕的,他会使用一个初始为全0的token来对camera_token进行偏移调制(类似FiLM),并用门控单元来控制残差连接时的贡献度,完成调制和残差连接后才会输入给attention,这个过程会迭代几次,过程中会不断更新调制token。下图大致画了一下主要步骤,把一些norm操作省略掉了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class CameraHead(nn.Module):
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
# aggregated_tokens_list是Aggregator每次AA输出的结果
# aggregated_tokens_list[-1]即经过所有AA后的迭代
tokens = aggregated_tokens_list[-1]

# camera_token在最前面,只保留camera_token,扔掉register_token
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)

# 迭代细化
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list

def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []

for _ in range(num_iterations):
if pred_pose_enc is None:
# 首次迭代:使用 empty_pose_tokens,全0的token
# embed_pose是个线性变化
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# 后续迭代:使用上一轮预测的姿势编码(detached防止梯度爆炸,类似RNN截断)
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)

# poseLN_modulation是一个简单的MLP,生成shift、scale、gate三个调制参数
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)

# adaln_norm自适应归一化,无仿射参数的LayerNorm
# modulate函数用用于对归一化后的特征进行仿射变换:x * (1 + scale) + shift
# 并用门控单元来控制贡献度
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
# 通过残差连接保留原始特征
pose_tokens_modulated = pose_tokens_modulated + pose_tokens

# 将调制后的特征输入attention,trunk是多个Block类,跟之前的AA是一样的
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
# 通过MLP预测参数增量
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))

# 更新增量
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta

# 解码出内外参
activated_pose = activate_pose(
pred_pose_enc,
trans_act=self.trans_act,
quat_act=self.quat_act,
fl_act=self.fl_act,
)
pred_pose_enc_list.append(activated_pose)

return pred_pose_enc_list

解码其实相对比较简单了,最后的MLP会输出一个9维的tensor,前三维是平移,接着四维是四元数,最后两维是视场角fov,然后通过一些线性或非线性变换来解码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
T = pred_pose_enc[..., :3]
quat = pred_pose_enc[..., 3:7]
fl = pred_pose_enc[..., 7:] # or fov

T = base_pose_act(T, trans_act)
quat = base_pose_act(quat, quat_act)
fl = base_pose_act(fl, fl_act) # or fov

pred_pose_enc = torch.cat([T, quat, fl], dim=-1)

return pred_pose_enc


def base_pose_act(pose_enc, act_type="linear"):
if act_type == "linear":
return pose_enc
elif act_type == "inv_log":
return inverse_log_transform(pose_enc)
elif act_type == "exp":
return torch.exp(pose_enc)
elif act_type == "relu":
return F.relu(pose_enc)
else:
raise ValueError(f"Unknown act_type: {act_type}")

DPTHead和TrackHead

关于点图和深度图都是通过DPThead来完成解码的,这反而其实没啥好说的,具体细节去看DPT原文吧。VGGT这里会先用DPT预测一个稠密的特征图,然后用两个卷积层和激活函数来解码出深度图和置信度。

另外,TrackHead需要使用DPT输出的这个稠密特征图,所以可以看到DPTHead的实现里如果指定了只要特征的flag,就不会执行后续的激活解码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DPTHead(nn.Module):
def _forward_impl(...) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if frames_start_idx is not None and frames_end_idx is not None:
images = images[:, frames_start_idx:frames_end_idx].contiguous()

B, S, _, H, W = images.shape

patch_h, patch_w = H // self.patch_size, W // self.patch_size

out = []
dpt_idx = 0

# DPT
for layer_idx in self.intermediate_layer_idx:
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
if frames_start_idx is not None and frames_end_idx is not None:
x = x[:, frames_start_idx:frames_end_idx]
x = x.view(B * S, -1, x.shape[-1])
x = self.norm(x)
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[dpt_idx](x)
if self.pos_embed:
x = self._apply_pos_embed(x, W, H)
x = self.resize_layers[dpt_idx](x)
out.append(x)
dpt_idx += 1
out = self.scratch_forward(out)

# 调整到目标分辨率
out = custom_interpolate(
out,
(
int(patch_h * self.patch_size / self.down_ratio),
int(patch_w * self.patch_size / self.down_ratio),
),
mode="bilinear",
align_corners=True,
)
if self.pos_embed:
out = self._apply_pos_embed(out, W, H)
# trackerhead需要DPT输出的特征图,而不需要后续的激活
if self.feature_only:
return out.view(B, S, *out.shape[1:])

# 3*3卷积+1*1卷积
out = self.scratch.output_conv2(out)

# 通过线性或非线性来输出3D点和置信度
preds, conf = activate_head(
out, activation=self.activation, conf_activation=self.conf_activation
)

preds = preds.view(B, S, *preds.shape[1:])
conf = conf.view(B, S, *conf.shape[1:])
return preds, conf

关于TrackHead,VGGT其实是从VGGsfm中把tracker那一部分移植过来了,我们之前提到过,VGGsfm使用不同大小的CNN来计算特征图,VGGT就直接用DPT的输出作为特征图了,然后复用了VGGsfm的BaseTrackerPredictor来跟踪点,这个函数基本上没有太大变化,也是计算相关性金字塔然后把跟踪特征、相关特征和相对偏移编码特征送入Transformer计算增量这样子不断迭代。因为前文已经展示过代码了,这里就不列具体细节了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class TrackHead(nn.Module):
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
B, S, _, H, W = images.shape

# DPTHead提取特征图
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)

# BaseTrackerPredictor用来跟踪特征
# 这个BaseTrackerPredictor跟VGGsfm里的BaseTrackerPredictor长得基本一模一样
coord_preds, vis_scores, conf_scores = self.tracker(
query_points=query_points,
fmaps=feature_maps,
iters=iters,
)

return coord_preds, vis_scores, conf_scores

损失函数

因为VGGT有4个输出部分,所有显然Loss也是由4个部分组成的。

首先是相机位姿Loss,这里的是Huber loss:

然后是深度图损失,跟其中是置信度,表示对应元素相乘,也就是用置信度给损失加权。这里面的第一项和第三项跟dust3r的loss长得差不多,分别表示真实深度加权和正则化损失;而第二项是VGGT额外引入的,是深度的梯度项,这保证了输出深度图的边缘和结构信息和真实结构尽可能一致,从而较好的保留几何细节:

点图损失跟深度图损失如出一辙:

最后是跟踪Loss,跟VGGsfm里的一样,外层遍历查询图像的所有真实查询点,其中是真值,而是预测结果:

最终的Loss为:

效果测试

同样用之前拍的图测试了一下效果,参数调的偏保守一点,几何结构都相当正确,但是一些置信度低的点云就不可见了。