add train and evals

This commit is contained in:
Thomas Wolf 2024-06-09 14:03:47 +02:00
parent 797f79f182
commit ef074d7281
14 changed files with 184 additions and 85 deletions

View File

@ -27,9 +27,9 @@ Follow these steps:
## 0 - record examples ## 0 - record examples
Run the `0_record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g. Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
``` ```
DATA_DIR='./data' python 0_record_training_data.py \ DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \ --repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \ --num-episodes=50 \
--num-frames=400 --num-frames=400
@ -44,15 +44,34 @@ TODO:
Use the standard dataset visualization script pointing it to the right folder: Use the standard dataset visualization script pointing it to the right folder:
``` ```
DATA_DIR='./data' python visualize_dataset.py python lerobot/scripts/visualize_dataset.py \ DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--repo-id thomwolf/blue_red_sort \ --repo-id thomwolf/blue_red_sort \
--episode-index 0 --episode-index 0
``` ```
## (soon) Train a policy ## 2 - Train a policy
Run `1_train_real_policy.py` example From the example directory let's run this command to train a model using ACT
## (soon) Evaluate the policy in the real world ```
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
device=cuda \
hydra.searchpath=[file://./train_config/] \
hydra.run.dir=./outputs/train/blue_red_sort \
dataset_repo_id=thomwolf/blue_red_sort \
env=gym_real_world \
policy=act_real_world \
wandb.enable=false
```
Run `2_evaluate_real_policy.py` example ## 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model.
You can override parameters as follow:
```
python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000
```

View File

@ -11,13 +11,13 @@ from .robot import Robot
FPS = 30 FPS = 30
CAMERAS_SHAPES = { CAMERAS_SHAPES = {
"observation.images.high": (480, 640, 3), "images.high": (480, 640, 3),
"observation.images.low": (480, 640, 3), "images.low": (480, 640, 3),
} }
CAMERAS_PORTS = { CAMERAS_PORTS = {
"observation.images.high": "/dev/video6", "images.high": "/dev/video6",
"observation.images.low": "/dev/video0", "images.low": "/dev/video0",
} }
LEADER_PORT = "/dev/ttyACM1" LEADER_PORT = "/dev/ttyACM1"
@ -52,6 +52,8 @@ class RealEnv(gym.Env):
leader_port: str = LEADER_PORT, leader_port: str = LEADER_PORT,
warmup_steps: int = 100, warmup_steps: int = 100,
trigger_torque=70, trigger_torque=70,
fps: int = FPS,
fps_tolerance: float = 0.1,
): ):
self.num_joints = num_joints self.num_joints = num_joints
self.cameras_shapes = cameras_shapes self.cameras_shapes = cameras_shapes
@ -62,6 +64,8 @@ class RealEnv(gym.Env):
self.follower_port = follower_port self.follower_port = follower_port
self.leader_port = leader_port self.leader_port = leader_port
self.record = record self.record = record
self.fps = fps
self.fps_tolerance = fps_tolerance
# Initialize the robot # Initialize the robot
self.follower = Robot(device_name=self.follower_port) self.follower = Robot(device_name=self.follower_port)
@ -72,10 +76,13 @@ class RealEnv(gym.Env):
# Initialize the cameras - sorted by camera names # Initialize the cameras - sorted by camera names
self.cameras = {} self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()): for cn, p in sorted(self.cameras_ports.items()):
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
self.cameras[cn] = cv2.VideoCapture(p) self.cameras[cn] = cv2.VideoCapture(p)
if not all(c.isOpened() for c in self.cameras.values()): if not self.cameras[cn].isOpened():
raise OSError("Cannot open all camera ports.") raise OSError(
f"Cannot open camera port {p} for {cn}."
f" Make sure the camera is connected and the port is correct."
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
)
# Specify gym action and observation spaces # Specify gym action and observation spaces
observation_space = {} observation_space = {}
@ -98,7 +105,7 @@ class RealEnv(gym.Env):
if self.cameras_shapes: if self.cameras_shapes:
for cn, hwc_shape in self.cameras_shapes.items(): for cn, hwc_shape in self.cameras_shapes.items():
# Assumes images are unsigned int8 in [0,255] # Assumes images are unsigned int8 in [0,255]
observation_space[f"images.{cn}"] = spaces.Box( observation_space[cn] = spaces.Box(
low=0, low=0,
high=255, high=255,
# height x width x channels (e.g. 480 x 640 x 3) # height x width x channels (e.g. 480 x 640 x 3)
@ -111,22 +118,20 @@ class RealEnv(gym.Env):
self._observation = {} self._observation = {}
self._terminated = False self._terminated = False
self._action_time = time.time() self.starting_time = time.time()
self.timestamps = []
def _get_obs(self): def _get_obs(self):
qpos = self.follower.read_position() qpos = self.follower.read_position()
self._observation["agent_pos"] = pwm2pos(qpos) self._observation["agent_pos"] = pwm2pos(qpos)
for cn, c in self.cameras.items(): for cn, c in self.cameras.items():
self._observation[f"images.{cn}"] = capture_image( self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
)
if self.record: if self.record:
leader_pos = self.leader.read_position() action = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(leader_pos) self._observation["leader_pos"] = pwm2pos(action)
def reset(self, seed: int | None = None): def reset(self, seed: int | None = None):
del seed
# Reset the robot and sync the leader and follower if we are recording # Reset the robot and sync the leader and follower if we are recording
for _ in range(self.warmup_steps): for _ in range(self.warmup_steps):
self._get_obs() self._get_obs()
@ -134,10 +139,22 @@ class RealEnv(gym.Env):
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"])) self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
self._terminated = False self._terminated = False
info = {} info = {}
self.timestamps = []
return self._observation, info return self._observation, info
def step(self, action: np.ndarray = None): def step(self, action: np.ndarray = None):
# Reset the observation if self.timestamps:
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
recording_time = time.time() - self.starting_time
else:
# it's the first step so we start the timer
self.starting_time = time.time()
recording_time = 0
self.timestamps.append(recording_time)
# Get the observation
self._get_obs() self._get_obs()
if self.record: if self.record:
# Teleoperate the leader # Teleoperate the leader
@ -145,9 +162,20 @@ class RealEnv(gym.Env):
else: else:
# Apply the action to the follower # Apply the action to the follower
self.follower.set_goal_pos(pos2pwm(action)) self.follower.set_goal_pos(pos2pwm(action))
reward = 0 reward = 0
terminated = truncated = self._terminated terminated = truncated = self._terminated
info = {} info = {"timestamp": recording_time, "fps_error": False}
# Check if we are able to keep up with the desired fps
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
print(
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
f" at frame {len(self.timestamps)}"
)
info["fps_error"] = True
return self._observation, reward, terminated, truncated, info return self._observation, reward, terminated, truncated, info
def render(self): ... def render(self): ...

View File

@ -1,7 +1,11 @@
"""This script demonstrates how to record a LeRobot dataset of training data
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
"""
import argparse import argparse
import copy import copy
import os import os
import time
import gym_real_world # noqa: F401 import gym_real_world # noqa: F401
import gymnasium as gym import gymnasium as gym
@ -27,15 +31,12 @@ parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16) parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true") parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument( parser.add_argument(
"--fps", "--fps_tolerance",
type=int, type=float,
default=30, default=0.1,
help="Frames per second of the recording." help="Tolerance in fps for the recording before dropping episodes.",
"If we are not able to record at this fps, we will adjust the fps in the metadata.",
)
parser.add_argument(
"--tolerance", type=float, default=0.01, help="Tolerance in seconds for the recording time."
) )
parser.add_argument( parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset." "--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
@ -47,7 +48,7 @@ num_episodes = args.num_episodes
num_frames = args.num_frames num_frames = args.num_frames
revision = args.revision revision = args.revision
fps = args.fps fps = args.fps
tolerance = args.tolerance fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id out_data = DATA_DIR / repo_id
@ -67,7 +68,7 @@ if not os.path.exists(videos_dir):
if __name__ == "__main__": if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py # Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0" gym_handle = "gym_real_world/RealEnv-v0"
env = gym.make(gym_handle, disable_env_checker=True, record=True) env = gym.make(gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance)
ep_dicts = [] ep_dicts = []
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}
@ -84,59 +85,46 @@ if __name__ == "__main__":
os.system(f'spd-say "go {ep_idx}"') os.system(f'spd-say "go {ep_idx}"')
# init buffers # init buffers
obs_replay = {k: [] for k in env.observation_space} obs_replay = {k: [] for k in env.observation_space}
timestamps = []
starting_time = time.time() drop_episode = False
timestamps = []
for _ in tqdm(range(num_frames)): for _ in tqdm(range(num_frames)):
# Apply the next action # Apply the next action
observation, _, _, _, _ = env.step(action=None) observation, _, _, _, info = env.step(action=None)
# images_stacked = np.hstack(list(observation['pixels'].values())) # images_stacked = np.hstack(list(observation['pixels'].values()))
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR) # images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
# cv2.imshow('frame', images_stacked) # cv2.imshow('frame', images_stacked)
if info["fps_error"]:
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
drop_episode = True
break
# store data # store data
for key in observation: for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key])) obs_replay[key].append(copy.deepcopy(observation[key]))
timestamps.append(info["timestamp"])
recording_time = time.time() - starting_time
timestamps.append(recording_time)
# Check if we are able to keep up with the desired fps
if recording_time > num_frames / fps + tolerance:
print(
f"Error: recording time {recording_time:.2f} is greater than expected {num_frames / fps:.2f}"
f" + tolerance {tolerance:.2f}"
f" at frame {len(timestamps)}"
f" in episode {ep_idx}."
f"Dropping the rest of the episode."
)
break
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / fps - (time.time() - starting_time)))
# if cv2.waitKey(1) & 0xFF == ord('q'): # if cv2.waitKey(1) & 0xFF == ord('q'):
# break # break
os.system('spd-say "stop"') os.system('spd-say "stop"')
if len(timestamps) == num_frames: if not drop_episode:
os.system(f'spd-say "saving episode {ep_idx}"') os.system(f'spd-say "saving episode {ep_idx}"')
ep_dict = {} ep_dict = {}
# store images in png and create the video # store images in png and create the video
for img_key in env.cameras: for img_key in env.cameras:
save_images_concurrently( save_images_concurrently(
obs_replay[f"images.{img_key}"], obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}", images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers, args.num_workers,
) )
# for i in tqdm(range(num_frames)):
# cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"),
# obs_replay[i]['pixels'][img_key])
fname = f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
# store the reference to the video frame # store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps] ep_dict[f"observation.{img_key}"] = [
# shutil.rmtree(tmp_imgs_dir) {"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
]
state = torch.tensor(np.array(obs_replay["agent_pos"])) state = torch.tensor(np.array(obs_replay["agent_pos"]))
action = torch.tensor(np.array(obs_replay["leader_pos"])) action = torch.tensor(np.array(obs_replay["leader_pos"]))
@ -198,8 +186,6 @@ if __name__ == "__main__":
features["timestamp"] = Value(dtype="float32", id=None) features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None) features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None) features["index"] = Value(dtype="int64", id=None)
# TODO(rcadene): add success
# features["next.success"] = Value(dtype='bool', id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)

View File

@ -0,0 +1,60 @@
import argparse
import logging
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym # noqa: F401
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.eval import eval
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@ -3,11 +3,10 @@
fps: 30 fps: 30
env: env:
name: dora name: real_world
task: DoraKoch-v0 task: RealEnv-v0
state_dim: 6 state_dim: 6
action_dim: 6 action_dim: 6
fps: ${fps} fps: ${fps}
episode_length: 400 episode_length: 200
gym: real_world: true
fps: ${fps}

View File

@ -1,8 +1,8 @@
# @package _global_ # @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets. # Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images, # Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used # low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation. # to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot). # This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world. # Look at its README for more information on how to evaluate a checkpoint in the real-world.
@ -15,14 +15,14 @@
# ``` # ```
seed: 1000 seed: 1000
dataset_repo_id: thomwolf/blue_sort dataset_repo_id: ???
override_dataset_stats: override_dataset_stats:
observation.images.cam_high: observation.images.high:
# stats from imagenet, since we use a pretrained vision model # stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low: observation.images.low:
# stats from imagenet, since we use a pretrained vision model # stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
@ -46,8 +46,8 @@ training:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]" action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval: eval:
n_episodes: 50 n_episodes: 1
batch_size: 50 batch_size: 1
# See `configuration_act.py` for more details. # See `configuration_act.py` for more details.
policy: policy:
@ -60,16 +60,16 @@ policy:
input_shapes: input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_high: [3, 480, 640] observation.images.high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640] observation.images.low: [3, 480, 640]
observation.state: ["${env.state_dim}"] observation.state: ["${env.state_dim}"]
output_shapes: output_shapes:
action: ["${env.action_dim}"] action: ["${env.action_dim}"]
# Normalization / Unnormalization # Normalization / Unnormalization
input_normalization_modes: input_normalization_modes:
observation.images.cam_high: mean_std observation.images.high: mean_std
observation.images.cam_low: mean_std observation.images.low: mean_std
observation.state: mean_std observation.state: mean_std
output_normalization_modes: output_normalization_modes:
action: mean_std action: mean_std

View File

@ -56,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
) )
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora). # A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora": if not cfg.env.real_world:
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else: else:

View File

@ -29,10 +29,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# map to expected inputs for the policy # map to expected inputs for the policy
return_observations = {} return_observations = {}
if isinstance(observations["pixels"], dict): if "pixels" in observations and isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else: elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
imgs = {"observation.image": observations["pixels"]} imgs = {"observation.image": observations["pixels"]}
else:
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
img = torch.from_numpy(img) img = torch.from_numpy(img)

View File

@ -9,6 +9,7 @@ env:
action_dim: 14 action_dim: 14
fps: ${fps} fps: ${fps}
episode_length: 400 episode_length: 400
real_world: false
gym: gym:
obs_type: pixels_agent_pos obs_type: pixels_agent_pos
render_mode: rgb_array render_mode: rgb_array

View File

@ -9,5 +9,6 @@ env:
action_dim: 14 action_dim: 14
fps: ${fps} fps: ${fps}
episode_length: 400 episode_length: 400
real_world: true
gym: gym:
fps: ${fps} fps: ${fps}

View File

@ -10,6 +10,7 @@ env:
action_dim: 2 action_dim: 2
fps: ${fps} fps: ${fps}
episode_length: 300 episode_length: 300
real_world: false
gym: gym:
obs_type: pixels_agent_pos obs_type: pixels_agent_pos
render_mode: rgb_array render_mode: rgb_array

View File

@ -10,6 +10,7 @@ env:
action_dim: 4 action_dim: 4
fps: ${fps} fps: ${fps}
episode_length: 25 episode_length: 25
real_world: false
gym: gym:
obs_type: pixels_agent_pos obs_type: pixels_agent_pos
render_mode: rgb_array render_mode: rgb_array

View File

@ -164,7 +164,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished. # available of none of the envs finished.
if "final_info" in info: if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]] successes = [i["is_success"] if i is not None else False for i in info["final_info"]]
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs

View File

@ -406,7 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1 step += 1
eval_env.close() if cfg.training.eval_freq > 0:
eval_env.close()
logging.info("End of training") logging.info("End of training")