fix color

This commit is contained in:
Remi Cadene 2024-07-15 17:35:11 +02:00
parent 61d9d74308
commit 694a5dbca8
2 changed files with 32 additions and 13 deletions

View File

@ -477,7 +477,9 @@ class KochRobot:
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position
state = []
@ -489,7 +491,10 @@ class KochRobot:
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch
obs_dict = {}

View File

@ -95,6 +95,7 @@ import tqdm
from huggingface_hub import create_branch
from omegaconf import DictConfig
from PIL import Image
from termcolor import colored
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats
@ -137,7 +138,7 @@ def none_or_int(value):
return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items += [f"ep:{episode_index}"]
@ -170,7 +171,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
logging.info(" ".join(log_items))
info_str = " ".join(log_items)
if fps is not None:
actual_fps = 1 / dt_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
logging.info(info_str)
def get_is_headless():
@ -201,7 +207,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
break
@ -272,7 +278,7 @@ def record_dataset(
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
@ -352,7 +358,7 @@ def record_dataset(
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
@ -396,7 +402,7 @@ def record_dataset(
done[-1] = True
ep_dict["next.done"] = done
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
ep_path = episodes_dir / f"episode_{episode_index}.pth"
print("Saving episode dictionary...")
torch.save(ep_dict, ep_path)
@ -458,7 +464,7 @@ def record_dataset(
logging.info("Concatenating episodes")
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
@ -534,16 +540,18 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
# TODO(rcadene): Add option to record eval dataset and logs
policy.eval()
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
@ -561,24 +569,30 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action)
robot.send_action(action.to("cpu"))
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s)
log_control_info(robot, dt_s, fps=fps)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
@ -643,7 +657,7 @@ if __name__ == "__main__":
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser.add_argument(
parser_record.add_argument(
"--push-to-hub",
type=int,
default=1,