add train and evals
This commit is contained in:
parent
797f79f182
commit
ef074d7281
|
@ -27,9 +27,9 @@ Follow these steps:
|
|||
|
||||
## 0 - record examples
|
||||
|
||||
Run the `0_record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
||||
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
||||
```
|
||||
DATA_DIR='./data' python 0_record_training_data.py \
|
||||
DATA_DIR='./data' python record_training_data.py \
|
||||
--repo-id=thomwolf/blue_red_sort \
|
||||
--num-episodes=50 \
|
||||
--num-frames=400
|
||||
|
@ -44,15 +44,34 @@ TODO:
|
|||
|
||||
Use the standard dataset visualization script pointing it to the right folder:
|
||||
```
|
||||
DATA_DIR='./data' python visualize_dataset.py python lerobot/scripts/visualize_dataset.py \
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id thomwolf/blue_red_sort \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
## (soon) Train a policy
|
||||
## 2 - Train a policy
|
||||
|
||||
Run `1_train_real_policy.py` example
|
||||
From the example directory let's run this command to train a model using ACT
|
||||
|
||||
## (soon) Evaluate the policy in the real world
|
||||
```
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
|
||||
device=cuda \
|
||||
hydra.searchpath=[file://./train_config/] \
|
||||
hydra.run.dir=./outputs/train/blue_red_sort \
|
||||
dataset_repo_id=thomwolf/blue_red_sort \
|
||||
env=gym_real_world \
|
||||
policy=act_real_world \
|
||||
wandb.enable=false
|
||||
```
|
||||
|
||||
Run `2_evaluate_real_policy.py` example
|
||||
## 3 - Evaluate the policy in the real world
|
||||
|
||||
From the example directory let's run this command to evaluate our policy.
|
||||
The configuration for running the policy is in the checkpoint of the model.
|
||||
You can override parameters as follow:
|
||||
|
||||
```
|
||||
python run_policy.py \
|
||||
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
|
||||
env.episode_length=1000
|
||||
```
|
||||
|
|
|
@ -11,13 +11,13 @@ from .robot import Robot
|
|||
FPS = 30
|
||||
|
||||
CAMERAS_SHAPES = {
|
||||
"observation.images.high": (480, 640, 3),
|
||||
"observation.images.low": (480, 640, 3),
|
||||
"images.high": (480, 640, 3),
|
||||
"images.low": (480, 640, 3),
|
||||
}
|
||||
|
||||
CAMERAS_PORTS = {
|
||||
"observation.images.high": "/dev/video6",
|
||||
"observation.images.low": "/dev/video0",
|
||||
"images.high": "/dev/video6",
|
||||
"images.low": "/dev/video0",
|
||||
}
|
||||
|
||||
LEADER_PORT = "/dev/ttyACM1"
|
||||
|
@ -52,6 +52,8 @@ class RealEnv(gym.Env):
|
|||
leader_port: str = LEADER_PORT,
|
||||
warmup_steps: int = 100,
|
||||
trigger_torque=70,
|
||||
fps: int = FPS,
|
||||
fps_tolerance: float = 0.1,
|
||||
):
|
||||
self.num_joints = num_joints
|
||||
self.cameras_shapes = cameras_shapes
|
||||
|
@ -62,6 +64,8 @@ class RealEnv(gym.Env):
|
|||
self.follower_port = follower_port
|
||||
self.leader_port = leader_port
|
||||
self.record = record
|
||||
self.fps = fps
|
||||
self.fps_tolerance = fps_tolerance
|
||||
|
||||
# Initialize the robot
|
||||
self.follower = Robot(device_name=self.follower_port)
|
||||
|
@ -72,10 +76,13 @@ class RealEnv(gym.Env):
|
|||
# Initialize the cameras - sorted by camera names
|
||||
self.cameras = {}
|
||||
for cn, p in sorted(self.cameras_ports.items()):
|
||||
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
|
||||
self.cameras[cn] = cv2.VideoCapture(p)
|
||||
if not all(c.isOpened() for c in self.cameras.values()):
|
||||
raise OSError("Cannot open all camera ports.")
|
||||
if not self.cameras[cn].isOpened():
|
||||
raise OSError(
|
||||
f"Cannot open camera port {p} for {cn}."
|
||||
f" Make sure the camera is connected and the port is correct."
|
||||
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
|
||||
)
|
||||
|
||||
# Specify gym action and observation spaces
|
||||
observation_space = {}
|
||||
|
@ -98,7 +105,7 @@ class RealEnv(gym.Env):
|
|||
if self.cameras_shapes:
|
||||
for cn, hwc_shape in self.cameras_shapes.items():
|
||||
# Assumes images are unsigned int8 in [0,255]
|
||||
observation_space[f"images.{cn}"] = spaces.Box(
|
||||
observation_space[cn] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
# height x width x channels (e.g. 480 x 640 x 3)
|
||||
|
@ -111,22 +118,20 @@ class RealEnv(gym.Env):
|
|||
|
||||
self._observation = {}
|
||||
self._terminated = False
|
||||
self._action_time = time.time()
|
||||
self.starting_time = time.time()
|
||||
self.timestamps = []
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.follower.read_position()
|
||||
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||
for cn, c in self.cameras.items():
|
||||
self._observation[f"images.{cn}"] = capture_image(
|
||||
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
|
||||
)
|
||||
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
|
||||
|
||||
if self.record:
|
||||
leader_pos = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(leader_pos)
|
||||
action = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(action)
|
||||
|
||||
def reset(self, seed: int | None = None):
|
||||
del seed
|
||||
# Reset the robot and sync the leader and follower if we are recording
|
||||
for _ in range(self.warmup_steps):
|
||||
self._get_obs()
|
||||
|
@ -134,10 +139,22 @@ class RealEnv(gym.Env):
|
|||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
self._terminated = False
|
||||
info = {}
|
||||
self.timestamps = []
|
||||
return self._observation, info
|
||||
|
||||
def step(self, action: np.ndarray = None):
|
||||
# Reset the observation
|
||||
if self.timestamps:
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||
recording_time = time.time() - self.starting_time
|
||||
else:
|
||||
# it's the first step so we start the timer
|
||||
self.starting_time = time.time()
|
||||
recording_time = 0
|
||||
|
||||
self.timestamps.append(recording_time)
|
||||
|
||||
# Get the observation
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
# Teleoperate the leader
|
||||
|
@ -145,9 +162,20 @@ class RealEnv(gym.Env):
|
|||
else:
|
||||
# Apply the action to the follower
|
||||
self.follower.set_goal_pos(pos2pwm(action))
|
||||
|
||||
reward = 0
|
||||
terminated = truncated = self._terminated
|
||||
info = {}
|
||||
info = {"timestamp": recording_time, "fps_error": False}
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
|
||||
print(
|
||||
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
|
||||
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
|
||||
f" at frame {len(self.timestamps)}"
|
||||
)
|
||||
info["fps_error"] = True
|
||||
|
||||
return self._observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self): ...
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
"""This script demonstrates how to record a LeRobot dataset of training data
|
||||
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym
|
||||
|
@ -27,15 +31,12 @@ parser.add_argument("--num-frames", type=int, default=400)
|
|||
parser.add_argument("--num-workers", type=int, default=16)
|
||||
parser.add_argument("--keep-last", action="store_true")
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames per second of the recording."
|
||||
"If we are not able to record at this fps, we will adjust the fps in the metadata.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=0.01, help="Tolerance in seconds for the recording time."
|
||||
"--fps_tolerance",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Tolerance in fps for the recording before dropping episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||
|
@ -47,7 +48,7 @@ num_episodes = args.num_episodes
|
|||
num_frames = args.num_frames
|
||||
revision = args.revision
|
||||
fps = args.fps
|
||||
tolerance = args.tolerance
|
||||
fps_tolerance = args.fps_tolerance
|
||||
|
||||
out_data = DATA_DIR / repo_id
|
||||
|
||||
|
@ -67,7 +68,7 @@ if not os.path.exists(videos_dir):
|
|||
if __name__ == "__main__":
|
||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||
gym_handle = "gym_real_world/RealEnv-v0"
|
||||
env = gym.make(gym_handle, disable_env_checker=True, record=True)
|
||||
env = gym.make(gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
@ -84,59 +85,46 @@ if __name__ == "__main__":
|
|||
os.system(f'spd-say "go {ep_idx}"')
|
||||
# init buffers
|
||||
obs_replay = {k: [] for k in env.observation_space}
|
||||
timestamps = []
|
||||
|
||||
starting_time = time.time()
|
||||
drop_episode = False
|
||||
timestamps = []
|
||||
for _ in tqdm(range(num_frames)):
|
||||
# Apply the next action
|
||||
observation, _, _, _, _ = env.step(action=None)
|
||||
observation, _, _, _, info = env.step(action=None)
|
||||
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
||||
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
||||
# cv2.imshow('frame', images_stacked)
|
||||
|
||||
if info["fps_error"]:
|
||||
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
|
||||
drop_episode = True
|
||||
break
|
||||
|
||||
# store data
|
||||
for key in observation:
|
||||
obs_replay[key].append(copy.deepcopy(observation[key]))
|
||||
|
||||
recording_time = time.time() - starting_time
|
||||
timestamps.append(recording_time)
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if recording_time > num_frames / fps + tolerance:
|
||||
print(
|
||||
f"Error: recording time {recording_time:.2f} is greater than expected {num_frames / fps:.2f}"
|
||||
f" + tolerance {tolerance:.2f}"
|
||||
f" at frame {len(timestamps)}"
|
||||
f" in episode {ep_idx}."
|
||||
f"Dropping the rest of the episode."
|
||||
)
|
||||
break
|
||||
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / fps - (time.time() - starting_time)))
|
||||
timestamps.append(info["timestamp"])
|
||||
|
||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
# break
|
||||
|
||||
os.system('spd-say "stop"')
|
||||
|
||||
if len(timestamps) == num_frames:
|
||||
if not drop_episode:
|
||||
os.system(f'spd-say "saving episode {ep_idx}"')
|
||||
ep_dict = {}
|
||||
# store images in png and create the video
|
||||
for img_key in env.cameras:
|
||||
save_images_concurrently(
|
||||
obs_replay[f"images.{img_key}"],
|
||||
obs_replay[img_key],
|
||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||
args.num_workers,
|
||||
)
|
||||
# for i in tqdm(range(num_frames)):
|
||||
# cv2.imwrite(str(images_dir / f"{img_key}_episode_{ep_idx:06d}" / f"frame_{i:06d}.png"),
|
||||
# obs_replay[i]['pixels'][img_key])
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps]
|
||||
# shutil.rmtree(tmp_imgs_dir)
|
||||
ep_dict[f"observation.{img_key}"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
|
||||
]
|
||||
|
||||
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
||||
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
||||
|
@ -198,8 +186,6 @@ if __name__ == "__main__":
|
|||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
# TODO(rcadene): add success
|
||||
# features["next.success"] = Value(dtype='bool', id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym # noqa: F401
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
|
@ -3,11 +3,10 @@
|
|||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraKoch-v0
|
||||
name: real_world
|
||||
task: RealEnv-v0
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
real_world: true
|
|
@ -1,8 +1,8 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
|
@ -15,14 +15,14 @@
|
|||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: thomwolf/blue_sort
|
||||
dataset_repo_id: ???
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_high:
|
||||
observation.images.high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
observation.images.low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
@ -46,8 +46,8 @@ training:
|
|||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
n_episodes: 1
|
||||
batch_size: 1
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
|
@ -60,16 +60,16 @@ policy:
|
|||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.images.high: [3, 480, 640]
|
||||
observation.images.low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.images.high: mean_std
|
||||
observation.images.low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
|
@ -56,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if not cfg.env.real_world:
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
|
|
|
@ -29,10 +29,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
|
||||
if isinstance(observations["pixels"], dict):
|
||||
if "pixels" in observations and isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
else:
|
||||
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
|
|
@ -9,6 +9,7 @@ env:
|
|||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
|
|
@ -9,5 +9,6 @@ env:
|
|||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: true
|
||||
gym:
|
||||
fps: ${fps}
|
||||
|
|
|
@ -10,6 +10,7 @@ env:
|
|||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
|
|
@ -10,6 +10,7 @@ env:
|
|||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
|
|
@ -164,7 +164,7 @@ def rollout(
|
|||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [i["is_success"] if i is not None else False for i in info["final_info"]]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
|
|
|
@ -406,7 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
step += 1
|
||||
|
||||
eval_env.close()
|
||||
if cfg.training.eval_freq > 0:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue