fix color
This commit is contained in:
parent
61d9d74308
commit
694a5dbca8
|
@ -477,7 +477,9 @@ class KochRobot:
|
|||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
now = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
|
@ -489,7 +491,10 @@ class KochRobot:
|
|||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
now = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
|
|
|
@ -95,6 +95,7 @@ import tqdm
|
|||
from huggingface_hub import create_branch
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
from termcolor import colored
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
|
@ -137,7 +138,7 @@ def none_or_int(value):
|
|||
return int(value)
|
||||
|
||||
|
||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items += [f"ep:{episode_index}"]
|
||||
|
@ -170,7 +171,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
|||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
logging.info(" ".join(log_items))
|
||||
info_str = " ".join(log_items)
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
def get_is_headless():
|
||||
|
@ -201,7 +207,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
|
|||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
||||
break
|
||||
|
@ -272,7 +278,7 @@ def record_dataset(
|
|||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
|
@ -352,7 +358,7 @@ def record_dataset(
|
|||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
|
@ -396,7 +402,7 @@ def record_dataset(
|
|||
done[-1] = True
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
print("Saving episode dictionary...")
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
|
@ -458,7 +464,7 @@ def record_dataset(
|
|||
logging.info("Concatenating episodes")
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
@ -534,16 +540,18 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
|||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record eval dataset and logs
|
||||
policy.eval()
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
@ -561,24 +569,30 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
|||
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# add batch dimension to 1
|
||||
for name in observation:
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
|
||||
if device.type == "mps":
|
||||
for name in observation:
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
robot.send_action(action)
|
||||
robot.send_action(action.to("cpu"))
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s)
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||
break
|
||||
|
@ -643,7 +657,7 @@ if __name__ == "__main__":
|
|||
default=1,
|
||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
parser_record.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
|
|
Loading…
Reference in New Issue