Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content

zero3: invalidate coordinator trace on hook re-registration#8043

Merged
tohtana merged 6 commits into
deepspeedai:masterfrom
roycho96:fix/zero3-hook-cycle-trace-invalidate
Jun 10, 2026
Merged

zero3: invalidate coordinator trace on hook re-registration#8043
tohtana merged 6 commits into
deepspeedai:masterfrom
roycho96:fix/zero3-hook-cycle-trace-invalidate

Conversation

@roycho96

@roycho96 roycho96 commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Re-registering ZeRO-3 module hooks after they were removed (e.g. via unwrap_model_for_generation) leaves the param coordinator's recorded trace stale. The next training forward raises IndexError: pop from an empty deque from _start_of_forward_hook -> reset_step -> record_parameters -> popleft.

Repro

DeepSpeed master, torch 2.8.0+cu128, transformers, peft. Single GPU.

import torch, deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

m = "hf-internal-testing/tiny-random-gpt2"
tok = AutoTokenizer.from_pretrained(m); tok.pad_token = tok.eos_token
model = get_peft_model(AutoModelForCausalLM.from_pretrained(m, dtype=torch.bfloat16),
                       LoraConfig(task_type=TaskType.CAUSAL_LM, r=4, target_modules=["c_attn"]))
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

cfg = {"train_micro_batch_size_per_gpu": 1, "bf16": {"enabled": True},
       "zero_optimization": {"stage": 3, "stage3_param_persistence_threshold": 0},
       "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}}
engine, *_ = deepspeed.initialize(model=model, config=cfg,
                                  model_parameters=[p for p in model.parameters() if p.requires_grad])

ids = tok("hello", return_tensors="pt").input_ids.to(engine.device)
for _ in range(2):
    with unwrap_model_for_generation(engine) as unwrapped:
        with torch.no_grad():
            unwrapped.generate(ids, max_new_tokens=4, do_sample=False, pad_token_id=tok.pad_token_id)
    out = engine(input_ids=ids, labels=ids)
    engine.backward(out.loss); engine.step()

Run with torchrun --nproc-per-node=1 repro.py. Second iteration raises the IndexError.

Fix

Two small edits in deepspeed/runtime/zero/:

  • parameter_offload.py::_register_deepspeed_module: when the root module is re-registered, invalidate the coordinator trace so the next forward re-records cleanly.
  • partitioned_param_coordinator.py::_clear_trace_structures: also clear __step_id_module_fetched_for, which was being left populated and caused the empty-deque pop.

Both guards are no-ops on initial registration (trace is already INVALID) and on non-root submodule walks.

Test

tests/unit/runtime/zero/test_unwrap_model.py::TestUnwrapModelTraceInvalidate covers the path: run one training step, wrap with unwrap_model_for_generation, assert the coordinator returns to INVALID. World size 2.

roycho96 added 3 commits June 2, 2026 22:23
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@tohtana

tohtana commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Hi @roycho96,

Thank you for the fix! The fix looks good to me, but can you also add an iteration after unwrap_model_for_generation? I think assert coordinator.is_invalid_trace() doesn't fully guarantees expected state as the currently bug appears after that. Something like this will be more robust.

training step 1:
  forward
  backward
  step

unwrap/generate phase:
  with unwrap_model_for_generation(engine):
      ... optional forward/generate/no-op ...

training step 2:
  forward
  backward
  step

The previous assertion only checked the coordinator's internal flag.
The actual bug surfaces during the next training step, when reset_step
constructs the parameter trace and pops an empty deque.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the fix/zero3-hook-cycle-trace-invalidate branch from 3009717 to ced26cb Compare June 7, 2026 15:30
@roycho96

roycho96 commented Jun 7, 2026

Copy link
Copy Markdown
Contributor Author

Hi @roycho96,

Thank you for the fix! The fix looks good to me, but can you also add an iteration after unwrap_model_for_generation? I think assert coordinator.is_invalid_trace() doesn't fully guarantees expected state as the currently bug appears after that. Something like this will be more robust.

training step 1:
  forward
  backward
  step

unwrap/generate phase:
  with unwrap_model_for_generation(engine):
      ... optional forward/generate/no-op ...

training step 2:
  forward
  backward
  step

Thanks @tohtana, updated the test per your suggestion. Confirmed it fails without the fix and passes with it.

@tohtana tohtana enabled auto-merge (squash) June 10, 2026 07:35

@tohtana tohtana left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you @roycho96!

@tohtana tohtana merged commit 87f31d5 into deepspeedai:master Jun 10, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants