fix color
This commit is contained in:
parent
61d9d74308
commit
694a5dbca8
|
@ -477,7 +477,9 @@ class KochRobot:
|
||||||
# Read follower position
|
# Read follower position
|
||||||
follower_pos = {}
|
follower_pos = {}
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
|
now = time.perf_counter()
|
||||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||||
|
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
|
||||||
|
|
||||||
# Create state by concatenating follower current position
|
# Create state by concatenating follower current position
|
||||||
state = []
|
state = []
|
||||||
|
@ -489,7 +491,10 @@ class KochRobot:
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
images = {}
|
images = {}
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
|
now = time.perf_counter()
|
||||||
images[name] = self.cameras[name].async_read()
|
images[name] = self.cameras[name].async_read()
|
||||||
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
|
||||||
|
|
||||||
# Populate output dictionnaries and format to pytorch
|
# Populate output dictionnaries and format to pytorch
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
|
|
|
@ -95,6 +95,7 @@ import tqdm
|
||||||
from huggingface_hub import create_branch
|
from huggingface_hub import create_branch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
@ -137,7 +138,7 @@ def none_or_int(value):
|
||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||||
log_items = []
|
log_items = []
|
||||||
if episode_index is not None:
|
if episode_index is not None:
|
||||||
log_items += [f"ep:{episode_index}"]
|
log_items += [f"ep:{episode_index}"]
|
||||||
|
@ -170,7 +171,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
log_dt(f"dtR{name}", robot.logs[key])
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
logging.info(" ".join(log_items))
|
info_str = " ".join(log_items)
|
||||||
|
if fps is not None:
|
||||||
|
actual_fps = 1 / dt_s
|
||||||
|
if actual_fps < fps - 1:
|
||||||
|
info_str = colored(info_str, "yellow")
|
||||||
|
logging.info(info_str)
|
||||||
|
|
||||||
|
|
||||||
def get_is_headless():
|
def get_is_headless():
|
||||||
|
@ -201,7 +207,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
||||||
break
|
break
|
||||||
|
@ -272,7 +278,7 @@ def record_dataset(
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_time
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
@ -352,7 +358,7 @@ def record_dataset(
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_time
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
@ -396,7 +402,7 @@ def record_dataset(
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
ep_dict["next.done"] = done
|
ep_dict["next.done"] = done
|
||||||
|
|
||||||
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||||
print("Saving episode dictionary...")
|
print("Saving episode dictionary...")
|
||||||
torch.save(ep_dict, ep_path)
|
torch.save(ep_dict, ep_path)
|
||||||
|
|
||||||
|
@ -458,7 +464,7 @@ def record_dataset(
|
||||||
logging.info("Concatenating episodes")
|
logging.info("Concatenating episodes")
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||||
ep_dict = torch.load(ep_path)
|
ep_dict = torch.load(ep_path)
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
@ -534,16 +540,18 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||||
# TODO(rcadene): Add option to record eval dataset and logs
|
# TODO(rcadene): Add option to record eval dataset and logs
|
||||||
policy.eval()
|
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_global_seed(hydra_cfg.seed)
|
set_global_seed(hydra_cfg.seed)
|
||||||
|
@ -561,24 +569,30 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
||||||
|
|
||||||
with (
|
with (
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type)
|
||||||
|
if device.type == "cuda" and hydra_cfg.use_amp
|
||||||
|
else nullcontext(),
|
||||||
):
|
):
|
||||||
# add batch dimension to 1
|
# add batch dimension to 1
|
||||||
for name in observation:
|
for name in observation:
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
|
||||||
|
if device.type == "mps":
|
||||||
|
for name in observation:
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
# remove batch dimension
|
# remove batch dimension
|
||||||
action = action.squeeze(0)
|
action = action.squeeze(0)
|
||||||
|
|
||||||
robot.send_action(action)
|
robot.send_action(action.to("cpu"))
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||||
break
|
break
|
||||||
|
@ -643,7 +657,7 @@ if __name__ == "__main__":
|
||||||
default=1,
|
default=1,
|
||||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser_record.add_argument(
|
||||||
"--push-to-hub",
|
"--push-to-hub",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
|
|
Loading…
Reference in New Issue