Skip to content

Feat: Activation Checkpointing#154

Open
Chamberlain0w0 wants to merge 3 commits into
masterfrom
feat/activation_checkpointing
Open

Feat: Activation Checkpointing#154
Chamberlain0w0 wants to merge 3 commits into
masterfrom
feat/activation_checkpointing

Conversation

@Chamberlain0w0

Copy link
Copy Markdown
Contributor

No description provided.

@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/activation_checkpointing branch from d87926e to ba8db8f Compare June 3, 2026 08:45
@Chamberlain0w0 Chamberlain0w0 changed the title [WIP] Feat: Activation Checkpointing Feat: Activation Checkpointing Jun 5, 2026
Comment thread example/gpt2/checkpoint_loader.h Outdated
Comment thread example/gpt2/checkpoint_loader.cc Outdated
Comment thread infini_train/src/utils/checkpoint.cc
Comment thread infini_train/include/utils/checkpoint.h Outdated
@@ -1,12 +1,15 @@
#include "infini_train/include/nn/modules/transformer/transformer.h"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后扩展 recompute 其他功能(如selective)的话就不适合全放在transformer.cc里了,之后可以拆分activation_recompute.cc。本PR可以不做修改

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,这块确实没太想清楚,主要是 megatron 的实现里面也把很多重计算逻辑融在了 transformer 模型层,之后可以再讨论下

Comment thread infini_train/src/utils/checkpoint.cc Outdated
Comment thread infini_train/include/autocast.h
Comment thread infini_train/src/nn/modules/module.cc Outdated
@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/activation_checkpointing branch from 9cdf73c to a38a4d3 Compare June 11, 2026 01:21
@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/activation_checkpointing branch from a38a4d3 to e594d9c Compare June 25, 2026 06:47
// Used by non-reentrant checkpoint recomputation so downstream SetupContext
// calls see the same needs_input_grad_ pattern as the original forward,
// without wiring the recompute graph into the engine.
class PropagateRequiresGradGuard {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本质上是需要一个 2*2=4 种情况的精细控制(with_grad/no_grad x 带引用计数建图/仅传递 requires_grad)。

也可以不单独实现一个 guard,而是给原先的 GradGuard 重载一个带参数 (比如 bool record_context = false) 的构造方式,让 GradGuard 本身能够覆盖这四种情况。

try {
forward_fn(detached_inputs);
} catch (const StopRecomputeError &) {
// Early-stop: expected when all needed tensors are recomputed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只是用异常的语法和控制流机制来实现 early stop 的语义,本质上并不是真的抛一个会让程序 abort 的异常。torch 也采取的是一样的实现方式。

Image

这块暂时没有想到其他更优雅的处理方式。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants