fix unit test
This commit is contained in:
parent
52e760a88e
commit
68a561570c
|
@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -452,6 +453,7 @@ class KochRobot:
|
||||||
return obs_dict, action_dict
|
return obs_dict, action_dict
|
||||||
|
|
||||||
def capture_observation(self):
|
def capture_observation(self):
|
||||||
|
"""The returned observations do not have a batch dimension."""
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
|
@ -476,10 +478,15 @@ class KochRobot:
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||||
|
img = torch.from_numpy(images[name])
|
||||||
|
img = img.type(torch.float32) / 255
|
||||||
|
img = img.permute(2, 0, 1).contiguous()
|
||||||
|
obs_dict[f"observation.images.{name}"] = img
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor):
|
def send_action(self, action: torch.Tensor):
|
||||||
|
"""The provided action is expected to be a vector."""
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
|
|
|
@ -409,8 +409,15 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
|
# add batch dimension to 1
|
||||||
|
for name in observation:
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
|
|
Loading…
Reference in New Issue