Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64 - Added support for truncation in replay buffer and actor server - Updated SAC policy configuration to use a specific dataset and modify vision encoder settings - Improved dataset conversion process with progress tracking and task naming - Added flexibility for joint action space masking in learner server
This commit is contained in:
parent
546719137a
commit
42a038173f
|
@ -5,16 +5,20 @@ fps: 20
|
||||||
env:
|
env:
|
||||||
name: maniskill/pushcube
|
name: maniskill/pushcube
|
||||||
task: PushCube-v1
|
task: PushCube-v1
|
||||||
image_size: 128
|
image_size: 64
|
||||||
control_mode: pd_ee_delta_pose
|
control_mode: pd_ee_delta_pose
|
||||||
state_dim: 25
|
state_dim: 25
|
||||||
action_dim: 7
|
action_dim: 7
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
obs: rgb
|
obs: rgb
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
render_size: 128
|
render_size: 64
|
||||||
device: cuda
|
device: cuda
|
||||||
|
|
||||||
reward_classifier:
|
reward_classifier:
|
||||||
pretrained_path: null
|
pretrained_path: null
|
||||||
config_path: null
|
config_path: null
|
||||||
|
|
||||||
|
wrapper:
|
||||||
|
joint_masking_action_space: null
|
||||||
|
delta_action: null
|
||||||
|
|
|
@ -8,7 +8,8 @@
|
||||||
# env.gym.obs_type=environment_state_agent_pos \
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
dataset_repo_id: null
|
# dataset_repo_id: null
|
||||||
|
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
|
@ -52,12 +53,14 @@ policy:
|
||||||
n_action_steps: 1
|
n_action_steps: 1
|
||||||
|
|
||||||
shared_encoder: true
|
shared_encoder: true
|
||||||
vision_encoder_name: "helper2424/resnet10"
|
# vision_encoder_name: "helper2424/resnet10"
|
||||||
freeze_vision_encoder: true
|
vision_encoder_name: null
|
||||||
|
# freeze_vision_encoder: true
|
||||||
|
freeze_vision_encoder: false
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
observation.image: [3, 128, 128]
|
observation.image: [3, 64, 64]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: [7]
|
action: [7]
|
||||||
|
|
||||||
|
|
|
@ -373,6 +373,7 @@ def act_with_policy(
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=next_obs,
|
next_state=next_obs,
|
||||||
done=done,
|
done=done,
|
||||||
|
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,6 +31,7 @@ class Transition(TypedDict):
|
||||||
reward: float
|
reward: float
|
||||||
next_state: dict[str, torch.Tensor]
|
next_state: dict[str, torch.Tensor]
|
||||||
done: bool
|
done: bool
|
||||||
|
truncated: bool
|
||||||
complementary_info: dict[str, Any] = None
|
complementary_info: dict[str, Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +41,7 @@ class BatchTransition(TypedDict):
|
||||||
reward: torch.Tensor
|
reward: torch.Tensor
|
||||||
next_state: dict[str, torch.Tensor]
|
next_state: dict[str, torch.Tensor]
|
||||||
done: torch.Tensor
|
done: torch.Tensor
|
||||||
|
truncated: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def move_transition_to_device(
|
def move_transition_to_device(
|
||||||
|
@ -70,6 +72,11 @@ def move_transition_to_device(
|
||||||
device, non_blocking=device.type == "cuda"
|
device, non_blocking=device.type == "cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(transition["truncated"], torch.Tensor):
|
||||||
|
transition["truncated"] = transition["truncated"].to(
|
||||||
|
device, non_blocking=device.type == "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
# Move next_state tensors to CPU
|
# Move next_state tensors to CPU
|
||||||
transition["next_state"] = {
|
transition["next_state"] = {
|
||||||
key: val.to(device, non_blocking=device.type == "cuda")
|
key: val.to(device, non_blocking=device.type == "cuda")
|
||||||
|
@ -205,6 +212,7 @@ class ReplayBuffer:
|
||||||
reward: float,
|
reward: float,
|
||||||
next_state: dict[str, torch.Tensor],
|
next_state: dict[str, torch.Tensor],
|
||||||
done: bool,
|
done: bool,
|
||||||
|
truncated: bool,
|
||||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
|
@ -229,6 +237,7 @@ class ReplayBuffer:
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=next_state,
|
next_state=next_state,
|
||||||
done=done,
|
done=done,
|
||||||
|
truncated=truncated,
|
||||||
complementary_info=complementary_info,
|
complementary_info=complementary_info,
|
||||||
)
|
)
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
@ -294,6 +303,7 @@ class ReplayBuffer:
|
||||||
reward=data["reward"],
|
reward=data["reward"],
|
||||||
next_state=data["next_state"],
|
next_state=data["next_state"],
|
||||||
done=data["done"],
|
done=data["done"],
|
||||||
|
truncated=False,
|
||||||
)
|
)
|
||||||
return replay_buffer
|
return replay_buffer
|
||||||
|
|
||||||
|
@ -352,6 +362,8 @@ class ReplayBuffer:
|
||||||
# ----- 3) Reward and done -----
|
# ----- 3) Reward and done -----
|
||||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||||
|
# TODO: (azouitine) Handle truncation properly
|
||||||
|
truncated = bool(current_sample["next.done"].item()) # ensure bool
|
||||||
|
|
||||||
# ----- 4) Next state -----
|
# ----- 4) Next state -----
|
||||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||||
|
@ -374,6 +386,7 @@ class ReplayBuffer:
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=next_state,
|
next_state=next_state,
|
||||||
done=done,
|
done=done,
|
||||||
|
truncated=truncated,
|
||||||
)
|
)
|
||||||
transitions.append(transition)
|
transitions.append(transition)
|
||||||
|
|
||||||
|
@ -419,6 +432,11 @@ class ReplayBuffer:
|
||||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
# -- Build batched truncateds --
|
||||||
|
batch_truncateds = torch.tensor(
|
||||||
|
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
# Return a BatchTransition typed dict
|
# Return a BatchTransition typed dict
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
|
@ -426,6 +444,7 @@ class ReplayBuffer:
|
||||||
reward=batch_rewards,
|
reward=batch_rewards,
|
||||||
next_state=batch_next_state,
|
next_state=batch_next_state,
|
||||||
done=batch_dones,
|
done=batch_dones,
|
||||||
|
truncated=batch_truncateds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_lerobot_dataset(
|
def to_lerobot_dataset(
|
||||||
|
@ -501,7 +520,7 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# Start writing images if needed. If you have no image features, this is harmless.
|
# Start writing images if needed. If you have no image features, this is harmless.
|
||||||
# Set num_processes or num_threads if you want concurrency.
|
# Set num_processes or num_threads if you want concurrency.
|
||||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Convert transitions into episodes and frames
|
# Convert transitions into episodes and frames
|
||||||
|
@ -513,7 +532,11 @@ class ReplayBuffer:
|
||||||
)
|
)
|
||||||
|
|
||||||
frame_idx_in_episode = 0
|
frame_idx_in_episode = 0
|
||||||
for global_frame_idx, transition in enumerate(self.memory):
|
for global_frame_idx, transition in tqdm(
|
||||||
|
enumerate(self.memory),
|
||||||
|
desc="Converting replay buffer to dataset",
|
||||||
|
total=len(self.memory),
|
||||||
|
):
|
||||||
frame_dict = {}
|
frame_dict = {}
|
||||||
|
|
||||||
# Fill the data for state keys
|
# Fill the data for state keys
|
||||||
|
@ -546,14 +569,15 @@ class ReplayBuffer:
|
||||||
# Move to next frame
|
# Move to next frame
|
||||||
frame_idx_in_episode += 1
|
frame_idx_in_episode += 1
|
||||||
# If we reached an episode boundary, call save_episode, reset counters
|
# If we reached an episode boundary, call save_episode, reset counters
|
||||||
if transition["done"]:
|
# TODO: (azouitine) Handle truncation properly
|
||||||
|
if transition["done"] or transition["truncated"]:
|
||||||
# Use some placeholder name for the task
|
# Use some placeholder name for the task
|
||||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
lerobot_dataset.save_episode(task=task_name)
|
||||||
episode_index += 1
|
episode_index += 1
|
||||||
frame_idx_in_episode = 0
|
frame_idx_in_episode = 0
|
||||||
# Start a new buffer for the next episode
|
# Start a new buffer for the next episode
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||||
episode_index
|
episode_index=episode_index
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are done adding frames
|
# We are done adding frames
|
||||||
|
@ -624,6 +648,10 @@ def concatenate_batch_transitions(
|
||||||
left_batch_transitions["done"] = torch.cat(
|
left_batch_transitions["done"] = torch.cat(
|
||||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||||
)
|
)
|
||||||
|
left_batch_transitions["truncated"] = torch.cat(
|
||||||
|
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
return left_batch_transitions
|
return left_batch_transitions
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -153,7 +153,7 @@ def initialize_replay_buffer(
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
storage_device=device
|
storage_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
|
@ -169,8 +169,13 @@ def initialize_replay_buffer(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
def get_observation_features(
|
||||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
if (
|
||||||
|
policy.config.vision_encoder_name is None
|
||||||
|
or not policy.config.freeze_vision_encoder
|
||||||
|
):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -338,6 +343,7 @@ def add_actor_information_and_train(
|
||||||
interaction_step_shift = (
|
interaction_step_shift = (
|
||||||
resume_interaction_step if resume_interaction_step is not None else 0
|
resume_interaction_step if resume_interaction_step is not None else 0
|
||||||
)
|
)
|
||||||
|
saved_data = False
|
||||||
while True:
|
while True:
|
||||||
if shutdown_event is not None and shutdown_event.is_set():
|
if shutdown_event is not None and shutdown_event.is_set():
|
||||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||||
|
@ -372,7 +378,6 @@ def add_actor_information_and_train(
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
|
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
|
@ -382,7 +387,9 @@ def add_actor_information_and_train(
|
||||||
observations=observations, actions=actions, next_state=next_observations
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
)
|
)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
observation_features, next_observation_features = get_observation_features(
|
||||||
|
policy, observations, next_observations
|
||||||
|
)
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -415,7 +422,9 @@ def add_actor_information_and_train(
|
||||||
observations=observations, actions=actions, next_state=next_observations
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
)
|
)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
observation_features, next_observation_features = get_observation_features(
|
||||||
|
policy, observations, next_observations
|
||||||
|
)
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -436,8 +445,10 @@ def add_actor_information_and_train(
|
||||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_actor = policy.compute_loss_actor(observations=observations,
|
loss_actor = policy.compute_loss_actor(
|
||||||
observation_features=observation_features)
|
observations=observations,
|
||||||
|
observation_features=observation_features,
|
||||||
|
)
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
|
@ -447,7 +458,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
|
@ -458,7 +469,9 @@ def add_actor_information_and_train(
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
if optimization_step % cfg.training.log_freq == 0:
|
if optimization_step % cfg.training.log_freq == 0:
|
||||||
training_infos["Optimization step"] = optimization_step
|
training_infos["Optimization step"] = optimization_step
|
||||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
logger.log_dict(
|
||||||
|
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||||
|
)
|
||||||
# logging.info(f"Training infos: {training_infos}")
|
# logging.info(f"Training infos: {training_infos}")
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||||
|
@ -621,6 +634,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
logging.info("Convertion to a offline replay buffer")
|
logging.info("Convertion to a offline replay buffer")
|
||||||
|
active_action_dims = None
|
||||||
|
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||||
active_action_dims = [
|
active_action_dims = [
|
||||||
i
|
i
|
||||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||||
|
|
Loading…
Reference in New Issue