YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

ARX5 MuJoCo World Model β€” Rollout Stability Ablation (step 125k)

Four fine-tuned checkpoints from a 2Γ—2 ablation testing noise augmentation and BPTT (k=8) as interventions against autoregressive drift in an ARX5 MuJoCo proprioceptive world model.

All variants fine-tune from the same base: cfgC_lr5e4 best checkpoint (GPT-h32, val_loss=0.005403 at 1M steps), using LR=1e-4, 2k-step warmup, cosine decay to 3e-5, batch=64, 150k-step budget. Checkpoints here are at step 125k.

Ablation design

variant noise end_std forecast_horizon rollout_grad_mode description
B (noise_only) 0.20 1 β€” noise augmentation only, no BPTT
C (bptt_only) 0.05 8 full BPTT k=8, noise unchanged
D (noise+bptt) 0.20 8 full both interventions together
E (bptt8_nograd) 0.05 8 detached AR context without rollout gradients

Model architecture

GPT-based proprioceptive world model for ARX5 6-DOF arm in MuJoCo.

param value
type GPT
n_embd 384
n_head 12
n_layer 6
block_size 32
state_dim 14 (6 joint pos + 1 gripper + 7 joint vel)
action_dim 7
action_repr delta
state_pred_mode absolute
ensemble_size 2
history_horizon 31
angle_wrap_dims [0,1,2,3,4,5] (revolute joints, input-side wrap-to-pi)

Checkpoints

variant step sha256 size
B (noise_only) 125000 00afc01d7fb44b9f1c485f33dc488afbd11bcdcb1e43af254aa158267fa5536c 170 MB
C (bptt_only) 125000 e2de440ac018e7c12b1e5647671af0f2764018fc2834794cfffe9551e806a269 170 MB
D (noise+bptt) 125000 8851ba8c42ad1ca95448f65691fda921ce3e81bfa238a5f49a9a38f681707ba7 170 MB
E (bptt8_nograd) 125000 ae180556f5ac12780eb2fcba0bf92a85da1372731f9faee82ef8d564ef742154 170 MB

Verify:

sha256sum checkpoints/variant_B_noise_only/step_125000/step_0125000_C.pt

Training curves

W&B project: arx5-mujoco-WM-finetune-rollout

Usage

These are full training checkpoints (model weights + optimizer/scheduler state). To load weights only:

import torch
ckpt = torch.load("step_0125000_C.pt", map_location="cpu")
# ckpt["model_state_dict"]  β€” model weights
# ckpt["config"]            β€” training config snapshot

See the source repo for the model class and full loading utilities.

Base checkpoint

All variants fine-tune from cfgC_lr5e4 at 1M steps:

  • val_loss: 0.005403
  • train_loss: 0.02983
  • LR: 5e-4, 5k warmup, cosine
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support