基于 pytorch-2.8.0/torch/utils/checkpoint.py
reentrant 是什么,为什么叫“可重入”
“reentrant(可重入)checkpoint”的做法是:在前向阶段用 torch.no_grad() 运行被 checkpoint 的函数,不记录它的 autograd 计算图;等到反向阶段,autograd 引擎会“再次进入(re-enter)用户前向函数”,把这段前向重新跑一遍,用于重建需要的激活并继续求梯度。简单说:正向不记图,反向再“重入前向”重算。
“non-reentrant(非可重入)checkpoint”则在正向时正常记图,并通过“已保存张量钩子”来在需要时只重建必要的中间量;因此它允许你在 checkpoint 区域内部也能再做一次 backward 等等,适配面更广。
之所以叫“可重入”,就是因为反向时会“重入”的前向函数去重算(而不是复用正向时记录的图)。当前更推荐 non-reentrant。
可重入实现(old) 在这不重点分析了
forward:正向
把 run_function、preserve_rng_state、推断出的 设备类型(cuda/cpu/其他)和 autocast 配置 缓存在 ctx 上。
若需要保存 RNG:保存 CPU RNG,并在已初始化的相应设备上(如 CUDA)保存每个参与的 device 的 RNG 状态(get_device_states)。
仅把 Tensor 型输入通过 ctx.save_for_backward 存起来,非 Tensor 输入保存在 ctx.inputs,Tensor 在那里留一个 None 的占位,然后在 torch.no_grad() 下执行真正的前向,因此中间激活并不会入图保存。
backward:重建现场 + 重算 + 取梯度
首先检查兼容性:当 use_reentrant=True 时,不允许使用 .grad() 或 .backward(inputs=…) 这种“显式指定输入”的反向路径,否则抛错;这是这条实现最重要的限制之一。
复原输入:把刚才保存的 Tensors 填回 ctx.inputs 里的占位。
RNG 恢复:用 torch.random.fork_rng 把外部 RNG 状态备份一下,然后把 CPU RNG + 对应设备的 RNG 恢复到“前向当时”的状态。
AMP/Autocast 恢复:根据 forward 时记录的设备与 cpu 的 autocast 状态,分别进入相同的 autocast 上下文。
重算:对分离后的输入(detach_variable,并保留 requires_grad 标志)调用 run_function。这样会重新构建一张“临时前向图”。
喂入反向梯度:把 backward 传进来的 *args 当作上游梯度,调用 Autograd 让它在这张“重算图”上做传播。
非可重入实现(propose) 不再在 Autograd 的 backward 回调里重跑前向,而是借助 Saved Tensor Hooks 控制“什么时候、如何重算出需要的中间张量”。
入口 非重入实现使用以下逻辑
1 2 3 4 5 6 7 8 9 10 11 12 13 14 @torch._disable_dynamo def checkpoint (... ): gen = _checkpoint_without_reentrant_generator( function, preserve, context_fn, determinism_check, debug, *args, **kwargs ) next (gen) ret = function(*args, **kwargs) try : next (gen) except StopIteration: return ret
gen = _checkpoint_...() : 在 checkpoint 函数中,一个生成器对象 gen 被创建。此时,_checkpoint_... 函数内的代码一行都未执行 。
next(gen) - 第一次迭代 :
前置工作
最重要的是进入了 with _checkpoint_hook(...) 上下文管理器 。
控制权返回到 checkpoint 函数,next(gen) 调用结束。
ret = function(...) - 核心执行 :
checkpoint 函数现在调用用户传入的 function。
虽然这个调用发生在 checkpoint 函数的作用域里,但由于上一步激活的 _checkpoint_hook 仍然“存活”在暂停的生成器中,这个钩子会对 function 的执行产生作用!它会拦截所有需要保存以备反向传播的张量,用一个轻量级的占位符 _Holder 替换它们,从而实现节省内存的目的。
function 执行完毕,返回结果 ret。
next(gen) - 第二次迭代 :
with _checkpoint_hook(...) 语句块执行完毕,上下文管理器正常退出。
执行清理代码。
_checkpoint_... 函数执行到末尾,自然结束。
StopIteration :
当一个生成器函数执行完毕时,它会自动引发一个 StopIteration 异常。
checkpoint 函数中的 except StopIteration: 捕获了这个异常,这标志着整个流程(前置 - 执行 - 后置)已经成功完成。
函数返回 ret。
一般做法是将生成器转化为上下文管理器 contextlib.contextmanager。这里是反过来了。一个在生成器内部启动的上下文,其效果作用于生成器外部的函数调用
可能的解释是:为了将“状态的生命周期”与“函数的执行”解耦:checkpoint 的设计巧妙地将状态的创建/销毁逻辑(在生成器中)与被该状态影响的核心业务逻辑(用户的 function)分离开来。checkpoint 函数作为“指挥官”,协调着两者的交互,但两者本身互不知晓对方的内部实现。这是一种高度内聚、低耦合的优雅设计。
TorchDynamo 说明
TorchDynamo 不会进入 utils.checkpoint 函数内部进行分析。整个流程如下:
当 TorchDynamo 遇到 utils.checkpoint 函数时,会尝试将其封装为一个高阶操作符(HigherOrderOp)。这个过程分为三个阶段:
TorchDynamo 会先试探性地检查传入的前向传播函数是否可以安全地被追踪(即函数内部是否包含 Dynamo 支持的运算)。如果函数逻辑简单且符合追踪规范,则进入下一步。
如果前向传播函数被判定为安全,TorchDynamo 会将 utils.checkpoint 整体封装为一个高阶操作符,并将其加入生成的 Fx 计算图中。此时,Dynamo 不会深入分析 utils.checkpoint 内部的具体实现逻辑,而是将其视为一个不可拆分的原子操作。
如果前向传播函数无法被安全追踪(例如包含动态控制流或未知操作),TorchDynamo 会触发图中断(graph break),回退到 PyTorch 的急切模式(eager mode)执行。此时,@torch._disable_dynamo 装饰器会确保 Dynamo 不会对 utils.checkpoint 内部创建的计算帧再次触发优化流程。
生成器逻辑 _checkpoint_without_reentrant_generator
保存状态:保存 CPU 和 GPU 的随机数生成器状态,确保重计算时有完全相同的随机行为。
定义 recompute_fn:创建一个闭包,它知道如何使用保存的 RNG 状态和上下文来重新执行原始函数 fn。
创建 _CheckpointFrame:初始化一个状态对象,用于在前后向之间传递信息。
应用 _NoopSaveInputs.apply(dummy, *args),将所有输入 args“保存”到计算图中。这样在反向传播时,我们可以通过它的 grad_fn 访问到这些原始输入。
设置钩子:通过 with _checkpoint_hook(frame): 激活我们的自定义保存行为。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 def _checkpoint_without_reentrant_generator (fn, preserve_rng_state, ... ): if preserve_rng_state: fwd_cpu_state = torch.get_rng_state() def recompute_fn (*inputs ): ... new_frame = _CheckpointFrame(recompute_fn, ...) dummy = torch.empty((0 ,), requires_grad=True ) new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) if new_frame.input_saver.grad_fn is None : yield return with _checkpoint_hook(new_frame), forward_context: yield ...
环境状态捕获 _CheckpointFrame 1 2 new_frame = _CheckpointFrame(recompute_fn, ...)
管理前向期间所有被 checkpoint 的张量 (用弱引用占位符替代本体)。
调度和缓存反向期间的重计算结果 ,支持嵌套、多次 backward 等复杂场景。
支持提前终止重计算 ,提升效率。
元数据比对与 Debug 支持 ,强力保障模型梯度正确性。
通过 WeakKeyDictionary 等机制高效管理内存和生命周期 ,无需手动清理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 class _CheckpointFrame : def __init__ (self, recompute_fn, early_stop, ... 用于异常和校验 ): """ recompute_fn: 封装了 如何重算前向区域 的函数闭包。注意它会复现原来的 RNG 状态、上下文环境等。 early_stop: 是否允许在重计算时“提前终止”。如果为 True,只要所有需要的张量已重算出来就立即中断,节约计算。 """ self.recompute_fn = recompute_fn self.input_saver = None self.weak_holders: List [ReferenceType] = [] self.recomputed: DefaultDict[ int , weakref.WeakKeyDictionary[_Handle, torch.Tensor] ] = defaultdict(weakref.WeakKeyDictionary) self.recomp_counter: DefaultDict[int , int ] = defaultdict(int ) self.is_recomputed: DefaultDict[int , bool ] = defaultdict(bool ) self.early_stop = early_stop ...
_No Operation Save Inputs_:与 Autograd 计算图生命周期绑定的存储方式
问题 :我们需要一种独立于 Python GC、与计算图生命周期同步的方式来保存 checkpoint 的输入。
方案 :我们不自己实现这个复杂的生命周期管理,而是“欺骗”并“利用”Autograd 引擎,让它来为我们做这件事。
实现 :通过一个“空操作”的 autograd.Function,Autograd 引擎可以将张量的生命周期与计算图绑定。只要这个计算节点还在图中(即未来的反向传播还需要它),Autograd 就会确保 save_for_backward 保存的张量是可访问的。通过上下文对象ctx.saved_tensors 把它们取回来
基类 torch.autograd.Function 是 PyTorch 中自定义自动微分逻的核心抽象类,用于突破默认自动微分(依赖计算图追踪)的限制,让开发者手动定义算子的前向传播(forward)与反向传播(backward,梯度计算)规则。
在这里我们需要的是借助 ctx 上下文对象保存张量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 class _SingleLevelFunction ( _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta ): @staticmethod def forward (*args: Any , **kwargs: Any ) -> Any : r"""定义自定义 autograd Function 的前向传播。 这个函数需要被所有子类重写。 定义 forward 有两种方式: 用法 1 (合并 forward 和 ctx):: @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass - 它必须接受一个上下文对象 ctx 作为第一个参数,后面跟着任意数量的参数(张量或其他类型)。 - 更多细节请参见 :ref:`combining-forward-context`。 用法 2 (分离 forward 和 ctx):: @staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass - 此时 forward不再接受 ctx 参数。 - 相反,你必须同时重写 :meth:`torch.autograd.Function.setup_context` 这个静态方法来处理 `ctx` 对象的设置。 `output` 是 forward 的输出,`inputs` 是一个包含 forward 所有输入的元组(Tuple)。 - 更多细节请参见 :ref:`extending-autograd`。 上下文 `ctx` 可以用来存储任意数据,这些数据可以在反向传播过程中被取出。 张量不应该直接存储在 `ctx` 上(尽管为了向后兼容目前没有强制执行)。 相反,如果张量打算在 `backward` (等价于 `vjp`) 中使用,应该用 :func:`ctx.save_for_backward` 保存; 如果打算在 `jvp` 中使用,则用 :func:`ctx.save_for_forward` 保存。 """ raise NotImplementedError( "你必须为自定义的 autograd.Function 实现 forward 函数。" ) @staticmethod def setup_context (ctx: Any , inputs: tuple [Any , ...], output: Any ) -> Any : r"""定义 autograd.Function 前向传播有两种方式。 其中一种是: 1. 使用 `forward(ctx, *args, **kwargs)` 签名来重写 forward。 此时 `setup_context` 不需要被重写。为反向传播设置 ctx 的操作在 `forward` 内部完成。 2. 使用 `forward(*args, **kwargs)` 签名来重写 forward,并同时重写 `setup_context`。 此时为反向传播设置 ctx 的操作在 `setup_context` 内部完成(而不是在 `forward` 内部)。 更多细节请参见 :meth:`torch.autograd.Function.forward` 和 :ref:`extending-autograd`。 """ raise NotImplementedError("setup_context 未被实现。" ) @staticmethod def backward (ctx: Any , *grad_outputs: Any ) -> Any : r"""为该操作定义在反向模式自动微分下的求导法则。 这个函数需要被所有子类重写。 (定义这个函数等价于定义 `vjp` 函数。) 它必须接受一个上下文 :attr:`ctx` 作为第一个参数,后面跟着与 :func:`forward` 返回值一样多的梯度输出 (对于 forward 返回的非张量输出,这里会传入 None), 并且它应该返回与 :func:`forward` 的输入一样多的张量。 每个参数都是对应输出的梯度,而每个返回值应该是对应输入的梯度。 如果某个输入不是张量,或者是一个不需要梯度的张量,你可以为那个输入直接返回 None 作为梯度。 上下文 `ctx` 可以用来取回在前向传播中保存的张量。 它还有一个属性 :attr:`ctx.needs_input_grad`,是一个布尔值的元组,表示每个输入是否需要梯度。 例如,如果 :func:`forward` 的第一个输入需要计算梯度,那么在 :func:`backward` 中 ``ctx.needs_input_grad[0]`` 的值就会是 ``True``。 """ raise NotImplementedError( "为了在反向模式 AD 中使用你的自定义 autograd.Function," "你必须实现 backward 或 vjp 方法。" ) vjp = backward @staticmethod def jvp (ctx: Any , *grad_inputs: Any ) -> Any : r"""为该操作定义在前向模式自动微分下的求导法则。 """ class Function (_SingleLevelFunction ): r"""用于创建自定义 `autograd.Function` 的基类。 要创建一个自定义的 `autograd.Function`,需要继承这个类并实现 :meth:`forward` 和 :meth:`backward` 这两个静态方法。然后,在前向传播中使用你的 自定义操作时,调用类方法 ``apply``。不要直接调用 :meth:`forward`。 为了确保正确性和最佳性能,请确保你在 ``ctx`` 上调用了正确的方法,并使用 :func:`torch.autograd.gradcheck` 来验证你的 backward 函数。 """ def __init__ (self, *args, **kwargs ): warnings.warn() def __call__ (self, *args, **kwargs ): raise RuntimeError() @staticmethod def vmap (info, in_dims, *args ): raise NotImplementedError() @classmethod def apply (cls, *args, **kwargs ): def bind_default_args (func, *args, **kwargs ): signature = inspect.signature(func) bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() return bound_args.args is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context) if is_setup_ctx_defined: args = bind_default_args(cls.forward, *args, **kwargs) if not torch._C._are_functorch_transforms_active(): args = _functorch.utils.unwrap_dead_wrappers(args) return super ().apply(*args, **kwargs) if not is_setup_ctx_defined: raise RuntimeError( "为了将 autograd.Function 与 functorch 转换(vmap, grad, jvp, jacrev, ...)一起使用," "它必须重写 setup_context 静态方法。更多细节,请参见 " "https://pytorch.org/docs/main/notes/extending.func.html" ) return custom_function_call(cls, *args, **kwargs) def _is_setup_context_defined (fn ): return fn != _SingleLevelFunction.setup_contex
以 y = x² 为例,手动实现前向与反向(反向梯度应为 dy/dx = 2x):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 import torchclass SquareFunction (torch.autograd.Function ): @staticmethod def forward (ctx, x ): y = x ** 2 ctx.save_for_backward(x) return y @staticmethod def backward (ctx, grad_y ): x, = ctx.saved_tensors grad_x = grad_y * 2 * x return grad_x x = torch.tensor([1.0 , 2.0 ], requires_grad=True ) y = SquareFunction.apply(x) y.sum ().backward() print (x.grad)
自定义 Function 会被视为计算图中的一个“原子节点”:
前向时,Function.apply(x) 会在计算图中插入一个“自定义节点”,并通过 ctx 隐式关联前向数据与反向逻辑;
反向时,Autograd 引擎遇到该节点,会直接调用其 backward 方法,而非自动推导梯度。
实现
一个自定义的 autograd.Function。
它的 forward 方法几乎什么都不做(No-op)。
setup_context 使用 ctx.save_for_backward 把所有我们需要的输入保存起来。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 class _NoopSaveInputs (torch.autograd.Function ): @staticmethod def forward (*args ): return torch.empty((0 ,)) @staticmethod def setup_context (ctx: Any , inputs: Tuple [Any , ...], output: Any ) -> None : tensor_indices, tensors = zip ( *[(i, o) for i, o in enumerate (inputs) if isinstance (o, torch.Tensor)] ) idx2saved_idx = {b: a for a, b in enumerate (tensor_indices)} args = [None if isinstance (o, torch.Tensor) else o for o in inputs] def get_args (saved_tensors ): ret = [ saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o for i, o in enumerate (args) ] return ret[1 :] ctx.get_args = get_args ctx.save_for_backward(*tensors) @staticmethod def backward (ctx, *grad_outputs ): raise AssertionError("Did not expect to backward on this graph" )
1 2 3 4 dummy = torch.empty((0 ,), requires_grad=True ) new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
torch.autograd.Function 是自定义可微操作的核心类,而 apply 方法是其唯一实例化调用入口,用于触发自定义操作的前向传播(forward)与后续反向传播(backward)逻辑
dummy: 这是一个需要梯度的张量(torch.empty((0,), requires_grad=True))。它的作用是确保 _NoopSaveInputs 这个操作被记录到计算图中。如果所有输入都不需要梯度,Autograd 引擎可能会优化掉这个节点,保存机制就会失效。
kwargs, *args: 需要传递给用户函数 fn 的所有参数
.apply(...): 执行时:
_NoopSaveInputs.forward 被调用,它接收 (dummy, kwargs, *args) 作为参数,然后返回一个空张量。
_NoopSaveInputs.setup_context 被调用,打包了所有 args 和 kwargs,将张量部分交给 ctx.save_for_backward,并创建了一个 ctx.get_args 函数用于未来的恢复。
new_frame.input_saver = ...: .apply() 的返回值是一个张量,这个张量拥有一个 .grad_fn 属性,它指向一个代表 _NoopSaveInputs 反向节点的对象。这个对象内部就包含了我们设置好的 ctx
_图不一定连通_:_NoopSaveInputs 形成了一个并行的、逻辑上的分支。虽然它不直接参与梯度计算流,但它通过 dummy.requires_grad=True 这个属性,作为一个有效的图节点存在。
在反向传播期间,checkpoint 内部的钩子 _checkpoint_hook 会被触发,这个钩子可以通过 new_frame.input_saver 找到这个 _NoopSaveInputsBackward 节点,并从中提取出保存的输入 x 来进行重计算。
1 2 3 4 5 6 7 8 9 10 11 12 13 def unpack_hook (holder ): ctx = frame.input_saver.grad_fn original_inputs = ctx.get_args(ctx.saved_tensors) frame.recompute_fn(*original_inputs)
占位符 _Holder
_Holder:保存张量的占位符对象(不是张量本体)。正向在 saved_tensors_hooks(pack, unpack) 的 pack_hook 中为每个需要“保存”的张量创建一个 _Holder 并把它塞回计算图里,真正的张量稍后在反向触发“重算”时再填充到缓存。_Holder 内部维护一个 handles: Dict[int, Optional[_Handle]],把一次前向里出现的保存位映射到对应的 _Handle
_Handle:每个“保存位”的键对象。帧级缓存用 WeakKeyDictionary[_Handle, Tensor] 存(key 是 _Handle,value 才是重算得到的真实 Tensor)。这样只要 _Handle 没有强引用(由 _Holder 持有的引用消失),缓存项就会自动被回收,适配“只做了部分反向”的情况。
1 2 3 4 5 _CheckpointFrame .weak_holders: List [ReferenceType] .recomputed: DefaultDict[int , weakref.WeakKeyDictionary[_Handle, torch.Tensor]] .recomp_counter: DefaultDict[int , int ] .is_recomputed: DefaultDict[int , bool ]
1 2 3 4 5 6 7 8 9 10 11 12 13 class _Handle : """ 它的实例被用作 `_CheckpointFrame.recomputed` 字典的键(key) """ pass class _Holder : def __init__ (self ): """ - 键 `int` 类型,代表 `gid` (graph_task_id),即当前反向传播任务的唯一 ID。每次调用 `.backward()` 都会生成一个新的 `gid`。 - 值 `Optional[_Handle]`。它是一个 `_Handle` 实例,作为在 `_CheckpointFrame.recomputed` 字典中查找重计算张量的键。当张量被解包使用后,这个值会被设为 `None`,以防止在同一次反向传播中被重复使用。 """ self.handles: Dict [int , Optional [_Handle]] = {}
_Holder 的生命周期贯穿了检查点机制的三个关键阶段:前向传播(打包)、反向传播(解包与重计算)、以及状态清理。
阶段一:前向传播(打包 - Pack)
1 2 3 4 def pack_hook (x ): holder = _Holder() frame.weak_holders.append(weakref.ref(holder)) return holder
为什么是弱引用? Autograd 引擎会将 holder 实例保存在一个 SavedVariable 对象中。如果 retain_graph=False,反向传播结束后,SavedVariable 会被销毁,对 holder 的强引用就消失了。这时,如果没有其他地方强引用 holder,垃圾回收器就会回收它。
阶段二:反向传播(解包 - Unpack 与重计算)
当反向传播过程需要用到被 _Holder 替代的那个张量时,_checkpoint_hook 的 unpack_hook 会被调用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 def unpack_hook (holder ): gid = torch._C._current_graph_task_id() if gid == -1 : gid = int (uuid.uuid4()) if not frame.is_recomputed[gid]: with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad(): frame.recompute_fn(*args) ... frame.is_recomputed[gid] = True _internal_assert(gid in holder.handles) if holder.handles[gid] is None : raise CheckpointError(...) handle_key = holder.handles[gid] _internal_assert(handle_key in frame.recomputed[gid]) ret = frame.recomputed[gid][handle_key] holder.handles[gid] = None return ret
_Holder 在重计算过程中的角色:
重计算由 _recomputation_hook 控制。它的 pack_hook 会在重计算过程中再次“保存”张量。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class _recomputation_hook (torch.autograd.graph.saved_tensors_hooks ): def __init__ (self, target_frame_ref: ReferenceType, gid: int ): def pack_hook (x ): holder = target_frame.weak_holders[recomp_idx]() if holder is not None : _internal_assert(holder.handles.get(gid, None ) is None ) holder.handles[gid] = _Handle() target_frame.recomputed[gid][holder.handles[gid]] = x return x
前向反向 Hook _checkpoint_hook 基类 saved_tensors_hooks 是一个上下文管理器(with 语句),用于定义张量在反向传播中被保存和读取时的自定义操作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 class saved_tensors_hooks : """上下文管理器,用于为保存的张量设置一对打包 / 解包钩子。 使用这个上下文管理器来定义操作的中间结果应该如何在保存前进行打包,并在检索时进行解包。 在这个上下文中,每次操作保存一个张量用于反向传播时(这包括使用 :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` 保存的中间结果, 也包括由 PyTorch 定义的操作记录的张量),都会调用 ``pack_hook`` 函数。``pack_hook`` 的输出将被存储在计算图中,而不是原始张量。 当需要访问保存的张量时(即在执行 :func:`torch.Tensor.backward()` 或 :func:`torch.autograd.grad()` 时),会调用 ``unpack_hook``。它接收 ``pack_hook`` 返回的*打包后的对象*作为参数,并应返回一个与原始张量内容相同的张量(即在对应 ``pack_hook`` 中传入的张量)。 钩子函数应具有如下签名: pack_hook(tensor: Tensor) -> Any unpack_hook(Any) -> Tensor 其中 ``pack_hook`` 的返回值应是 ``unpack_hook`` 的有效输入。 通常,你希望 ``unpack_hook(pack_hook(t))`` 在值、大小、数据类型和设备上与 ``t`` 相等。 示例:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def pack_hook(x): ... print("Packing", x) ... return x >>> >>> def unpack_hook(x): ... print("Unpacking", x) ... return x >>> >>> a = torch.ones(5, requires_grad=True) >>> b = torch.ones(5, requires_grad=True) * 2 >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): ... y = a * b Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) >>> y.sum().backward() Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) .. warning :: 在钩子的输入上执行原地操作可能导致未定义行为。 .. warning :: <--- 会用到这个特性 同一时间只允许使用一对钩子。当递归嵌套使用此上下文管理器时,只有最内层的一对钩子会被应用。 """ def __init__ ( self, pack_hook: Callable [[torch.Tensor], Any ], unpack_hook: Callable [[Any ], torch.Tensor], ) -> None :
checkpoint 继承:
1 2 3 4 5 class _checkpoint_hook (torch.autograd.graph.saved_tensors_hooks ): def __init__ (self, frame: _CheckpointFrame ): def pack_hook (x ): ... def unpack_hook (holder ): ... super ().__init__(pack_hook, unpack_hook)
前向实现 pack_hook 在前向传播时,当 Autograd 引擎想要保存一个大的中间激活张量时,_checkpoint_hook 会介入,阻止保存这个大张量,而是保存一个极小的“凭证”(placeholder)。
1 2 3 with _checkpoint_hook(new_frame), forward_context: yield
_checkpoint_hook 依赖于 _CheckpointFrame 来存储状态 依赖于 _Holder 作为返回给 Autograd 的占位符
1 _CheckpointFrame.weak_holders: List [ReferenceType] = []
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class _checkpoint_hook (torch.autograd.graph.saved_tensors_hooks ): def __init__ (self, frame: _CheckpointFrame ): def pack_hook (x ): holder = _Holder() frame.weak_holders.append(weakref.ref(holder)) return holder def unpack_hook (holder ): ...
重算 Hook _recomputation_hook 实现 在反向传播时,当 Autograd 引擎需要这个张量来计算梯度时,_checkpoint_hook 再次介入。它接收到之前保存的“凭证”,意识到张量并不存在,于是立即触发一段“重计算”流程来重新生成这个张量,然后将新鲜出炉的张量交给 Autograd 引擎。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 class _CheckpointFrame : def __init__ (self, recompute_fn, early_stop, unpack_error_cb, metadata_fn ): self.recompute_fn = recompute_fn self.input_saver = None self.weak_holders: List [ReferenceType] = [] class _checkpoint_hook (torch.autograd.graph.saved_tensors_hooks ): def __init__ (self, frame: _CheckpointFrame ): def pack_hook (x ): ... def unpack_hook (holder ): gid = torch._C._current_graph_task_id() if gid == -1 : gid = int (uuid.uuid4()) if not frame.is_recomputed[gid]: ctx = frame.input_saver.grad_fn args = ctx.get_args(ctx.saved_tensors) try : with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad(): frame.recompute_fn(*args) except _StopRecomputationError: pass frame.is_recomputed[gid] = True ret = frame.recomputed[gid][holder.handles[gid]] holder.handles[gid] = None return ret
_为什么 _recomputation_hook 要单独设计?_:如果用同一个 hook,就必须在 pack_hook 内加多层 if 来区分“我是 forward 还是 recompute”,还要避免 forward 阶段误存真实张量/重算阶段误产 Holder,逻辑会极度耦合与脆弱。而拆分成两个 hook,可以利用“上下文覆盖”这一特性,天然保证不会交叉污染。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 class _recomputation_hook (torch.autograd.graph.saved_tensors_hooks ): def __init__ (self, target_frame_ref: ReferenceType, gid: int ): def pack_hook (x ): x = x.detach() if x.requires_grad else x target_frame = target_frame_ref() assert target_frame is not None recomp_idx = target_frame.recomp_counter[gid] target_frame.recomp_counter[gid] += 1 if recomp_idx >= len (target_frame.weak_holders): assert not target_frame.early_stop if not target_frame.forward_completed: target_frame.ignore_saved_mismatch = True return x raise CheckpointError() holder = target_frame.weak_holders[recomp_idx]() if holder is not None : _internal_assert(holder.handles.get(gid, None ) is None ) holder.handles[gid] = _Handle() target_frame.recomputed[gid][holder.handles[gid]] = x if target_frame.early_stop and target_frame.recomp_counter[gid] == len ( target_frame.weak_holders ): raise _StopRecomputationError return x def unpack_hook (x ): return x super ().__init__(pack_hook, unpack_hook)
时序图:前向 (Forward Pass) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 sequenceDiagram participant C as checkpoint() participant G as _checkpoint_..._generator participant Autograd as Autograd 引擎 participant UserFunc as 用户 function() participant Hook as _checkpoint_hook %% ====================== 1. 准备阶段 ====================== rect rgb(230, 247, 255) note over C: 开始执行 checkpoint() 函数 C->>+G: gen = _checkpoint_..._generator(...) note right of G: 生成器对象被创建,代码<br/>此时一行都未执行。 G-->>-C: 返回 gen 对象 end %% ====================== 2. 前置设置与暂停 (关键阶段) ====================== rect rgb(255, 243, 230) C->>G: next(gen) note right of G: 第一次唤醒:执行生成器的“前置”逻辑 G->>G: 1. 保存RNG状态、创建 _CheckpointFrame (frame) note over G, Autograd: 2. 应用 _NoopSaveInputs 保存输入 G->>Autograd: _NoopSaveInputs.apply(dummy, *args, **kwargs) note left of Autograd: Autograd 执行 _NoopSaveInputs 的<br/> `forward` 和 `setup_context` Autograd->>Autograd: a. 在计算图中创建 _NoopSaveInputsBackward (grad_fn) 节点 Autograd->>Autograd: b. 将 *args, **kwargs 保存到该 grad_fn 节点中 Autograd-->>G: 返回一个哑输出 note right of G: 将这个 grad_fn 节点的引用<br/>保存在 frame.input_saver 中 note over G, Hook: 3. 进入 `with _checkpoint_hook(frame)` 上下文 G->>Hook: 激活 _checkpoint_hook note right of Hook: 状态: **激活** <br/> (pack_hook 已准备就绪) note over G: 4. 到达 `yield`,暂停执行! G-->>C: 控制权返回给 checkpoint() end %% ====================== 3. 核心执行与钩子拦截====================== rect rgb(232, 255, 232) note over C: 生成器已暂停,但其内部的<br/>`_checkpoint_hook` 仍然激活! C->>+UserFunc: ret = function(*args, **kwargs) note right of UserFunc: 开始执行用户代码... UserFunc->>UserFunc: e.g., y = op1(x) note over UserFunc, Autograd: Autograd 尝试为 op1 保存其输入 x UserFunc->>Autograd: (内部调用) note left of Autograd: 发现有激活的 saved_tensors_hooks Autograd->>Hook: 调用 pack_hook(x) note right of Hook: 丢弃张量,返回占位符 Hook->>Hook: a. 创建 _Holder 占位符 Hook->>Hook: b. 在 frame 中存入 holder 的弱引用 Hook->>Hook: c. 丢弃对真实张量 x 的引用 Hook-->>Autograd: d. 返回 _Holder 对象 note left of Autograd: 收到并保存了轻量级的 _Holder,<br/>而非沉重的张量 x。内存节省达成! Autograd-->>UserFunc: UserFunc->>UserFunc: ...更多操作... UserFunc-->>-C: function 执行完毕,返回 ret end %% ====================== 4. 恢复与清理 ====================== rect rgb(255, 243, 230) note over C: 准备执行“后置”逻辑 C->>G: next(gen) note right of G: 第二次唤醒:从 `yield` 之后恢复执行 note over G, Hook: 5. 离开 `with _checkpoint_hook(...)` 上下文 G->>Hook: 停用 _checkpoint_hook note right of Hook: 状态: **非激活** G->>G: 6. 执行清理工作 (e.g., frame.forward_completed = True) note right of G: 7. 生成器函数执行完毕 G-->>C: 抛出 StopIteration 异常 end %% ====================== 5. 结束 ====================== rect rgb(230, 247, 255) note over C: 捕获 StopIteration,表示流程正常结束 C->>C: return ret note over C: checkpoint() 执行结束 end
时序图:反向 & 重计算 (Backward Pass) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 sequenceDiagram participant User as 用户代码 participant Autograd as Autograd 引擎 participant UnpackHook as _checkpoint_hook.unpack_hook participant Frame as _CheckpointFrame (状态中心) participant Inputs as _NoopSaveInputs Backward (输入源) participant RecompHook as _recomputation_hook participant RecompFn as recompute_fn %% ====================== 1. 反向传播启动,遇到“凭证” ====================== rect rgb(255, 235, 238) User->>Autograd: loss.backward() note over Autograd: 开始反向传播... Autograd->>Autograd: 遇到一个需要已保存张量 `x` 的节点 note over Autograd: 从计算图中取出保存物,<br/>发现它是一个 `_Holder` 对象,而非 Tensor! note over Autograd: 检测到有 saved_tensors_hooks (unpack_hook) Autograd->>+UnpackHook: 调用 unpack_hook(holder) end %% ====================== 2. UnpackHook 协调重计算 (核心逻辑) ====================== rect rgb(255, 248, 225) note over UnpackHook: 接收到 holder 凭证 UnpackHook->>Frame: 检查 is_recomputed[gid] 是否为 False note right of Frame: 状态: 首次请求,返回 False note over UnpackHook, Inputs: **需要启动重计算!**<br/>第一步:获取原始输入。 UnpackHook->>Frame: 获取 frame.input_saver UnpackHook->>Inputs: 从 grad_fn 节点中<br/>提取保存的 *args, **kwargs Inputs-->>UnpackHook: 返回原始输入 note over UnpackHook, RecompFn: 第二步:启动生产线 (recompute_fn) UnpackHook->>RecompHook: 激活 `with _recomputation_hook(...)` note right of RecompHook: 状态: **激活**<br/>(pack_hook 已准备就绪,开始“监工”) UnpackHook->>+RecompFn: 调用 recompute_fn(原始输入) end %% ====================== 3. 重计算执行,RecompHook 捕获张量 ====================== rect rgb(232, 242, 255) note over RecompFn: 开始执行重计算...<br/>(内部逻辑与用户原始 function 一致) RecompFn->>RecompFn: e.g., y = op1(x) note over RecompFn, Autograd: 重计算图中的 Autograd 尝试保存 `x` RecompFn->>Autograd: (内部调用) note left of Autograd: 发现有激活的 hooks (RecompHook) Autograd->>RecompHook: 调用 pack_hook(x) note right of RecompHook: **执行“捕获与缓存”操作** RecompHook->>RecompHook: a. detach 张量 x RecompHook->>RecompHook: b. 创建唯一句柄 _Handle RecompHook->>Frame: c. 将 (handle -> x) 存入<br/> frame.recomputed[gid] 缓存 RecompHook-->>Autograd: d. 返回原始张量 x note left of Autograd: 收到 x,重计算图<br/>内部的反向传播得以正常进行 Autograd-->>RecompFn: note over RecompFn: ...更多重计算操作... RecompFn-->>-UnpackHook: recompute_fn 执行完毕 UnpackHook->>RecompHook: `with` 语句结束,停用 RecompHook note right of RecompHook: 状态: **非激活** end %% ====================== 4. UnpackHook 返回结果,Autograd 继续 ====================== rect rgb(224, 247, 224) note over UnpackHook, Frame: 第三步:完成收尾,提供“货物” UnpackHook->>Frame: 标记 is_recomputed[gid] = True UnpackHook->>Frame: 从 frame.recomputed[gid] 缓存中<br/>根据 holder 凭证找到张量 x Frame-->>UnpackHook: 返回重计算出的张量 x UnpackHook-->>-Autograd: 返回新鲜出炉的张量 x note over Autograd: 成功拿到所需的张量! Autograd->>Autograd: 使用 x 计算梯度... Autograd->>Autograd: ...继续反向传播... end
时序图:_Holder 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 sequenceDiagram participant User as User code participant CP as checkpoint participant Frame as _CheckpointFrame participant Hooks as saved_tensors_hooks(pack,unpack) participant SV as SavedVariable(持有“保存物”) participant Bwd as backward() User->>CP: 调用 checkpoint(fn, *args) CP->>Frame: 构造 Frame(recompute_fn / weak_holders / cache) CP->>Hooks: 进入 hooks 上下文(pack, unpack) note over Hooks: 正向阶段 Hooks->>Hooks: pack(x) 创建 holder=_Holder() Hooks->>Frame: weak_holders.append(weakref(holder)) Hooks-->>SV: 返回 holder 作为“保存物”(不是 x) User-->>CP: 正向结束(未保存 x) note over Bwd,Hooks: 反向阶段第一次需要该保存位 Bwd->>Hooks: 触发 unpack(holder) alt 缓存已有 Hooks-->>Bwd: 直接返回缓存 Tensor else 首次解包 Hooks->>Frame: 触发 recompute_fn() 做一次重算 Hooks->>Hooks: 内层 hooks(inner_pack) Hooks->>Frame: inner_pack 顺序遇到保存位 alt 对应 weak_holder 仍存活 Hooks->>Frame: 把真实 Tensor 写入 WeakKeyDictionary[handle](或旧版以 holder 为键) else 已死亡 Hooks-->>Frame: 跳过(不填充,省内存/算力) end Hooks-->>Bwd: 从缓存取出 Tensor 返回 end note over SV,Frame: 随反向推进 SavedVariable reset_data()\ note over SV,Frame: holder 失去强引用,弱键缓存项自动被清理
其他的细节 简单使用示例 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import torchfrom torch import nnfrom torch.utils.checkpoint import checkpointclass Block (nn.Module ): def __init__ (self, dim ): super ().__init__() self.net = nn.Sequential( nn.Linear(dim, dim*4 ), nn.GELU(), nn.Linear(dim*4 , dim) ) def forward (self, x ): return self.net(x) class Model (nn.Module ): def __init__ (self, depth=12 , dim=768 ): super ().__init__() self.blocks = nn.ModuleList([Block(dim) for _ in range (depth)]) def forward (self, x ): for i, blk in enumerate (self.blocks): x = checkpoint(blk, x, use_reentrant=False ) return x x = torch.randn(8 , 768 , requires_grad=True ) m = Model() out = m(x) loss = out.pow (2 ).mean() loss.backward()
嵌套 checkpoint 源码中的解释:
注意 [ 可嵌套的检查点(Nestable Checkpoint)]
嵌套检查点的语义可以通过两个基本规则来定义。遵循这两个规则会引出一个重要的推论,该推论是设计动机的核心。
规则 1:保存的张量仅由最内层的检查点管理,并对所有外层检查点隐藏。 规则 2:内层检查点的输入被视为被保存到其父级检查点中的张量。 推论: 要重新计算任何一个给定的已保存张量,我们必须重新计算所有包裹它的检查点。
为什么这是必然的? 在反向传播中解包某个已保存张量 X 时,我们需要重新计算包含 X 的最内层检查点(根据规则 1)。而为了重新计算这个检查点,我们又需要它的输入,这些输入是由该检查点的父级检查点所管理的(根据规则 2),因此父级也必须先被重新计算。以此类推可以发现:为了成功解包 X,所有在 X 被保存时处于激活状态的检查点都必须被重新计算(除非在当前反向传播过程中,由于其他张量的需求已经完成过这些重计算)。
在实践中,我们使用一个无操作的自动微分函数(noop autograd Function)将输入作为“已保存张量”进行存储。在解包时调用 ctx.saved_tensor 会触发父级检查点的重新计算。
规则 3:开始重新计算时应假设当前没有任何检查点处于激活状态。但在重新计算过程中遇到的新检查点仍需被尊重。 当我们启动重新计算时,会在栈上压入一个专用于重新计算的“已保存变量钩子”(saved variable hook)。更多上下文请参见规则 6 中的示例。
除了上述针对嵌套检查点的基本语义之外,我们还施加了若干额外约束,这些约束可能普遍适用于检查点机制本身。
规则 4:重新计算出的张量的生命周期
重新计算出的张量只属于特定一次反向传播调用,在它们被解包后立即清除。 特别地,即使设置了 retain_graph=True,我们也要求这样做。
[规则 4 的实现细节] 如果我们允许在设置 retain_graph=True 时让重新计算出的张量继续存活,那么我们可以将这些张量作为值存入一个 WeakKeyDictionary(弱键字典),并将强引用打包为键。这样,在反向传播结束时,只要 retain_graph=False,打包的键就会被清除,从而自动删除字典中的对应项。
但如果我们希望在 retain_graph=True 的情况下也能在解包时立刻清除 重新计算出的张量,则不能依赖反向传播自动清除打包的键。取而代之的是:我们将强引用包装在一个容器对象中,并在解包时手动清空该容器。
一个重要细节是:如果发生了第二次反向传播,第二次的重新计算需要重置该容器并创建一个新的键。
规则 5:一旦完成了所需张量的重新计算,就应立即停止重新计算过程。 [规则 5 的实现细节] 在重新计算期间,当已重新计算的张量数量达到预期目标数量时,抛出一个异常。我们在外部通过 try-catch 捕获这一特定异常以提前终止执行。具体示例见下方规则 6。
规则 6:支持在检查点上下文中执行反向传播 [保留图结构的情况:retain_graph=True] 1 2 3 4 5 6 7 8 def fn (x ): y = x.sin() z = y.cos() gx, = torch.autograd.grad(z, x, retain_grad=True ) return gx, z out = checkpoint(fn)(inp) out.backward()
由于 z 是在启用检查点的情况下由 cos() 保存的,它实际上不会被真正保存。因此内部的 .grad() 调用必须触发一次重新计算。
在重新计算过程中,“内部打包钩子”有两个职责:
和通常一样,填充用于存储重新计算出张量的 WeakKeyDictionary;
打包实际的张量(已分离),以便可以在重新构建的计算图上执行反向传播。这些被打包的张量将一直存在直到本次重新计算结束;或者更早地,如果有人以 retain_graph=False 执行反向传播,则会被提前释放。
更一般地说,在以下情况中会对重新构建的图执行反向传播:
如果在前向传播中执行了反向传播:
在原始前向传播期间(若未启用提前停止)
在原始反向传播期间
如果存在多个 .grad() 或 .backward() 调用,即使启用了提前停止,我们也会对重新构建的图执行反向传播(见下例)
[不保留图结构的情况:retain_graph=False] 下面的例子展示了:在重新计算过程中,如果我们发现某些试图重新计算的张量已经被清除会发生什么。
剧透:我们不做特殊处理,直接跳过它们!
1 2 3 4 5 6 7 8 def fn (x ): y = x.sin() z = y.cos() gx, = torch.autograd.grad(z, x) return x.cos() * gx out = checkpoint(fn)(inp) out.backward()
**(1)(2)**:由于处于检查点内,不保存 x 和 y。
**(3)**:触发 fn 的重新计算,因为 x 和 y 未被保存。 根据是否启用提前停止,可能运行到 (2) 就停止,或继续执行下去。 因为此次反向传播使用了 retain_graph=False,所以我们会清除 x 和 y 对应的持有者(holder)。
**(4)**:仍在检查点内,不保存 x。
**(5)**:调用 .backward() 再次触发 fn 的重新计算。 在这次重新计算中,我们发现 x 和 y 在原图中的持有者已是 None(已被清除),于是直接跳过它们。 但我们仍然会在 (4) 处保存 x(因为此时它的持有者仍然有效)。
嵌套使用示例
假设 Block 内又想对子结构再切分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class SubBlock (nn.Module ): def __init__ (self, dim ): super ().__init__() self.lin1 = nn.Linear(dim, dim) self.lin2 = nn.Linear(dim, dim) def forward (self, x ): def inner (y ): return torch.relu(self.lin2(torch.relu(self.lin1(y)))) return checkpoint(inner, x, use_reentrant=False ) class OuterBlock (nn.Module ): def __init__ (self, dim ): super ().__init__() self.sub1 = SubBlock(dim) self.sub2 = SubBlock(dim) def forward (self, x ): def big (y ): y = self.sub1(y) y = self.sub2(y) return y return checkpoint(big, x, use_reentrant=False )
此时内层与外层是嵌套 checkpoint,反向时若只需要来自 big 内某一步保存的中间张量,会触发两层级联重算但利用“早停 + 不重复”逻辑最小化代价。