fix unit test

This commit is contained in:
Remi Cadene 2024-07-10 14:05:58 +02:00
parent 52e760a88e
commit 68a561570c
2 changed files with 16 additions and 2 deletions

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
import time import time
import einops
import numpy as np import numpy as np
import torch import torch
@ -452,6 +453,7 @@ class KochRobot:
return obs_dict, action_dict return obs_dict, action_dict
def capture_observation(self): def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.") raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
@ -476,10 +478,15 @@ class KochRobot:
obs_dict = {} obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state) obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras: for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name]) # Convert to pytorch format: channel first and float32 in [0,1]
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor): def send_action(self, action: torch.Tensor):
"""The provided action is expected to be a vector."""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.") raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")

View File

@ -409,8 +409,15 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
): ):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
action = policy.select_action(observation) action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action) robot.send_action(action)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now