From 68a561570cf565393cc920eb071581735629745a Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 10 Jul 2024 14:05:58 +0200 Subject: [PATCH] fix unit test --- lerobot/common/robot_devices/robots/koch.py | 11 +++++++++-- lerobot/scripts/control_robot.py | 7 +++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lerobot/common/robot_devices/robots/koch.py b/lerobot/common/robot_devices/robots/koch.py index f133f734..6c031070 100644 --- a/lerobot/common/robot_devices/robots/koch.py +++ b/lerobot/common/robot_devices/robots/koch.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace from pathlib import Path import time +import einops import numpy as np import torch @@ -452,6 +453,7 @@ class KochRobot: return obs_dict, action_dict def capture_observation(self): + """The returned observations do not have a batch dimension.""" if not self.is_connected: raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.") @@ -476,13 +478,18 @@ class KochRobot: obs_dict = {} obs_dict["observation.state"] = torch.from_numpy(state) for name in self.cameras: - obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name]) + # Convert to pytorch format: channel first and float32 in [0,1] + img = torch.from_numpy(images[name]) + img = img.type(torch.float32) / 255 + img = img.permute(2, 0, 1).contiguous() + obs_dict[f"observation.images.{name}"] = img return obs_dict def send_action(self, action: torch.Tensor): + """The provided action is expected to be a vector.""" if not self.is_connected: raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.") - + from_idx = 0 to_idx = 0 follower_goal_pos = {} diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index cccd6ef0..c10445f4 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -409,8 +409,15 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(), ): + # add batch dimension to 1 + for name in observation: + observation[name] = observation[name].unsqueeze(0) + action = policy.select_action(observation) + # remove batch dimension + action = action.squeeze(0) + robot.send_action(action) dt_s = time.perf_counter() - now