Skip to content

cofai.backbone

BackboneProtocol

Protocol defining the interface for backbone models.

decode

decode(h: Tensor, tasks: list[str]) -> torch.Tensor

Decode encoded features for downstream tasks.

Parameters:

Name Type Description Default
h Tensor

Encoded features from encode method, shape (B, N, C)

required
tasks list[str]

List of tasks to decode for, e.g. ["rec", "cls", "seg"]

required

Returns: task_feats (dict): Dictionary of task-specific features

encode

encode(x: Tensor) -> torch.Tensor

Encode input images to intermediate features.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W)

required

Returns:

Name Type Description
h Tensor

Encoded features of shape (B, N, C)

Dinov2OrgBackbone

Dinov2OrgBackbone(model_size='small', img_size=256, patch_size=16, dynamic_size=False, slot=-4, n_last_blocks=4, ckpt_path=None)

DINOv2 backbone using the original Facebook Research implementation.

This class extends the original DINOv2 model to provide flexible feature extraction. The slot parameter determines the splitting point for dividing the ViT blocks into:

  • encode part: blocks[:slot],
  • decode part: blocks[slot:]

Intermediate features are extracted after the encode part and before the decode part.

Parameters:

Name Type Description Default
model_size str

Model variant specification ('small', 'base', 'large', 'giant'). Defaults to 'small'.

'small'
img_size int

Base input image size. Defaults to 256.

256
patch_size int

Patch embedding size. Defaults to 16.

16
dynamic_size bool

Whether to support dynamically varying input sizes. Defaults to False.

False
slot int

Block slicing position for feature extraction. Follows Python list slicing conventions. -4 means the last 4th block. Defaults to -4.

-4
n_last_blocks int

Number of final blocks to utilize for feature aggregation. Defaults to 4.

4
ckpt_path str

Path to pre-trained checkpoint for initialization. Defaults to None.

None

decode

decode(h, token_res=None, task='whole')

Decode encoded features through the decoder part of the DINOv2 model.

Parameters:

Name Type Description Default
h Tensor

Encoded features from the encoder.

required
token_res tuple

Token resolution (H, W) for reshaping patch tokens. Defaults to None.

None
task str

Decoding task type. Must be one of:

  • "whole": Return full token sequences from multiple layers.
  • "cls": Return class tokens and patch tokens separately.
  • "seg": Return patch tokens reshaped to 2D spatial format. Defaults to "whole".
'whole'

Returns:

Name Type Description
feats list[Tensor, ...]

Decoded features, format depends on task.

encode

encode(x)

Encode input images through the encoder part of the DINOv2 model.

The encoding process applies input normalization, prepares tokens with masks, and processes the input through the first slot transformer blocks.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features after the encoder blocks.

forward

forward(x, task='whole')

Forward pass through the backbone.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required
task str

Task type, one of ["whole", "cls", "seg"]. Defaults to "whole".

'whole'

Returns:

Name Type Description
feats list[Tensor, ...]

Output features, format depends on task.

slide_decode_seg

slide_decode_seg(feature_list, slide_res)

Decode features from sliding window encoding for segmentation task.

Parameters:

Name Type Description Default
feature_list list

List of encoded features from slide_encode.

required
slide_res tuple

Token resolution (H, W) for each crop.

required

Returns:

Name Type Description
multi_crop_feats list[list[Tensor, ...]]

List of decoded segmentation features, one for each crop.

slide_encode

slide_encode(img, slide_window, slide_stride)

Encode images using sliding window approach.

This method extracts features from overlapping image crops using a sliding window strategy. Useful for processing large images that don't fit in memory or for extracting features at multiple scales.

Parameters:

Name Type Description Default
img Tensor

Input image of shape (B, 3, H_img, W_img).

required
slide_window tuple

Window size for cropping (h_crop, w_crop).

required
slide_stride tuple

Stride for sliding window (h_stride, w_stride).

required

Returns:

Name Type Description
multi_crop_feats list[Tensor]

List of encoded features, one for each crop.

Dinov2TimmBackbone

Dinov2TimmBackbone(model_size: str = 'small', img_size: int = 256, patch_size: int = 16, dynamic_size: bool = False, slot: int = -4, n_last_blocks: int = 4, ckpt_path: str = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', cast_dtype: str = 'float', with_registers: bool = False)

DINOv2 backbone using timm library.

This class extends the DINOv2 model to provide flexible feature extraction. The DINOv2 backbone implemented with timm supports variable patch sizes and dynamic input image sizes. The slot parameter determines the splitting point for dividing the ViT blocks into:

  • encode part: blocks[:slot],
  • decode part: blocks[slot:]

Intermediate feature are extracted after the encode part and before the decode part.

Parameters:

Name Type Description Default
model_size str

Model variant specification ('small', 'base', 'large', 'giant'). Defaults to 'small'.

'small'
img_size int

Base input image size. Defaults to 256.

256
patch_size int

Patch embedding size. Defaults to 16.

16
dynamic_size bool

Whether to support dynamically varying input sizes. Defaults to False.

False
slot int or None

Block slicing position for feature extraction. Follows Python list slicing conventions. Defaults to -4.

-4
n_last_blocks int

Number of final blocks to utilize for feature aggregation. Defaults to 4.

4
ckpt_path str

Path to pre-trained checkpoint for initialization. Defaults to None.

None
cast_dtype str or dtype

Data type for autocast mixed precision. Supports string format like "torch.float", "torch.float16", "float32", etc. Defaults to "torch.float".

'float'
device str

Device to run the model on. Defaults to "cuda" if available, else "cpu".

'cuda' if is_available() else 'cpu'
with_registers bool

Whether to use register tokens in the model. Defaults to False.

False

decode

decode(h, token_res=None, task='whole')

Decode encoded features through the decoder part of the DINOv2 model.

Parameters:

Name Type Description Default
h Tensor

Encoded features from the encoder.

required
token_res tuple

Token resolution (H, W) for reshaping patch tokens. Defaults to None.

None
task str

Decoding task type. Must be one of:

  • "whole": Return full token sequences from multiple layers.
  • "cls": Return class tokens and patch tokens separately.
  • "seg": Return patch tokens reshaped to 2D spatial format. Defaults to "whole".
'whole'

Returns:

Name Type Description
feats list[Tensor, ...]

Decoded features, format depends on task.

decode_rae

decode_rae(h, token_res=None)

Decode encoded features for RAE decoder input.

This method processes encoded features through the remaining transformer blocks and returns patch tokens without prefix tokens for RAE decoder. The input h should be the output from the encode method (after blocks[:slot]).

This method uses LayerNorm without learnable affine parameters (matching Dinov2TimmwithNorm with normalize=True). This equals to simple normalization.

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the encoder part of the DINOv2 model.

The encoding process applies input normalization, patch embedding, positional embedding, and processes the input through the first slot transformer blocks.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features after the encoder blocks, shape (B, N, C).

forward

forward(x, task='whole')

Forward pass through the backbone.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required
task str

Task type, one of ["whole", "cls", "seg"]. Defaults to "whole".

'whole'

Returns:

Name Type Description
feats list[Tensor]

Output features, format depends on task.

Dinov2TransformersBackbone

Dinov2TransformersBackbone(model_name: str = 'facebook/dinov2-base', img_size: int = 224, slot: int = -4, normalize: bool = True)

DINOv2 backbone using transformers library.

This backbone uses Dinov2Model or Dinov2WithRegistersModel from transformers library and provides encode/decode functionality for RAE.

Parameters:

Name Type Description Default
model_name str

HuggingFace model name, e.g., 'facebook/dinov2-base' or 'facebook/dinov2-with-registers-base'. Defaults to 'facebook/dinov2-base'.

'facebook/dinov2-base'
img_size int

Input image size. Defaults to 224.

224
slot int

Block slicing position for feature extraction. -4 means the last 4th block. -1 means use all blocks (matches old behavior). Defaults to -4.

-4
normalize bool

Whether to remove layernorm affine parameters (for RAE compatibility). Defaults to True.

True

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the DINOv2 encoder.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features, shape (B, N, C).

MAETimmBackbone

MAETimmBackbone(model_name: str = 'vit_base_patch16_224.mae', img_size: int = 256, patch_size: int = 16, slot: int = -1, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', cast_dtype: str = 'float')

MAE (Masked Autoencoder) backbone using timm library.

This backbone uses ViT-MAE from timm library and provides encode/decode functionality for RAE.

Available model names in timm
  • 'vit_base_patch16_224.mae' (default, standard MAE model)
  • 'vit_base_patch16_mae.in1k' (alternative naming)

Parameters:

Name Type Description Default
model_name str

Timm model name, e.g., 'vit_base_patch16_224.mae'. Defaults to 'vit_base_patch16_224.mae'.

'vit_base_patch16_224.mae'
img_size int

Input image size. Defaults to 256.

256
patch_size int

Patch embedding size. Defaults to 16.

16
slot int

Block slicing position for feature extraction. -1 means use all blocks. Defaults to -1.

-1
cast_dtype str or dtype

Data type for autocast mixed precision. Defaults to "float".

'float'
device str

Device to run the model on. Defaults to "cuda" if available, else "cpu".

'cuda' if is_available() else 'cpu'

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the MAE encoder.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features, shape (B, N, C).

MAETransformersBackbone

MAETransformersBackbone(model_name: str = 'facebook/vit-mae-base', img_size: int = 256, slot: int = -1)

MAE (Masked Autoencoder) backbone using transformers library.

This backbone uses ViTMAE from transformers library and provides encode/decode functionality for RAE.

Parameters:

Name Type Description Default
model_name str

HuggingFace model name, e.g., 'facebook/vit-mae-base'. Defaults to 'facebook/vit-mae-base'.

'facebook/vit-mae-base'
img_size int

Input image size. Defaults to 256.

256
slot int

Block slicing position for feature extraction. -1 means use all blocks. Defaults to -1.

-1

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the MAE encoder.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features, shape (B, N, C).

SigLIP2TimmBackbone

SigLIP2TimmBackbone(model_name: str = 'vit_base_patch16_siglip_224', img_size: int = 224, patch_size: int = 16, slot: int = -1, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', cast_dtype: str = 'float')

SigLIP2 backbone using timm library.

This backbone uses SigLIP2 from timm library and provides encode/decode functionality for RAE.

Available model names in timm
  • 'vit_base_patch16_siglip_224' (default, standard SigLIP model)
  • 'vit_base_patch16_siglip.in1k' (alternative naming)
  • 'vit_base_patch16_siglip_256' (alternative size)

Parameters:

Name Type Description Default
model_name str

Timm model name for SigLIP2. Defaults to 'vit_base_patch16_siglip_224'.

'vit_base_patch16_siglip_224'
img_size int

Input image size. Defaults to 224.

224
patch_size int

Patch embedding size. Defaults to 16.

16
slot int

Block slicing position for feature extraction. -1 means use all blocks. Defaults to -1.

-1
cast_dtype str or dtype

Data type for autocast mixed precision. Defaults to "float".

'float'
device str

Device to run the model on. Defaults to "cuda" if available, else "cpu".

'cuda' if is_available() else 'cpu'

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the SigLIP2 encoder.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features, shape (B, N, C).

SigLIP2TransformersBackbone

SigLIP2TransformersBackbone(model_name: str = 'google/siglip2-base-patch16-256', img_size: int = 256, slot: int = -1)

SigLIP2 backbone using transformers library.

This backbone uses SiglipModel from transformers library and provides encode/decode functionality for RAE.

Parameters:

Name Type Description Default
model_name str

HuggingFace model name, e.g., 'google/siglip2-base-patch16-256'. Defaults to 'google/siglip2-base-patch16-256'.

'google/siglip2-base-patch16-256'
img_size int

Input image size. Defaults to 256.

256
slot int

Block slicing position for feature extraction. -1 means use all blocks. Defaults to -1.

-1

encode

encode(x: Tensor) -> torch.Tensor

Encode input images through the SigLIP2 encoder.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W).

required

Returns:

Name Type Description
h Tensor

Encoded features, shape (B, N, C).

VGGBackbone

VGGBackbone(device)

extract

extract(frame_bgr)

对应原 extract_vgg_feature 函数 输入: cv2 读取的 BGR 图片 (numpy array) 输出: [1, 512, 1, 1] 的 Tensor

VqganBackbone

VqganBackbone(**kwargs)

VQGAN-based backbone for image encoding and decoding.

This backbone uses a VQGAN (Vector Quantized Generative Adversarial Network) model to encode images into discrete tokens and decode them back to images. The encoding process converts images to latent codes and quantizes them using a codebook.

Parameters:

Name Type Description Default
vqgan_config dict

Configuration dictionary for the VQModel initialization.

required
**kwargs dict

Unused keyword arguments for API compatibility.

{}

Attributes:

Name Type Description
vqgan VQModel

The underlying VQGAN model.

codebook_size int

Size of the quantization codebook.

decode

decode(z_q)

Decode quantized latent codes back to images.

Parameters:

Name Type Description Default
z_q Tensor

Quantized latent codes of shape (B, C, H, W).

required

Returns:

Name Type Description
x_hat Tensor

Reconstructed images of shape (B, 3, H, W) in range [0, 1].

encode

encode(x)

Encode input images into latent codes and tokens.

The input images x are expected to be in the range [0, 1], which are then transformed to [-1, 1] for the VQGAN encoder. The encoder produces latent codes that are quantized using the codebook to produce discrete tokens. The quantization process produces z_q' = z + (z_q - z).detach(), which incurs a small MSE error (approximately 1e-18) between z_q' and z_q. We use z_q as the context for consistency.

Parameters:

Name Type Description Default
x Tensor

Input images of shape (B, 3, H, W) in range [0, 1].

required

Returns:

Name Type Description
vqgan_enc dict

A dictionary containing:

  • "z" (torch.Tensor): Continuous latent codes before quantization.
  • "z_q" (torch.Tensor): Quantized latent codes of shape (B, C, H, W).
  • "tokens" (torch.Tensor): Discrete token indices of shape (B, H, W).
  • "shape" (tuple): Spatial dimensions (H, W) of the latent representation.

tokens_to_features

tokens_to_features(tokens)

Convert discrete tokens to quantized latent features.

Parameters:

Name Type Description Default
tokens Tensor

Discrete token indices of shape (B, H, W).

required

Returns:

Name Type Description
z_q Tensor

Quantized latent features of shape (B, C, H, W).

extract_vgg_features

extract_vgg_features(vgg_model, frame_tensor)

提取VGG特征(与process_video并行执行) 参数: vgg_model: 预初始化的VGG特征提取器 frame_tensor: 输入图像张量 [1,3,H,W], 范围[0,1] 返回: features: 提取的特征向量 [1,C]

setup_vgg_feature_extractor

setup_vgg_feature_extractor(device='cuda')

初始化VGG16特征提取器