Skip to content

Commit b62c60d

Browse files
committed
Create lerobot_pusht_bc_tutorial_marktechpost.py
1 parent 9ee6a03 commit b62c60d

1 file changed

Lines changed: 187 additions & 0 deletions

File tree

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# -*- coding: utf-8 -*-
2+
"""lerobot_pusht_bc_tutorial_Marktechpost.ipynb
3+
4+
Automatically generated by Colab.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/14VCj4xMpHzaDaYjHm0Cyr_2HQZg8ZHw9
8+
"""
9+
10+
!pip -q install --upgrade lerobot torch torchvision timm imageio[ffmpeg]
11+
12+
import os, math, random, io, sys, json, pathlib, time
13+
import torch, torch.nn as nn, torch.nn.functional as F
14+
from torch.utils.data import DataLoader, Subset
15+
from torchvision.utils import make_grid, save_image
16+
import numpy as np
17+
import imageio.v2 as imageio
18+
19+
try:
20+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
21+
except Exception:
22+
from lerobot.datasets.lerobot_dataset import LeRobotDataset
23+
24+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25+
SEED = 42
26+
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
27+
28+
REPO_ID = "lerobot/pusht"
29+
ds = LeRobotDataset(REPO_ID)
30+
print("Dataset length:", len(ds))
31+
32+
s0 = ds[0]
33+
keys = list(s0.keys())
34+
print("Sample keys:", keys)
35+
36+
def key_with(prefixes):
37+
for k in keys:
38+
for p in prefixes:
39+
if k.startswith(p): return k
40+
return None
41+
42+
K_IMG = key_with(["observation.image", "observation.images", "observation.rgb"])
43+
K_STATE = key_with(["observation.state"])
44+
K_ACT = "action"
45+
assert K_ACT in s0, f"No 'action' key found in sample. Found: {keys}"
46+
print("Using keys -> IMG:", K_IMG, "STATE:", K_STATE, "ACT:", K_ACT)
47+
48+
class PushTWrapper(torch.utils.data.Dataset):
49+
def __init__(self, base):
50+
self.base = base
51+
def __len__(self): return len(self.base)
52+
def __getitem__(self, i):
53+
x = self.base[i]
54+
img = x[K_IMG]
55+
if img.ndim == 4: img = img[-1]
56+
img = img.float() / 255.0 if img.dtype==torch.uint8 else img.float()
57+
state = x.get(K_STATE, torch.zeros(2))
58+
state = state.float().reshape(-1)
59+
act = x[K_ACT].float().reshape(-1)
60+
if img.shape[-2:] != (96,96):
61+
img = F.interpolate(img.unsqueeze(0), size=(96,96), mode="bilinear", align_corners=False)[0]
62+
return {"image": img, "state": state, "action": act}
63+
64+
wrapped = PushTWrapper(ds)
65+
N = len(wrapped)
66+
idx = list(range(N))
67+
random.shuffle(idx)
68+
n_train = int(0.9*N)
69+
train_idx, val_idx = idx[:n_train], idx[n_train:]
70+
71+
train_ds = Subset(wrapped, train_idx[:12000])
72+
val_ds = Subset(wrapped, val_idx[:2000])
73+
74+
BATCH = 128
75+
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
76+
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
77+
78+
class SmallBackbone(nn.Module):
79+
def __init__(self, out=256):
80+
super().__init__()
81+
self.conv = nn.Sequential(
82+
nn.Conv2d(3, 32, 5, 2, 2), nn.ReLU(inplace=True),
83+
nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(inplace=True),
84+
nn.Conv2d(64,128, 3, 2, 1), nn.ReLU(inplace=True),
85+
nn.Conv2d(128,128,3, 1, 1), nn.ReLU(inplace=True),
86+
)
87+
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, out), nn.ReLU(inplace=True))
88+
def forward(self, x): return self.head(self.conv(x))
89+
90+
class BCPolicy(nn.Module):
91+
def __init__(self, img_dim=256, state_dim=2, hidden=256, act_dim=2):
92+
super().__init__()
93+
self.backbone = SmallBackbone(img_dim)
94+
self.mlp = nn.Sequential(
95+
nn.Linear(img_dim + state_dim, hidden), nn.ReLU(inplace=True),
96+
nn.Linear(hidden, hidden//2), nn.ReLU(inplace=True),
97+
nn.Linear(hidden//2, act_dim)
98+
)
99+
def forward(self, img, state):
100+
z = self.backbone(img)
101+
if state.ndim==1: state = state.unsqueeze(0)
102+
z = torch.cat([z, state], dim=-1)
103+
return self.mlp(z)
104+
105+
policy = BCPolicy().to(DEVICE)
106+
opt = torch.optim.AdamW(policy.parameters(), lr=3e-4, weight_decay=1e-4)
107+
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
108+
109+
@torch.no_grad()
110+
def evaluate():
111+
policy.eval()
112+
mse, n = 0.0, 0
113+
for batch in val_loader:
114+
img = batch["image"].to(DEVICE, non_blocking=True)
115+
st = batch["state"].to(DEVICE, non_blocking=True)
116+
act = batch["action"].to(DEVICE, non_blocking=True)
117+
pred = policy(img, st)
118+
mse += F.mse_loss(pred, act, reduction="sum").item()
119+
n += act.numel()
120+
return mse / n
121+
122+
def cosine_lr(step, total, base=3e-4, min_lr=3e-5):
123+
if step>=total: return min_lr
124+
cos = 0.5*(1+math.cos(math.pi*step/total))
125+
return min_lr + (base-min_lr)*cos
126+
127+
EPOCHS = 4
128+
steps_total = EPOCHS*len(train_loader)
129+
step = 0
130+
best = float("inf")
131+
ckpt = "/content/lerobot_pusht_bc.pt"
132+
133+
for epoch in range(EPOCHS):
134+
policy.train()
135+
for batch in train_loader:
136+
lr = cosine_lr(step, steps_total); step += 1
137+
for g in opt.param_groups: g["lr"] = lr
138+
139+
img = batch["image"].to(DEVICE, non_blocking=True)
140+
st = batch["state"].to(DEVICE, non_blocking=True)
141+
act = batch["action"].to(DEVICE, non_blocking=True)
142+
143+
opt.zero_grad(set_to_none=True)
144+
with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
145+
pred = policy(img, st)
146+
loss = F.smooth_l1_loss(pred, act)
147+
scaler.scale(loss).backward()
148+
nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
149+
scaler.step(opt); scaler.update()
150+
151+
val_mse = evaluate()
152+
print(f"Epoch {epoch+1}/{EPOCHS} | Val MSE: {val_mse:.6f}")
153+
if val_mse < best:
154+
best = val_mse
155+
torch.save({"state_dict": policy.state_dict(), "val_mse": best}, ckpt)
156+
157+
print("Best Val MSE:", best, "| Saved:", ckpt)
158+
159+
policy.load_state_dict(torch.load(ckpt)["state_dict"]); policy.eval()
160+
os.makedirs("/content/vis", exist_ok=True)
161+
162+
def draw_arrow(imgCHW, action_xy, scale=40):
163+
import PIL.Image, PIL.ImageDraw
164+
C,H,W = imgCHW.shape
165+
arr = (imgCHW.clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
166+
im = PIL.Image.fromarray(arr)
167+
dr = PIL.ImageDraw.Draw(im)
168+
cx, cy = W//2, H//2
169+
dx, dy = float(action_xy[0])*scale, float(-action_xy[1])*scale
170+
dr.line((cx, cy, cx+dx, cy+dy), width=3, fill=(0,255,0))
171+
return np.array(im)
172+
173+
frames = []
174+
with torch.no_grad():
175+
for i in range(60):
176+
b = wrapped[i]
177+
img = b["image"].unsqueeze(0).to(DEVICE)
178+
st = b["state"].unsqueeze(0).to(DEVICE)
179+
pred = policy(img, st)[0].cpu()
180+
frames.append(draw_arrow(b["image"], pred))
181+
video_path = "/content/vis/pusht_pred.mp4"
182+
imageio.mimsave(video_path, frames, fps=10)
183+
print("Wrote", video_path)
184+
185+
grid = make_grid(torch.stack([wrapped[i]["image"] for i in range(16)]), nrow=8)
186+
save_image(grid, "/content/vis/grid.png")
187+
print("Saved grid:", "/content/vis/grid.png")

0 commit comments

Comments
 (0)