Merge pull request #45 from alexander-soare/fix_environment_seeding

Reproduce original diffusion policy pusht image eval
This commit is contained in:
Alexander Soare 2024-03-22 16:27:48 +00:00 committed by GitHub
commit e21ed6f510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 23 deletions

View File

@ -4,6 +4,8 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from lerobot.common.utils import set_seed
class AbstractEnv(EnvBase):
def __init__(
@ -34,7 +36,13 @@ class AbstractEnv(EnvBase):
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
# you store the return value in self._next_seed (it will be a new randomly generated seed).
self._next_seed = seed
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
# self._reset is called, we use seed.
self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
@ -59,4 +67,4 @@ class AbstractEnv(EnvBase):
raise NotImplementedError("Abstract method")
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError("Abstract method")
set_seed(seed)

View File

@ -126,9 +126,8 @@ class AlohaEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
AlohaEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
@ -137,8 +136,6 @@ class AlohaEnv(AbstractEnv):
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs.observation)

View File

@ -106,11 +106,9 @@ class PushtEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
PushtEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
@ -239,5 +237,7 @@ class PushtEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
# Set global seed.
set_seed(seed)
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
self._env.seed(seed)

View File

@ -33,7 +33,7 @@ class PushTEnv(gym.Env):
def __init__(
self,
legacy=False,
legacy=True, # compatibility with original
block_cog=None,
damping=None,
render_action=True,

View File

@ -7,7 +7,8 @@ from lerobot.common.envs.pusht.pusht_env import PushTEnv
class PushTImageEnv(PushTEnv):
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
# Note: legacy defaults to True for compatibility with original
def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
super().__init__(
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
)

View File

@ -26,6 +26,7 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
"""
import argparse
import json
import logging
import os.path as osp
import threading
@ -72,6 +73,7 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
seeds = []
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
@ -84,11 +86,16 @@ def eval_policy(
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None:
policy.clear_action_queue()
if env.is_closed:
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
seeds.extend(env._next_seed)
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
if policy is not None:
policy.clear_action_queue()
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
@ -139,11 +146,31 @@ def eval_policy(
thread.join()
info = {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
"per_episode": [
{
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"success": success,
"seed": seed,
}
for i, (sum_reward, max_reward, success, seed) in enumerate(
zip(
sum_rewards[:num_episodes],
max_rewards[:num_episodes],
successes[:num_episodes],
seeds[:num_episodes],
strict=True,
)
)
],
"aggregated": {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
}
if return_first_video:
return info, first_video
@ -182,7 +209,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# when policy is None, rollout a random policy
policy = None
metrics = eval_policy(
info = eval_policy(
env,
policy=policy,
save_video=True,
@ -191,7 +218,11 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)
print(info["aggregated"])
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
logging.info("End of eval")

View File

@ -183,7 +183,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
video_dir=Path(out_dir) / "eval",
save_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")