Merge pull request #45 from alexander-soare/fix_environment_seeding
Reproduce original diffusion policy pusht image eval
This commit is contained in:
commit
e21ed6f510
|
@ -4,6 +4,8 @@ from typing import Optional
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
class AbstractEnv(EnvBase):
|
class AbstractEnv(EnvBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -34,7 +36,13 @@ class AbstractEnv(EnvBase):
|
||||||
|
|
||||||
self._make_env()
|
self._make_env()
|
||||||
self._make_spec()
|
self._make_spec()
|
||||||
self._current_seed = self.set_seed(seed)
|
|
||||||
|
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
|
||||||
|
# you store the return value in self._next_seed (it will be a new randomly generated seed).
|
||||||
|
self._next_seed = seed
|
||||||
|
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
|
||||||
|
# self._reset is called, we use seed.
|
||||||
|
self.set_seed(seed)
|
||||||
|
|
||||||
if self.num_prev_obs > 0:
|
if self.num_prev_obs > 0:
|
||||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
@ -59,4 +67,4 @@ class AbstractEnv(EnvBase):
|
||||||
raise NotImplementedError("Abstract method")
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
raise NotImplementedError("Abstract method")
|
set_seed(seed)
|
||||||
|
|
|
@ -126,9 +126,8 @@ class AlohaEnv(AbstractEnv):
|
||||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
AlohaEnv._reset_warning_issued = True
|
AlohaEnv._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# Seed the environment and update the seed to be used for the next reset.
|
||||||
self._current_seed += 1
|
self._next_seed = self.set_seed(self._next_seed)
|
||||||
self.set_seed(self._current_seed)
|
|
||||||
|
|
||||||
# TODO(rcadene): do not use global variable for this
|
# TODO(rcadene): do not use global variable for this
|
||||||
if "sim_transfer_cube" in self.task:
|
if "sim_transfer_cube" in self.task:
|
||||||
|
@ -137,8 +136,6 @@ class AlohaEnv(AbstractEnv):
|
||||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||||
|
|
||||||
raw_obs = self._env.reset()
|
raw_obs = self._env.reset()
|
||||||
# TODO(rcadene): add assert
|
|
||||||
# assert self._current_seed == self._env._seed
|
|
||||||
|
|
||||||
obs = self._format_raw_obs(raw_obs.observation)
|
obs = self._format_raw_obs(raw_obs.observation)
|
||||||
|
|
||||||
|
|
|
@ -106,11 +106,9 @@ class PushtEnv(AbstractEnv):
|
||||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
PushtEnv._reset_warning_issued = True
|
PushtEnv._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# Seed the environment and update the seed to be used for the next reset.
|
||||||
self._current_seed += 1
|
self._next_seed = self.set_seed(self._next_seed)
|
||||||
self.set_seed(self._current_seed)
|
|
||||||
raw_obs = self._env.reset()
|
raw_obs = self._env.reset()
|
||||||
assert self._current_seed == self._env._seed
|
|
||||||
|
|
||||||
obs = self._format_raw_obs(raw_obs)
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
@ -239,5 +237,7 @@ class PushtEnv(AbstractEnv):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
# Set global seed.
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
||||||
self._env.seed(seed)
|
self._env.seed(seed)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class PushTEnv(gym.Env):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
legacy=False,
|
legacy=True, # compatibility with original
|
||||||
block_cog=None,
|
block_cog=None,
|
||||||
damping=None,
|
damping=None,
|
||||||
render_action=True,
|
render_action=True,
|
||||||
|
|
|
@ -7,7 +7,8 @@ from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
class PushTImageEnv(PushTEnv):
|
class PushTImageEnv(PushTEnv):
|
||||||
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||||
|
|
||||||
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
|
# Note: legacy defaults to True for compatibility with original
|
||||||
|
def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,6 +26,7 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import threading
|
import threading
|
||||||
|
@ -72,6 +73,7 @@ def eval_policy(
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
|
seeds = []
|
||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
episode_counter = 0 # for saving the correct number of videos
|
episode_counter = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
|
@ -84,11 +86,16 @@ def eval_policy(
|
||||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render()) # noqa: B023
|
||||||
|
|
||||||
|
# Clear the policy's action queue before the start of a new rollout.
|
||||||
|
if policy is not None:
|
||||||
|
policy.clear_action_queue()
|
||||||
|
|
||||||
|
if env.is_closed:
|
||||||
|
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
|
||||||
|
seeds.extend(env._next_seed)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||||
if policy is not None:
|
|
||||||
policy.clear_action_queue()
|
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -139,11 +146,31 @@ def eval_policy(
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
"per_episode": [
|
||||||
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
{
|
||||||
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
"episode_ix": i,
|
||||||
"eval_s": time.time() - start,
|
"sum_reward": sum_reward,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"max_reward": max_reward,
|
||||||
|
"success": success,
|
||||||
|
"seed": seed,
|
||||||
|
}
|
||||||
|
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
||||||
|
zip(
|
||||||
|
sum_rewards[:num_episodes],
|
||||||
|
max_rewards[:num_episodes],
|
||||||
|
successes[:num_episodes],
|
||||||
|
seeds[:num_episodes],
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"aggregated": {
|
||||||
|
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
||||||
|
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
||||||
|
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
||||||
|
"eval_s": time.time() - start,
|
||||||
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if return_first_video:
|
if return_first_video:
|
||||||
return info, first_video
|
return info, first_video
|
||||||
|
@ -182,7 +209,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
# when policy is None, rollout a random policy
|
# when policy is None, rollout a random policy
|
||||||
policy = None
|
policy = None
|
||||||
|
|
||||||
metrics = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
save_video=True,
|
save_video=True,
|
||||||
|
@ -191,7 +218,11 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
max_steps=cfg.env.episode_length,
|
max_steps=cfg.env.episode_length,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(info["aggregated"])
|
||||||
|
|
||||||
|
# Save info
|
||||||
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
|
@ -183,7 +183,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
save_video=True,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
Loading…
Reference in New Issue