fix unit test
This commit is contained in:
parent
52e760a88e
commit
68a561570c
|
@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
|
|||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -452,6 +453,7 @@ class KochRobot:
|
|||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
|
@ -476,10 +478,15 @@ class KochRobot:
|
|||
obs_dict = {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
img = torch.from_numpy(images[name])
|
||||
img = img.type(torch.float32) / 255
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
obs_dict[f"observation.images.{name}"] = img
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
"""The provided action is expected to be a vector."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
|
|
|
@ -409,8 +409,15 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
|||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||
):
|
||||
# add batch dimension to 1
|
||||
for name in observation:
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
|
|
Loading…
Reference in New Issue