1+ # -*- coding: utf-8 -*-
2+ """lerobot_pusht_bc_tutorial_Marktechpost.ipynb
3+
4+ Automatically generated by Colab.
5+
6+ Original file is located at
7+ https://colab.research.google.com/drive/14VCj4xMpHzaDaYjHm0Cyr_2HQZg8ZHw9
8+ """
9+
10+ !pip - q install - - upgrade lerobot torch torchvision timm imageio [ffmpeg ]
11+
12+ import os , math , random , io , sys , json , pathlib , time
13+ import torch , torch .nn as nn , torch .nn .functional as F
14+ from torch .utils .data import DataLoader , Subset
15+ from torchvision .utils import make_grid , save_image
16+ import numpy as np
17+ import imageio .v2 as imageio
18+
19+ try :
20+ from lerobot .common .datasets .lerobot_dataset import LeRobotDataset
21+ except Exception :
22+ from lerobot .datasets .lerobot_dataset import LeRobotDataset
23+
24+ DEVICE = "cuda" if torch .cuda .is_available () else "cpu"
25+ SEED = 42
26+ random .seed (SEED ); np .random .seed (SEED ); torch .manual_seed (SEED )
27+
28+ REPO_ID = "lerobot/pusht"
29+ ds = LeRobotDataset (REPO_ID )
30+ print ("Dataset length:" , len (ds ))
31+
32+ s0 = ds [0 ]
33+ keys = list (s0 .keys ())
34+ print ("Sample keys:" , keys )
35+
36+ def key_with (prefixes ):
37+ for k in keys :
38+ for p in prefixes :
39+ if k .startswith (p ): return k
40+ return None
41+
42+ K_IMG = key_with (["observation.image" , "observation.images" , "observation.rgb" ])
43+ K_STATE = key_with (["observation.state" ])
44+ K_ACT = "action"
45+ assert K_ACT in s0 , f"No 'action' key found in sample. Found: { keys } "
46+ print ("Using keys -> IMG:" , K_IMG , "STATE:" , K_STATE , "ACT:" , K_ACT )
47+
48+ class PushTWrapper (torch .utils .data .Dataset ):
49+ def __init__ (self , base ):
50+ self .base = base
51+ def __len__ (self ): return len (self .base )
52+ def __getitem__ (self , i ):
53+ x = self .base [i ]
54+ img = x [K_IMG ]
55+ if img .ndim == 4 : img = img [- 1 ]
56+ img = img .float () / 255.0 if img .dtype == torch .uint8 else img .float ()
57+ state = x .get (K_STATE , torch .zeros (2 ))
58+ state = state .float ().reshape (- 1 )
59+ act = x [K_ACT ].float ().reshape (- 1 )
60+ if img .shape [- 2 :] != (96 ,96 ):
61+ img = F .interpolate (img .unsqueeze (0 ), size = (96 ,96 ), mode = "bilinear" , align_corners = False )[0 ]
62+ return {"image" : img , "state" : state , "action" : act }
63+
64+ wrapped = PushTWrapper (ds )
65+ N = len (wrapped )
66+ idx = list (range (N ))
67+ random .shuffle (idx )
68+ n_train = int (0.9 * N )
69+ train_idx , val_idx = idx [:n_train ], idx [n_train :]
70+
71+ train_ds = Subset (wrapped , train_idx [:12000 ])
72+ val_ds = Subset (wrapped , val_idx [:2000 ])
73+
74+ BATCH = 128
75+ train_loader = DataLoader (train_ds , batch_size = BATCH , shuffle = True , num_workers = 2 , pin_memory = True )
76+ val_loader = DataLoader (val_ds , batch_size = BATCH , shuffle = False , num_workers = 2 , pin_memory = True )
77+
78+ class SmallBackbone (nn .Module ):
79+ def __init__ (self , out = 256 ):
80+ super ().__init__ ()
81+ self .conv = nn .Sequential (
82+ nn .Conv2d (3 , 32 , 5 , 2 , 2 ), nn .ReLU (inplace = True ),
83+ nn .Conv2d (32 , 64 , 3 , 2 , 1 ), nn .ReLU (inplace = True ),
84+ nn .Conv2d (64 ,128 , 3 , 2 , 1 ), nn .ReLU (inplace = True ),
85+ nn .Conv2d (128 ,128 ,3 , 1 , 1 ), nn .ReLU (inplace = True ),
86+ )
87+ self .head = nn .Sequential (nn .AdaptiveAvgPool2d (1 ), nn .Flatten (), nn .Linear (128 , out ), nn .ReLU (inplace = True ))
88+ def forward (self , x ): return self .head (self .conv (x ))
89+
90+ class BCPolicy (nn .Module ):
91+ def __init__ (self , img_dim = 256 , state_dim = 2 , hidden = 256 , act_dim = 2 ):
92+ super ().__init__ ()
93+ self .backbone = SmallBackbone (img_dim )
94+ self .mlp = nn .Sequential (
95+ nn .Linear (img_dim + state_dim , hidden ), nn .ReLU (inplace = True ),
96+ nn .Linear (hidden , hidden // 2 ), nn .ReLU (inplace = True ),
97+ nn .Linear (hidden // 2 , act_dim )
98+ )
99+ def forward (self , img , state ):
100+ z = self .backbone (img )
101+ if state .ndim == 1 : state = state .unsqueeze (0 )
102+ z = torch .cat ([z , state ], dim = - 1 )
103+ return self .mlp (z )
104+
105+ policy = BCPolicy ().to (DEVICE )
106+ opt = torch .optim .AdamW (policy .parameters (), lr = 3e-4 , weight_decay = 1e-4 )
107+ scaler = torch .cuda .amp .GradScaler (enabled = (DEVICE == "cuda" ))
108+
109+ @torch .no_grad ()
110+ def evaluate ():
111+ policy .eval ()
112+ mse , n = 0.0 , 0
113+ for batch in val_loader :
114+ img = batch ["image" ].to (DEVICE , non_blocking = True )
115+ st = batch ["state" ].to (DEVICE , non_blocking = True )
116+ act = batch ["action" ].to (DEVICE , non_blocking = True )
117+ pred = policy (img , st )
118+ mse += F .mse_loss (pred , act , reduction = "sum" ).item ()
119+ n += act .numel ()
120+ return mse / n
121+
122+ def cosine_lr (step , total , base = 3e-4 , min_lr = 3e-5 ):
123+ if step >= total : return min_lr
124+ cos = 0.5 * (1 + math .cos (math .pi * step / total ))
125+ return min_lr + (base - min_lr )* cos
126+
127+ EPOCHS = 4
128+ steps_total = EPOCHS * len (train_loader )
129+ step = 0
130+ best = float ("inf" )
131+ ckpt = "/content/lerobot_pusht_bc.pt"
132+
133+ for epoch in range (EPOCHS ):
134+ policy .train ()
135+ for batch in train_loader :
136+ lr = cosine_lr (step , steps_total ); step += 1
137+ for g in opt .param_groups : g ["lr" ] = lr
138+
139+ img = batch ["image" ].to (DEVICE , non_blocking = True )
140+ st = batch ["state" ].to (DEVICE , non_blocking = True )
141+ act = batch ["action" ].to (DEVICE , non_blocking = True )
142+
143+ opt .zero_grad (set_to_none = True )
144+ with torch .cuda .amp .autocast (enabled = (DEVICE == "cuda" )):
145+ pred = policy (img , st )
146+ loss = F .smooth_l1_loss (pred , act )
147+ scaler .scale (loss ).backward ()
148+ nn .utils .clip_grad_norm_ (policy .parameters (), 1.0 )
149+ scaler .step (opt ); scaler .update ()
150+
151+ val_mse = evaluate ()
152+ print (f"Epoch { epoch + 1 } /{ EPOCHS } | Val MSE: { val_mse :.6f} " )
153+ if val_mse < best :
154+ best = val_mse
155+ torch .save ({"state_dict" : policy .state_dict (), "val_mse" : best }, ckpt )
156+
157+ print ("Best Val MSE:" , best , "| Saved:" , ckpt )
158+
159+ policy .load_state_dict (torch .load (ckpt )["state_dict" ]); policy .eval ()
160+ os .makedirs ("/content/vis" , exist_ok = True )
161+
162+ def draw_arrow (imgCHW , action_xy , scale = 40 ):
163+ import PIL .Image , PIL .ImageDraw
164+ C ,H ,W = imgCHW .shape
165+ arr = (imgCHW .clamp (0 ,1 ).permute (1 ,2 ,0 ).cpu ().numpy ()* 255 ).astype (np .uint8 )
166+ im = PIL .Image .fromarray (arr )
167+ dr = PIL .ImageDraw .Draw (im )
168+ cx , cy = W // 2 , H // 2
169+ dx , dy = float (action_xy [0 ])* scale , float (- action_xy [1 ])* scale
170+ dr .line ((cx , cy , cx + dx , cy + dy ), width = 3 , fill = (0 ,255 ,0 ))
171+ return np .array (im )
172+
173+ frames = []
174+ with torch .no_grad ():
175+ for i in range (60 ):
176+ b = wrapped [i ]
177+ img = b ["image" ].unsqueeze (0 ).to (DEVICE )
178+ st = b ["state" ].unsqueeze (0 ).to (DEVICE )
179+ pred = policy (img , st )[0 ].cpu ()
180+ frames .append (draw_arrow (b ["image" ], pred ))
181+ video_path = "/content/vis/pusht_pred.mp4"
182+ imageio .mimsave (video_path , frames , fps = 10 )
183+ print ("Wrote" , video_path )
184+
185+ grid = make_grid (torch .stack ([wrapped [i ]["image" ] for i in range (16 )]), nrow = 8 )
186+ save_image (grid , "/content/vis/grid.png" )
187+ print ("Saved grid:" , "/content/vis/grid.png" )
0 commit comments