fix color

This commit is contained in:
Remi Cadene 2024-07-15 17:35:11 +02:00
parent 61d9d74308
commit 694a5dbca8
2 changed files with 32 additions and 13 deletions

View File

@ -477,7 +477,9 @@ class KochRobot:
# Read follower position # Read follower position
follower_pos = {} follower_pos = {}
for name in self.follower_arms: for name in self.follower_arms:
now = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -489,7 +491,10 @@ class KochRobot:
# Capture images from cameras # Capture images from cameras
images = {} images = {}
for name in self.cameras: for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch # Populate output dictionnaries and format to pytorch
obs_dict = {} obs_dict = {}

View File

@ -95,6 +95,7 @@ import tqdm
from huggingface_hub import create_branch from huggingface_hub import create_branch
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from termcolor import colored
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -137,7 +138,7 @@ def none_or_int(value):
return int(value) return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None): def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = [] log_items = []
if episode_index is not None: if episode_index is not None:
log_items += [f"ep:{episode_index}"] log_items += [f"ep:{episode_index}"]
@ -170,7 +171,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
if key in robot.logs: if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key]) log_dt(f"dtR{name}", robot.logs[key])
logging.info(" ".join(log_items)) info_str = " ".join(log_items)
if fps is not None:
actual_fps = 1 / dt_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
logging.info(info_str)
def get_is_headless(): def get_is_headless():
@ -201,7 +207,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
log_control_info(robot, dt_s) log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s: if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
break break
@ -272,7 +278,7 @@ def record_dataset(
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
log_control_info(robot, dt_s) log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
@ -352,7 +358,7 @@ def record_dataset(
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
log_control_info(robot, dt_s) log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
@ -396,7 +402,7 @@ def record_dataset(
done[-1] = True done[-1] = True
ep_dict["next.done"] = done ep_dict["next.done"] = done
ep_path = episodes_dir / f"episode_{episode_index}.safetensors" ep_path = episodes_dir / f"episode_{episode_index}.pth"
print("Saving episode dictionary...") print("Saving episode dictionary...")
torch.save(ep_dict, ep_path) torch.save(ep_dict, ep_path)
@ -458,7 +464,7 @@ def record_dataset(
logging.info("Concatenating episodes") logging.info("Concatenating episodes")
ep_dicts = [] ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)): for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.safetensors" ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path) ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
@ -534,16 +540,18 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
log_control_info(robot, dt_s) log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None): def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
# TODO(rcadene): Add option to record eval dataset and logs # TODO(rcadene): Add option to record eval dataset and logs
policy.eval()
# Check device is available # Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed) set_global_seed(hydra_cfg.seed)
@ -561,24 +569,30 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
with ( with (
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(), torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
): ):
# add batch dimension to 1 # add batch dimension to 1
for name in observation: for name in observation:
observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation) action = policy.select_action(observation)
# remove batch dimension # remove batch dimension
action = action.squeeze(0) action = action.squeeze(0)
robot.send_action(action) robot.send_action(action.to("cpu"))
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
log_control_info(robot, dt_s) log_control_info(robot, dt_s, fps=fps)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s: if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break break
@ -643,7 +657,7 @@ if __name__ == "__main__":
default=1, default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
) )
parser.add_argument( parser_record.add_argument(
"--push-to-hub", "--push-to-hub",
type=int, type=int,
default=1, default=1,