YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
ARX5 MuJoCo World Model β Rollout Stability Ablation (step 125k)
Four fine-tuned checkpoints from a 2Γ2 ablation testing noise augmentation and BPTT (k=8) as interventions against autoregressive drift in an ARX5 MuJoCo proprioceptive world model.
All variants fine-tune from the same base: cfgC_lr5e4 best checkpoint (GPT-h32,
val_loss=0.005403 at 1M steps), using LR=1e-4, 2k-step warmup, cosine decay to 3e-5,
batch=64, 150k-step budget. Checkpoints here are at step 125k.
Ablation design
| variant | noise end_std | forecast_horizon | rollout_grad_mode | description |
|---|---|---|---|---|
| B (noise_only) | 0.20 | 1 | β | noise augmentation only, no BPTT |
| C (bptt_only) | 0.05 | 8 | full | BPTT k=8, noise unchanged |
| D (noise+bptt) | 0.20 | 8 | full | both interventions together |
| E (bptt8_nograd) | 0.05 | 8 | detached | AR context without rollout gradients |
Model architecture
GPT-based proprioceptive world model for ARX5 6-DOF arm in MuJoCo.
| param | value |
|---|---|
| type | GPT |
| n_embd | 384 |
| n_head | 12 |
| n_layer | 6 |
| block_size | 32 |
| state_dim | 14 (6 joint pos + 1 gripper + 7 joint vel) |
| action_dim | 7 |
| action_repr | delta |
| state_pred_mode | absolute |
| ensemble_size | 2 |
| history_horizon | 31 |
| angle_wrap_dims | [0,1,2,3,4,5] (revolute joints, input-side wrap-to-pi) |
Checkpoints
| variant | step | sha256 | size |
|---|---|---|---|
| B (noise_only) | 125000 | 00afc01d7fb44b9f1c485f33dc488afbd11bcdcb1e43af254aa158267fa5536c |
170 MB |
| C (bptt_only) | 125000 | e2de440ac018e7c12b1e5647671af0f2764018fc2834794cfffe9551e806a269 |
170 MB |
| D (noise+bptt) | 125000 | 8851ba8c42ad1ca95448f65691fda921ce3e81bfa238a5f49a9a38f681707ba7 |
170 MB |
| E (bptt8_nograd) | 125000 | ae180556f5ac12780eb2fcba0bf92a85da1372731f9faee82ef8d564ef742154 |
170 MB |
Verify:
sha256sum checkpoints/variant_B_noise_only/step_125000/step_0125000_C.pt
Training curves
W&B project: arx5-mujoco-WM-finetune-rollout
| variant | W&B run |
|---|---|
| B (noise_only) | https://wandb.ai/pravsels/arx5-mujoco-WM-finetune-rollout/runs/h9y55c3r |
| C (bptt_only) | https://wandb.ai/pravsels/arx5-mujoco-WM-finetune-rollout/runs/iwknpgy6 |
| D (noise+bptt) | https://wandb.ai/pravsels/arx5-mujoco-WM-finetune-rollout/runs/npthusza |
| E (bptt8_nograd) | https://wandb.ai/pravsels/arx5-mujoco-WM-finetune-rollout/runs/brw9ga1q |
Usage
These are full training checkpoints (model weights + optimizer/scheduler state). To load weights only:
import torch
ckpt = torch.load("step_0125000_C.pt", map_location="cpu")
# ckpt["model_state_dict"] β model weights
# ckpt["config"] β training config snapshot
See the source repo for the model class and full loading utilities.
Base checkpoint
All variants fine-tune from cfgC_lr5e4 at 1M steps:
- val_loss: 0.005403
- train_loss: 0.02983
- LR: 5e-4, 5k warmup, cosine