Add refactor env.py
This commit is contained in:
parent
6b5330b31c
commit
d6abce38eb
|
@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||||
if n_envs is not None and n_envs < 1:
|
if n_envs is not None and n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs must be at least 1")
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"obs_type": "pixels_agent_pos",
|
|
||||||
"render_mode": "rgb_array",
|
|
||||||
"max_episode_steps": cfg.env.episode_length,
|
|
||||||
"visualization_width": 384,
|
|
||||||
"visualization_height": 384,
|
|
||||||
}
|
|
||||||
|
|
||||||
package_name = f"gym_{cfg.env.name}"
|
package_name = f"gym_{cfg.env.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||||
|
gym_kwgs = cfg.env.get("gym", {})
|
||||||
|
|
||||||
|
if cfg.env.get("episode_length"):
|
||||||
|
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||||
env = env_cls(
|
env = env_cls(
|
||||||
[
|
[
|
||||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -37,6 +37,8 @@ training:
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
|
num_workers: 4
|
||||||
|
batch_size: ???
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
|
|
@ -5,10 +5,13 @@ fps: 50
|
||||||
env:
|
env:
|
||||||
name: aloha
|
name: aloha
|
||||||
task: AlohaInsertion-v0
|
task: AlohaInsertion-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: [3, 480, 640]
|
image_size: [3, 480, 640]
|
||||||
episode_length: 400
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 14
|
state_dim: 14
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|
|
@ -5,10 +5,13 @@ fps: 10
|
||||||
env:
|
env:
|
||||||
name: pusht
|
name: pusht
|
||||||
task: PushT-v0
|
task: PushT-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 96
|
image_size: 96
|
||||||
episode_length: 300
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 2
|
state_dim: 2
|
||||||
action_dim: 2
|
action_dim: 2
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 300
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|
|
@ -5,10 +5,13 @@ fps: 15
|
||||||
env:
|
env:
|
||||||
name: xarm
|
name: xarm
|
||||||
task: XarmLift-v0
|
task: XarmLift-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 84
|
image_size: 84
|
||||||
episode_length: 25
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 25
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|
|
@ -281,6 +281,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
|
|
||||||
|
@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
|
@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=cfg.training.num_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
|
@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
logging.info("End of offline training")
|
||||||
|
|
||||||
|
if cfg.training.online_steps == 0:
|
||||||
|
if cfg.training.eval_freq > 0:
|
||||||
|
eval_env.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# create an env dedicated to online episodes collection from policy rollout
|
||||||
|
online_training_env = make_env(cfg, n_envs=1)
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
online_dataset.hf_dataset = {}
|
online_dataset.hf_dataset = {}
|
||||||
|
@ -406,8 +420,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info("End of online training")
|
||||||
|
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
logging.info("End of training")
|
online_training_env.close()
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
|
Loading…
Reference in New Issue