项目地址:wan2.1
参考博文:torch.nn.Conv3D 参数及使用详解

记录一下看wan2.1仓库代码的过程,GPT辅助。

源代码:

class CausalConv3d(nn.Conv3d):
    """
    Causal 3d convolusion.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._padding = (self.padding[2], self.padding[2], self.padding[1],
                         self.padding[1], 2 * self.padding[0], 0)
        self.padding = (0, 0, 0)

    def forward(self, x, cache_x=None):
        padding = list(self._padding)
        if cache_x is not None and self._padding[4] > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding[4] -= cache_x.shape[2]
        x = F.pad(x, padding)

        return super().forward(x)

这段代码定义了一个用于 因果性(Causal)3D卷积 的 PyTorch 模块 CausalConv3d,它继承自 nn.Conv3d,并修改了 padding 和输入策略,以保证 时间维度上的因果性 —— 即输出时间点 t 只依赖于时间 ≤ t 的输入,不会“看到未来”


📌 一、因果性背景

在时间序列(如视频帧)建模中,因果卷积的核心目标是:

当前时刻的输出只能依赖于过去和当前的输入,不能访问未来帧。

对于 3D 卷积,维度顺序是 (N, C, T, H, W),其中:

  • T:时间维(通常是因果卷积处理的重点);
  • H, W:空间维。

📦 二、类结构说明

class CausalConv3d(nn.Conv3d):

继承自 nn.Conv3d,表示这是一个标准的 3D 卷积,只是做了一些重写处理。


⚙️ 三、构造函数逻辑(__init__

self._padding = (
    self.padding[2], self.padding[2],    # width
    self.padding[1], self.padding[1],    # height
    2 * self.padding[0], 0               # time (causal)
)
self.padding = (0, 0, 0)
解释:
  • 原始的 nn.Conv3d 可能有 padding,例如 (pt, ph, pw)

  • 这里保留了空间方向的对称 padding,但对 时间方向进行非对称 padding,构造成:

    (left_w, right_w, top_h, bottom_h, front_t, back_t)
    

    对应 F.pad 的顺序 (从最后一维到最前一维,这是函数的特性)

  • 关键点是时间维 front_t = 2 * pt, back_t = 0,只在“前面”pad,实现因果性;

  • 然后把模块自己的 padding 设置为全 0,表示 不再让 nn.Conv3d 自己pad,改为我们自己手动 pad。


🧠 四、前向传播逻辑(forward

def forward(self, x, cache_x=None):
  • x:当前时间片段的输入;
  • cache_x:可选的历史帧缓存(可用于跨步生成,例如 autoregressive inference);
if cache_x is not None and self._padding[4] > 0:
    x = torch.cat([cache_x, x], dim=2)
    padding[4] -= cache_x.shape[2]
  • 如果有缓存帧,就拼接在时间维上(dim=2);
  • 同时减少对应的 padding 长度,避免重复 pad。
x = F.pad(x, padding)
  • 手动 pad,特别是时间前面的 2×padding。
return super().forward(x)
  • 调用原始的 nn.Conv3d 前向传播,但此时输入已手动 pad,padding 设为0。

📝 总结

组件 作用
self._padding 记录定制化的 padding,特别是时间维只 pad 前面,实现因果性
F.pad 手动完成 padding
cache_x 支持 autoregressive 或连续生成,传入历史帧
self.padding = (0,0,0) 禁用 nn.Conv3d 的自动 padding
super().forward(x) 使用 Conv3d 的正常卷积

在init构造函数中,self.padding是从哪里来的?在这个代码之前没有定义这个变量啊?

在这段代码中:

class CausalConv3d(nn.Conv3d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._padding = (self.padding[2], self.padding[2], self.padding[1],
                         self.padding[1], 2 * self.padding[0], 0)
        self.padding = (0, 0, 0)

你注意到 self.padding__init__ 中看起来好像“从没定义过就用了”,但实际上它 来自父类 nn.Conv3d 的初始化


self.padding 的来源解释:

这段代码继承了 PyTorch 的内置模块 nn.Conv3d

super().__init__(*args, **kwargs)

这一行会调用 nn.Conv3d 的构造函数,而在 PyTorch 的实现中,nn.Conv3d 构造器里会初始化并设置一个成员变量 self.padding,这个 padding 是你在实例化 CausalConv3d 时传入的参数之一:

conv = CausalConv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), padding=(1, 1, 1))

所以 super().__init__() 执行完后,self.padding 就已经是一个元组了,比如:

self.padding = (1, 1, 1)

也就是 (pt, ph, pw) 对应时间、高、宽三个维度的 padding 大小。


🔁 然后我们做了什么?

我们用这三维的 self.padding 来构造一个 PyTorch F.pad 需要的六维 padding:

self._padding = (pw, pw, ph, ph, 2*pt, 0)

这是 (left_w, right_w, top_h, bottom_h, front_t, back_t),符合 F.pad 的要求(从最后一维往前 pad)。


🧹 最后为什么还把 self.padding 设置为 (0, 0, 0)?

这是为了防止 nn.Conv3d 自己再加 padding。因为我们已经手动用 F.pad 做了 padding,所以我们显式地把原来的 self.padding 清除掉,防止重复 padding。


✅ 总结

步骤 说明
super().__init__() 调用父类 nn.Conv3d 的构造函数,设置了 self.padding
self._padding self.padding 计算出 6 个维度的 padding 以供手动使用
self.padding = (0, 0, 0) 禁用 nn.Conv3d 默认 padding,避免重复 pad

测试模块


# ==== 测试 ====
# 假设我们输入 batch size 为 1,通道数 3,时间 4 帧,空间大小为 8x8
B, C_in, T, H, W = 1, 3, 4, 8, 8
x = torch.randn(B, C_in, T, H, W)

# 创建 CausalConv3d 实例:将 3通道输入映射为 8通道输出,kernel_size=3x3x3,padding=(1,1,1)
conv = CausalConv3d(in_channels=3, out_channels=8, kernel_size=(3,3,3), padding=(1,1,1))

# 前向传播
out = conv(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")

输出:

Input shape:  torch.Size([1, 3, 4, 8, 8])
Output shape: torch.Size([1, 8, 4, 8, 8])

关于cache机制

这段代码的目的是:实现时序因果卷积(Causal 3D Convolution)中的缓存机制,用于处理连续视频帧或特征块时,保留历史帧特征并拼接到当前输入上,从而保持时间上的连续性和因果性。


我们逐行解释:

if cache_x is not None and self._padding[4] > 0:
  • cache_x 是「前一段输入」的缓存特征(例如上一个视频块的最后几帧),用于当前块的因果卷积。
  • self._padding[4] 是时间维的左侧 pad 数(也就是需要保留多少前帧来补时间维的因果性)。只有当这两个条件满足时才执行以下逻辑。

    cache_x = cache_x.to(x.device)
  • 确保缓存特征和当前输入 x 在同一个设备(如 GPU)上。

    x = torch.cat([cache_x, x], dim=2)
  • 沿时间维 (dim=2) 拼接:将「缓存帧」加在当前输入前面,构造因果窗口(即现在的 t 帧只能依赖 ≤t 的帧,而不能看到未来的帧)。

    padding[4] -= cache_x.shape[2]
  • 动态调整 pad 数。因为我们已经通过 cache_x 提供了一部分前帧,因此左侧 padding 应减少对应数量,避免重复 pad。

🔁 举个例子(模拟)

假设你有:

  • 当前输入 x.shape = [1, 3, 4, 8, 8](4 帧)
  • 卷积核大小 k_t=3,原始 padding[0]=1
  • self._padding[4]=2(左 pad 2 帧,右 pad 0 帧)

你设置了:

  • cache_x.shape = [1, 3, 1, 8, 8](提供1帧历史)

执行后:

  • 拼接后 x.shape = [1, 3, 5, 8, 8](1帧缓存 + 4帧当前)
  • 然后时间 pad 改为只 pad 1 帧(剩余1帧 pad)

✅ 总结作用:

这段代码通过引入 cache_x 实现:

  1. 因果性(不看未来帧)
  2. 跨块连续性(上下文不丢失)
  3. 减少无意义的 zero padding(尽量使用真实前帧)

适用于处理 视频分块生成 / 推理 场景中,前后时序需要关联但不能看到未来帧的任务。

Logo

电影级数字人,免显卡端渲染SDK,十行代码即可调用,工业级demo免费开源下载!

更多推荐