wan2.1代码解读——VAE模块中的CausalConv3d
组件作用记录定制化的 padding,特别是时间维只 pad 前面,实现因果性F.pad手动完成 paddingcache_x支持 autoregressive 或连续生成,传入历史帧禁用nn.Conv3d的自动 padding使用 Conv3d 的正常卷积在init构造函数中,self.padding是从哪里来的?在这个代码之前没有定义这个变量啊?你注意到在__init__中看起来好像“从没定义
项目地址:wan2.1
参考博文:torch.nn.Conv3D 参数及使用详解
记录一下看wan2.1仓库代码的过程,GPT辅助。
源代码:
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
这段代码定义了一个用于 因果性(Causal)3D卷积 的 PyTorch 模块 CausalConv3d,它继承自 nn.Conv3d,并修改了 padding 和输入策略,以保证 时间维度上的因果性 —— 即输出时间点 t 只依赖于时间 ≤ t 的输入,不会“看到未来”。
📌 一、因果性背景
在时间序列(如视频帧)建模中,因果卷积的核心目标是:
当前时刻的输出只能依赖于过去和当前的输入,不能访问未来帧。
对于 3D 卷积,维度顺序是 (N, C, T, H, W),其中:
T:时间维(通常是因果卷积处理的重点);H,W:空间维。
📦 二、类结构说明
class CausalConv3d(nn.Conv3d):
继承自 nn.Conv3d,表示这是一个标准的 3D 卷积,只是做了一些重写处理。
⚙️ 三、构造函数逻辑(__init__)
self._padding = (
self.padding[2], self.padding[2], # width
self.padding[1], self.padding[1], # height
2 * self.padding[0], 0 # time (causal)
)
self.padding = (0, 0, 0)
解释:
-
原始的
nn.Conv3d可能有 padding,例如(pt, ph, pw); -
这里保留了空间方向的对称 padding,但对 时间方向进行非对称 padding,构造成:
(left_w, right_w, top_h, bottom_h, front_t, back_t)对应
F.pad的顺序 (从最后一维到最前一维,这是函数的特性) ; -
关键点是时间维
front_t = 2 * pt,back_t = 0,只在“前面”pad,实现因果性; -
然后把模块自己的
padding设置为全 0,表示 不再让nn.Conv3d自己pad,改为我们自己手动 pad。
🧠 四、前向传播逻辑(forward)
def forward(self, x, cache_x=None):
x:当前时间片段的输入;cache_x:可选的历史帧缓存(可用于跨步生成,例如 autoregressive inference);
if cache_x is not None and self._padding[4] > 0:
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
- 如果有缓存帧,就拼接在时间维上(dim=2);
- 同时减少对应的 padding 长度,避免重复 pad。
x = F.pad(x, padding)
- 手动 pad,特别是时间前面的 2×padding。
return super().forward(x)
- 调用原始的
nn.Conv3d前向传播,但此时输入已手动 pad,padding 设为0。
📝 总结
| 组件 | 作用 |
|---|---|
self._padding |
记录定制化的 padding,特别是时间维只 pad 前面,实现因果性 |
F.pad |
手动完成 padding |
cache_x |
支持 autoregressive 或连续生成,传入历史帧 |
self.padding = (0,0,0) |
禁用 nn.Conv3d 的自动 padding |
super().forward(x) |
使用 Conv3d 的正常卷积 |
在init构造函数中,self.padding是从哪里来的?在这个代码之前没有定义这个变量啊?
在这段代码中:
class CausalConv3d(nn.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
你注意到 self.padding 在 __init__ 中看起来好像“从没定义过就用了”,但实际上它 来自父类 nn.Conv3d 的初始化。
✅ self.padding 的来源解释:
这段代码继承了 PyTorch 的内置模块 nn.Conv3d:
super().__init__(*args, **kwargs)
这一行会调用 nn.Conv3d 的构造函数,而在 PyTorch 的实现中,nn.Conv3d 构造器里会初始化并设置一个成员变量 self.padding,这个 padding 是你在实例化 CausalConv3d 时传入的参数之一:
conv = CausalConv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), padding=(1, 1, 1))
所以 super().__init__() 执行完后,self.padding 就已经是一个元组了,比如:
self.padding = (1, 1, 1)
也就是 (pt, ph, pw) 对应时间、高、宽三个维度的 padding 大小。
🔁 然后我们做了什么?
我们用这三维的 self.padding 来构造一个 PyTorch F.pad 需要的六维 padding:
self._padding = (pw, pw, ph, ph, 2*pt, 0)
这是 (left_w, right_w, top_h, bottom_h, front_t, back_t),符合 F.pad 的要求(从最后一维往前 pad)。
🧹 最后为什么还把 self.padding 设置为 (0, 0, 0)?
这是为了防止 nn.Conv3d 自己再加 padding。因为我们已经手动用 F.pad 做了 padding,所以我们显式地把原来的 self.padding 清除掉,防止重复 padding。
✅ 总结
| 步骤 | 说明 |
|---|---|
super().__init__() |
调用父类 nn.Conv3d 的构造函数,设置了 self.padding |
self._padding |
从 self.padding 计算出 6 个维度的 padding 以供手动使用 |
self.padding = (0, 0, 0) |
禁用 nn.Conv3d 默认 padding,避免重复 pad |
测试模块
# ==== 测试 ====
# 假设我们输入 batch size 为 1,通道数 3,时间 4 帧,空间大小为 8x8
B, C_in, T, H, W = 1, 3, 4, 8, 8
x = torch.randn(B, C_in, T, H, W)
# 创建 CausalConv3d 实例:将 3通道输入映射为 8通道输出,kernel_size=3x3x3,padding=(1,1,1)
conv = CausalConv3d(in_channels=3, out_channels=8, kernel_size=(3,3,3), padding=(1,1,1))
# 前向传播
out = conv(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
输出:
Input shape: torch.Size([1, 3, 4, 8, 8])
Output shape: torch.Size([1, 8, 4, 8, 8])
关于cache机制
这段代码的目的是:实现时序因果卷积(Causal 3D Convolution)中的缓存机制,用于处理连续视频帧或特征块时,保留历史帧特征并拼接到当前输入上,从而保持时间上的连续性和因果性。
我们逐行解释:
if cache_x is not None and self._padding[4] > 0:
cache_x是「前一段输入」的缓存特征(例如上一个视频块的最后几帧),用于当前块的因果卷积。self._padding[4]是时间维的左侧 pad 数(也就是需要保留多少前帧来补时间维的因果性)。只有当这两个条件满足时才执行以下逻辑。
cache_x = cache_x.to(x.device)
- 确保缓存特征和当前输入
x在同一个设备(如 GPU)上。
x = torch.cat([cache_x, x], dim=2)
- 沿时间维 (
dim=2) 拼接:将「缓存帧」加在当前输入前面,构造因果窗口(即现在的 t 帧只能依赖 ≤t 的帧,而不能看到未来的帧)。
padding[4] -= cache_x.shape[2]
- 动态调整 pad 数。因为我们已经通过
cache_x提供了一部分前帧,因此左侧 padding 应减少对应数量,避免重复 pad。
🔁 举个例子(模拟)
假设你有:
- 当前输入
x.shape = [1, 3, 4, 8, 8](4 帧) - 卷积核大小
k_t=3,原始padding[0]=1 - 则
self._padding[4]=2(左 pad 2 帧,右 pad 0 帧)
你设置了:
cache_x.shape = [1, 3, 1, 8, 8](提供1帧历史)
执行后:
- 拼接后
x.shape = [1, 3, 5, 8, 8](1帧缓存 + 4帧当前) - 然后时间 pad 改为只 pad 1 帧(剩余1帧 pad)
✅ 总结作用:
这段代码通过引入 cache_x 实现:
- 因果性(不看未来帧)
- 跨块连续性(上下文不丢失)
- 减少无意义的 zero padding(尽量使用真实前帧)
适用于处理 视频分块生成 / 推理 场景中,前后时序需要关联但不能看到未来帧的任务。
更多推荐




所有评论(0)