classBaseTrackerPredictor(nn.Module): defforward( self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1 ): """ query_points: B x N x 2, the number of batches, tracks, and xy fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. note HH and WW is the size of feature maps instead of original images """ B, N, D = query_points.shape B, S, C, HH, WW = fmaps.shape
assert D == 2
# Scale the input query_points because we may downsample the images # by down_ratio or self.stride # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map # its query_points should be query_points/4 if down_ratio > 1: query_points = query_points / float(down_ratio) query_points = query_points / float(self.stride)
# Init with coords as the query points # It means the search will start from the position of query points at the reference frames coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
# Sample/extract the features of the query points in the query frame # track_feats表示跟踪特征,初始值是通过在查询帧中对查询点的特征进行采样得到的 query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
# init track feats by query feats track_feats = query_track_feat.unsqueeze(1).repeat( 1, S, 1, 1 ) # B, S, N, C # back up the init coords coords_backup = coords.clone()
# Iterative Refinement for itr inrange(iters): # Detach the gradients from the last iteration # (in my experience, not very important for performance) coords = coords.detach()
# Compute the correlation (check the implementation of CorrBlock) # 相关性特征 if self.efficient_corr: fcorrs = fcorr_fn.sample(coords, track_feats) else: fcorr_fn.corr(track_feats) fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
# Movement of current coords relative to query points # 相对偏移编码特征 flows = ( (coords - coords[:, 0:1]) .permute(0, 2, 1, 3) .reshape(B * N, S, 2) ) flows_emb = get_2d_embedding( flows, self.flows_emb_dim, cat_coords=False ) # (In my trials, it is also okay to just add the flows_emb instead of concat) flows_emb = torch.cat([flows_emb, flows], dim=-1)
# 跟踪特征 track_feats_ = track_feats.permute(0, 2, 1, 3).reshape( B * N, S, self.latent_dim )
# Concatenate them as the input for the transformers transformer_input = torch.cat( [flows_emb, fcorrs_, track_feats_], dim=2 )
if transformer_input.shape[2] < self.transformer_dim: # pad the features to match the dimension pad_dim = self.transformer_dim - transformer_input.shape[2] pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) transformer_input = torch.cat([transformer_input, pad], dim=2)
# 2D positional embed # TODO: this can be much simplified pos_embed = get_2d_sincos_pos_embed( self.transformer_dim, grid_size=(HH, WW) ).to(query_points.device) sampled_pos_emb = sample_features4d( pos_embed.expand(B, -1, -1, -1), coords[:, 0] ) sampled_pos_emb = rearrange( sampled_pos_emb, "b n c -> (b n) c" ).unsqueeze(1)
x = transformer_input + sampled_pos_emb
# B, N, S, C x = rearrange(x, "(b n) s d -> b n s d", b=B)
# Compute the delta coordinates and delta track features # 迭代更新 delta = self.updateformer(x) # BN, S, C delta = rearrange(delta, " b n s d -> (b n) s d", b=B) delta_coords_ = delta[:, :, :2] delta_feats_ = delta[:, :, 2:]
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
# Update the track features track_feats_ = ( self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ ) track_feats = track_feats_.reshape( B, N, S, self.latent_dim ).permute( 0, 2, 1, 3 ) # BxSxNxC
# B x S x N x 2 coords = coords + delta_coords_.reshape(B, N, S, 2).permute( 0, 2, 1, 3 )
# Force coord0 as query # because we assume the query points should not be changed coords[:, 0] = coords_backup[:, 0]
# The predicted tracks are in the original image scale if down_ratio > 1: coord_preds.append(coords * self.stride * down_ratio) else: coord_preds.append(coords * self.stride)
# B, S, N ifnot self.fine: vis_e = self.vis_predictor( track_feats.reshape(B * S * N, self.latent_dim) ).reshape(B, S, N) vis_e = torch.sigmoid(vis_e) else: vis_e = None
# 因为每张图会划分成一块一块的,给每块添加位置编码 # 但是camera_token和register_token又不需要这种块位置编码,所以跳过他们 pos = None if self.rope isnotNone: pos = self.position_getter( B * S, H // self.patch_size, W // self.patch_size, device=images.device ) if self.patch_start_idx > 0: # do not use position embedding for special tokens (camera and register tokens) # so set pos to 0 for the special tokens pos = pos + 1 pos_special = ( torch.zeros(B * S, self.patch_start_idx, 2) .to(images.device) .to(pos.dtype) ) pos = torch.cat([pos_special, pos], dim=1)
# update P because we added special tokens _, P, C = tokens.shape
for i inrange(len(frame_intermediates)): # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat( [frame_intermediates[i], global_intermediates[i]], dim=-1 ) output_list.append(concat_inter)
del concat_inter del frame_intermediates del global_intermediates return output_list, self.patch_start_idx
# by default, self.aa_block_size=1, which processes one block at a time for _ inrange(self.aa_block_size): # self.frame_blocks是一个列表,每个元素都是Block类,Block类是一个self attention的封装 tokens = self.frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
def_process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C)
if pos isnotNoneand pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time for _ inrange(self.aa_block_size): # self.global_blocks是一个列表,每个元素都是Block类,Block类是一个self attention的封装 tokens = self.global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C))
deftrunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: B, S, C = pose_tokens.shape # S is expected to be 1. pred_pose_enc = None pred_pose_enc_list = []
patch_h, patch_w = H // self.patch_size, W // self.patch_size
out = [] dpt_idx = 0
# DPT for layer_idx in self.intermediate_layer_idx: x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] if frames_start_idx isnotNoneand frames_end_idx isnotNone: x = x[:, frames_start_idx:frames_end_idx] x = x.view(B * S, -1, x.shape[-1]) x = self.norm(x) x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) x = self.projects[dpt_idx](x) if self.pos_embed: x = self._apply_pos_embed(x, W, H) x = self.resize_layers[dpt_idx](x) out.append(x) dpt_idx += 1 out = self.scratch_forward(out)
# 调整到目标分辨率 out = custom_interpolate( out, ( int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio), ), mode="bilinear", align_corners=True, ) if self.pos_embed: out = self._apply_pos_embed(out, W, H) # trackerhead需要DPT输出的特征图,而不需要后续的激活 if self.feature_only: return out.view(B, S, *out.shape[1:])
# 3*3卷积+1*1卷积 out = self.scratch.output_conv2(out)