added possibility to record with a policy; added temporary fixes to train.py to enable training on mac
This commit is contained in:
parent
9a5356d0ac
commit
5e01c21692
|
@ -345,7 +345,7 @@ class TDMPCPolicy(
|
||||||
batch[key] = batch[key].transpose(1, 0)
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
action = batch["action"] # (t, b, action_dim)
|
action = batch["action"] # (t, b, action_dim)
|
||||||
reward = batch["reward"] # (t, b)
|
reward = batch["next.reward"] # (t, b)
|
||||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
# Apply random image augmentations.
|
# Apply random image augmentations.
|
||||||
|
@ -422,7 +422,7 @@ class TDMPCPolicy(
|
||||||
(
|
(
|
||||||
temporal_loss_coeffs
|
temporal_loss_coeffs
|
||||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||||
* ~batch["reward_is_pad"]
|
* ~batch["next.reward_is_pad"]
|
||||||
# `reward_preds` depends on the current observation and the actions.
|
# `reward_preds` depends on the current observation and the actions.
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
|
@ -443,7 +443,7 @@ class TDMPCPolicy(
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
# q_targets depends on the reward and the next observations.
|
# q_targets depends on the reward and the next observations.
|
||||||
* ~batch["reward_is_pad"]
|
* ~batch["next.reward_is_pad"]
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
|
|
|
@ -85,11 +85,15 @@ from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from datasets import Dataset, Features, Sequence, Value
|
from datasets import Dataset, Features, Sequence, Value
|
||||||
|
|
||||||
|
@ -99,12 +103,15 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch
|
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch
|
||||||
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||||
|
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||||
from lerobot.scripts.push_dataset_to_hub import (
|
from lerobot.scripts.push_dataset_to_hub import (
|
||||||
push_dataset_card_to_hub,
|
push_dataset_card_to_hub,
|
||||||
push_meta_data_to_hub,
|
push_meta_data_to_hub,
|
||||||
|
@ -178,6 +185,29 @@ def is_headless():
|
||||||
print()
|
print()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_action_from_policy(policy, observation, device, use_amp=False):
|
||||||
|
with (
|
||||||
|
torch.inference_mode(),
|
||||||
|
torch.autocast(device_type=device.type)
|
||||||
|
if device.type == "cuda" and use_amp
|
||||||
|
else nullcontext(),
|
||||||
|
):
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
|
for name in observation:
|
||||||
|
if "image" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
# Compute the next action with the policy
|
||||||
|
# based on the current observation
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
# Remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
# Move to cpu, if not already the case
|
||||||
|
return action.to("cpu")
|
||||||
|
|
||||||
def init_read_leader(robot, fps, **kwargs):
|
def init_read_leader(robot, fps, **kwargs):
|
||||||
axis_directions = kwargs.get('axis_directions', [1])
|
axis_directions = kwargs.get('axis_directions', [1])
|
||||||
offsets = kwargs.get('offsets', [0])
|
offsets = kwargs.get('offsets', [0])
|
||||||
|
@ -240,7 +270,7 @@ def create_rl_hf_dataset(data_dict):
|
||||||
features["action"] = Sequence(
|
features["action"] = Sequence(
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
)
|
)
|
||||||
features["reward"] = Value(dtype="float32", id=None)
|
features["next.reward"] = Value(dtype="float32", id=None)
|
||||||
|
|
||||||
features["seed"] = Value(dtype="int64", id=None)
|
features["seed"] = Value(dtype="int64", id=None)
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
@ -277,6 +307,8 @@ def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs):
|
||||||
def record(
|
def record(
|
||||||
env,
|
env,
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
|
policy: torch.nn.Module | None = None,
|
||||||
|
policy_cfg: DictConfig | None = None,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
root="data",
|
root="data",
|
||||||
repo_id="lerobot/debug",
|
repo_id="lerobot/debug",
|
||||||
|
@ -355,7 +387,23 @@ def record(
|
||||||
num_image_writers = num_image_writers_per_camera * 2 ###############
|
num_image_writers = num_image_writers_per_camera * 2 ###############
|
||||||
num_image_writers = max(num_image_writers, 1)
|
num_image_writers = max(num_image_writers, 1)
|
||||||
|
|
||||||
|
# Load policy if any
|
||||||
|
if policy is not None:
|
||||||
|
# Check device is available
|
||||||
|
device = get_safe_torch_device(policy_cfg.device, log=True)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
set_global_seed(policy_cfg.seed)
|
||||||
|
|
||||||
|
# override fps using policy fps
|
||||||
|
fps = policy_cfg.env.fps
|
||||||
|
else:
|
||||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||||
|
|
||||||
if not is_headless() and visualize_images:
|
if not is_headless() and visualize_images:
|
||||||
observations_queue = multiprocessing.Queue(1000)
|
observations_queue = multiprocessing.Queue(1000)
|
||||||
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
||||||
|
@ -369,7 +417,7 @@ def record(
|
||||||
while episode_index < num_episodes:
|
while episode_index < num_episodes:
|
||||||
logging.info(f"Recording episode {episode_index}")
|
logging.info(f"Recording episode {episode_index}")
|
||||||
say(f"Recording episode {episode_index}")
|
say(f"Recording episode {episode_index}")
|
||||||
ep_dict = {'action':[], 'reward':[]}
|
ep_dict = {'action':[], 'next.reward':[]}
|
||||||
for k in state_keys_dict:
|
for k in state_keys_dict:
|
||||||
ep_dict[k] = []
|
ep_dict[k] = []
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
|
@ -381,9 +429,14 @@ def record(
|
||||||
observation, info = env.reset(seed=seed)
|
observation, info = env.reset(seed=seed)
|
||||||
#with stop_reading_leader.get_lock():
|
#with stop_reading_leader.get_lock():
|
||||||
#stop_reading_leader.Value = 0
|
#stop_reading_leader.Value = 0
|
||||||
|
if policy is None:
|
||||||
read_leader.start()
|
read_leader.start()
|
||||||
while timestamp < episode_time_s:
|
while timestamp < episode_time_s:
|
||||||
|
if policy is None:
|
||||||
action = command_queue.get()
|
action = command_queue.get()
|
||||||
|
else:
|
||||||
|
action = get_action_from_policy(policy, observation)
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
||||||
futures += [
|
futures += [
|
||||||
|
@ -402,7 +455,7 @@ def record(
|
||||||
action = np.expand_dims(action, 0)
|
action = np.expand_dims(action, 0)
|
||||||
observation, reward, _, _ , info = env.step(action)
|
observation, reward, _, _ , info = env.step(action)
|
||||||
ep_dict['action'].append(torch.from_numpy(action))
|
ep_dict['action'].append(torch.from_numpy(action))
|
||||||
ep_dict['reward'].append(torch.tensor(reward))
|
ep_dict['next.reward'].append(torch.tensor(reward))
|
||||||
print(reward)
|
print(reward)
|
||||||
|
|
||||||
frame_index += 1
|
frame_index += 1
|
||||||
|
@ -417,6 +470,7 @@ def record(
|
||||||
#stop_reading_leader.Value = 1
|
#stop_reading_leader.Value = 1
|
||||||
# TODO (michel_aractinig): temp fix until I figure out the problem with shared memory
|
# TODO (michel_aractinig): temp fix until I figure out the problem with shared memory
|
||||||
# stop_reading_leader is blocking
|
# stop_reading_leader is blocking
|
||||||
|
if policy is None:
|
||||||
command_queue.close()
|
command_queue.close()
|
||||||
read_leader.terminate()
|
read_leader.terminate()
|
||||||
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
|
||||||
|
@ -451,7 +505,7 @@ def record(
|
||||||
for key in state_keys_dict:
|
for key in state_keys_dict:
|
||||||
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
|
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
|
||||||
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
|
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
|
||||||
ep_dict['reward'] = torch.stack(ep_dict['reward'])
|
ep_dict['next.reward'] = torch.stack(ep_dict['next.reward'])
|
||||||
|
|
||||||
ep_dict["seed"] = torch.tensor([seed] * num_frames)
|
ep_dict["seed"] = torch.tensor([seed] * num_frames)
|
||||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||||
|
@ -577,7 +631,11 @@ def record(
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
def replay(env,
|
||||||
|
episodes: list,
|
||||||
|
fps: int | None = None,
|
||||||
|
root="data",
|
||||||
|
repo_id="lerobot/debug"):
|
||||||
|
|
||||||
env = env()
|
env = env()
|
||||||
local_dir = Path(root) / repo_id
|
local_dir = Path(root) / repo_id
|
||||||
|
@ -700,6 +758,21 @@ if __name__ == "__main__":
|
||||||
default=0,
|
default=0,
|
||||||
help="Visualize image observations with opencv.",
|
help="Visualize image observations with opencv.",
|
||||||
)
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pretrained-policy-name-or-path",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
|
"saved using `Policy.save_pretrained`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--policy-overrides",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||||
|
)
|
||||||
|
|
||||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
|
@ -748,6 +821,16 @@ if __name__ == "__main__":
|
||||||
teleoperate(env_fn, robot, **kwargs)
|
teleoperate(env_fn, robot, **kwargs)
|
||||||
|
|
||||||
elif control_mode == "record":
|
elif control_mode == "record":
|
||||||
|
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||||
|
policy_overrides = args.policy_overrides
|
||||||
|
del kwargs["pretrained_policy_name_or_path"]
|
||||||
|
del kwargs["policy_overrides"]
|
||||||
|
|
||||||
|
if pretrained_policy_name_or_path is not None:
|
||||||
|
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||||
|
kwargs["policy_cfg"] = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||||
|
kwargs["policy"] = make_policy(hydra_cfg=kwargs["policy_cfg"], pretrained_policy_name_or_path=pretrained_policy_path)
|
||||||
|
|
||||||
record(env_fn, robot, **kwargs)
|
record(env_fn, robot, **kwargs)
|
||||||
|
|
||||||
elif control_mode == "replay":
|
elif control_mode == "replay":
|
||||||
|
|
|
@ -135,7 +135,7 @@ def update_policy(
|
||||||
|
|
||||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
with lock if lock is not None else nullcontext():
|
#with lock if lock is not None else nullcontext():
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
# Updates the scale for next iteration.
|
# Updates the scale for next iteration.
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
@ -311,6 +311,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
remove_indices=['observation.images.image_top', 'observation.velocity', 'seed']
|
||||||
|
# temp fix michel_Aractingi TODO
|
||||||
|
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(remove_indices)
|
||||||
|
|
||||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||||
logging.info(
|
logging.info(
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
|
@ -477,7 +482,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
#"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||||
},
|
},
|
||||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||||
fps=online_env.unwrapped.metadata["render_fps"],
|
fps=online_env.unwrapped.metadata["render_fps"],
|
||||||
|
@ -504,6 +509,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
num_samples=len(concat_dataset),
|
num_samples=len(concat_dataset),
|
||||||
replacement=True,
|
replacement=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO michel_aractingi temp fix for incosistent keys
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
|
@ -538,7 +546,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
def sample_trajectory_and_update_buffer():
|
def sample_trajectory_and_update_buffer():
|
||||||
nonlocal rollout_start_seed
|
nonlocal rollout_start_seed
|
||||||
with lock:
|
#with lock:
|
||||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||||
online_rollout_policy.eval()
|
online_rollout_policy.eval()
|
||||||
start_rollout_time = time.perf_counter()
|
start_rollout_time = time.perf_counter()
|
||||||
|
@ -556,13 +564,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||||
|
|
||||||
with lock:
|
#with lock:
|
||||||
start_update_buffer_time = time.perf_counter()
|
start_update_buffer_time = time.perf_counter()
|
||||||
online_dataset.add_data(eval_info["episodes"])
|
online_dataset.add_data(eval_info["episodes"])
|
||||||
|
|
||||||
# Update the concatenated dataset length used during sampling.
|
# Update the concatenated dataset length used during sampling.
|
||||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
|
||||||
# Update the sampling weights.
|
# Update the sampling weights.
|
||||||
sampler.weights = compute_sampler_weights(
|
sampler.weights = compute_sampler_weights(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
|
@ -574,19 +580,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||||
)
|
)
|
||||||
sampler.num_samples = len(concat_dataset)
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||||
|
|
||||||
return online_rollout_s, update_online_buffer_s
|
return online_rollout_s, update_online_buffer_s
|
||||||
|
|
||||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
# TODO remove parallelization for sim
|
||||||
|
#future = executor.submit(sample_trajectory_and_update_buffer)
|
||||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||||
if (
|
if (
|
||||||
not cfg.training.do_online_rollout_async
|
not cfg.training.do_online_rollout_async
|
||||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||||
):
|
):
|
||||||
online_rollout_s, update_online_buffer_s = future.result()
|
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
|
||||||
|
|
||||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -596,12 +602,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||||
with lock:
|
#with lock:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
dataloading_s = time.perf_counter() - start_time
|
dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
# TODO michel aractingi convert float64 to float32 for mac
|
||||||
|
if batch[key].dtype == torch.float64:
|
||||||
|
batch[key] = batch[key].float()
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = update_policy(
|
train_info = update_policy(
|
||||||
|
@ -619,7 +628,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
train_info["online_rollout_s"] = online_rollout_s
|
train_info["online_rollout_s"] = online_rollout_s
|
||||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||||
with lock:
|
#with lock:
|
||||||
train_info["online_buffer_size"] = len(online_dataset)
|
train_info["online_buffer_size"] = len(online_dataset)
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
|
@ -634,10 +643,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||||
# to do the next batch of rollouts.
|
# to do the next batch of rollouts.
|
||||||
if future.running():
|
#if future.running():
|
||||||
start = time.perf_counter()
|
#start = time.perf_counter()
|
||||||
online_rollout_s, update_online_buffer_s = future.result()
|
#online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
|
||||||
await_update_online_buffer_s = time.perf_counter() - start
|
#await_update_online_buffer_s = time.perf_counter() - start
|
||||||
|
|
||||||
if online_step >= cfg.training.online_steps:
|
if online_step >= cfg.training.online_steps:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue