fix unit test

This commit is contained in:
Remi Cadene 2024-07-10 14:05:58 +02:00
parent 52e760a88e
commit 68a561570c
2 changed files with 16 additions and 2 deletions

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
from pathlib import Path
import time
import einops
import numpy as np
import torch
@ -452,6 +453,7 @@ class KochRobot:
return obs_dict, action_dict
def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
@ -476,10 +478,15 @@ class KochRobot:
obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
# Convert to pytorch format: channel first and float32 in [0,1]
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
return obs_dict
def send_action(self, action: torch.Tensor):
"""The provided action is expected to be a vector."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")

View File

@ -409,8 +409,15 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action)
dt_s = time.perf_counter() - now