Add refactor env.py
This commit is contained in:
parent
6b5330b31c
commit
d6abce38eb
|
@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
|
@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = cfg.env.get("gym", {})
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -37,6 +37,8 @@ training:
|
|||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
|
|
@ -5,10 +5,13 @@ fps: 50
|
|||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
|
|
@ -5,10 +5,13 @@ fps: 10
|
|||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 96
|
||||
episode_length: 300
|
||||
fps: ${fps}
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
|
|
@ -5,10 +5,13 @@ fps: 15
|
|||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 84
|
||||
episode_length: 25
|
||||
fps: ${fps}
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
|
|
@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(
|
||||
|
@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
|
@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
|
@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
step += 1
|
||||
|
||||
logging.info("End of offline training")
|
||||
|
||||
if cfg.training.online_steps == 0:
|
||||
if cfg.training.eval_freq > 0:
|
||||
eval_env.close()
|
||||
return
|
||||
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
online_training_env = make_env(cfg, n_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.hf_dataset = {}
|
||||
|
@ -406,8 +420,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
drop_last=False,
|
||||
)
|
||||
|
||||
logging.info("End of online training")
|
||||
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
online_training_env.close()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
|
|
Loading…
Reference in New Issue