Update ManiSkill configuration and replay buffer to support truncation and dataset handling

- Reduced image size in ManiSkill environment configuration from 128 to 64
- Added support for truncation in replay buffer and actor server
- Updated SAC policy configuration to use a specific dataset and modify vision encoder settings
- Improved dataset conversion process with progress tracking and task naming
- Added flexibility for joint action space masking in learner server
This commit is contained in:
AdilZouitine 2025-02-24 16:53:37 +00:00
parent 546719137a
commit 42a038173f
5 changed files with 78 additions and 27 deletions

View File

@ -5,16 +5,20 @@ fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 128
image_size: 64
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 128
render_size: 64
device: cuda
reward_classifier:
pretrained_path: null
config_path: null
config_path: null
wrapper:
joint_masking_action_space: null
delta_action: null

View File

@ -8,7 +8,8 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: null
# dataset_repo_id: null
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
training:
# Offline training dataloader
@ -52,12 +53,14 @@ policy:
n_action_steps: 1
shared_encoder: true
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
# vision_encoder_name: "helper2424/resnet10"
vision_encoder_name: null
# freeze_vision_encoder: true
freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 128, 128]
observation.image: [3, 64, 64]
output_shapes:
action: [7]

View File

@ -373,6 +373,7 @@ def act_with_policy(
reward=reward,
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
)
)

View File

@ -31,6 +31,7 @@ class Transition(TypedDict):
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, Any] = None
@ -40,6 +41,7 @@ class BatchTransition(TypedDict):
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
def move_transition_to_device(
@ -70,6 +72,11 @@ def move_transition_to_device(
device, non_blocking=device.type == "cuda"
)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU
transition["next_state"] = {
key: val.to(device, non_blocking=device.type == "cuda")
@ -205,6 +212,7 @@ class ReplayBuffer:
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
@ -229,6 +237,7 @@ class ReplayBuffer:
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity
@ -294,6 +303,7 @@ class ReplayBuffer:
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False,
)
return replay_buffer
@ -352,6 +362,8 @@ class ReplayBuffer:
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# TODO: (azouitine) Handle truncation properly
truncated = bool(current_sample["next.done"].item()) # ensure bool
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
@ -374,6 +386,7 @@ class ReplayBuffer:
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
)
transitions.append(transition)
@ -419,6 +432,11 @@ class ReplayBuffer:
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched truncateds --
batch_truncateds = torch.tensor(
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
@ -426,6 +444,7 @@ class ReplayBuffer:
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
)
def to_lerobot_dataset(
@ -501,7 +520,7 @@ class ReplayBuffer:
# Start writing images if needed. If you have no image features, this is harmless.
# Set num_processes or num_threads if you want concurrency.
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# --------------------------------------------------------------------------------------------
# Convert transitions into episodes and frames
@ -513,7 +532,11 @@ class ReplayBuffer:
)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
for global_frame_idx, transition in tqdm(
enumerate(self.memory),
desc="Converting replay buffer to dataset",
total=len(self.memory),
):
frame_dict = {}
# Fill the data for state keys
@ -546,14 +569,15 @@ class ReplayBuffer:
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if transition["done"]:
# TODO: (azouitine) Handle truncation properly
if transition["done"] or transition["truncated"]:
# Use some placeholder name for the task
lerobot_dataset.save_episode(task="from_replay_buffer")
lerobot_dataset.save_episode(task=task_name)
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
episode_index=episode_index
)
# We are done adding frames
@ -624,6 +648,10 @@ def concatenate_batch_transitions(
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0,
)
return left_batch_transitions

View File

@ -153,7 +153,7 @@ def initialize_replay_buffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=device
storage_device=device,
)
dataset = LeRobotDataset(
@ -169,8 +169,13 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if (
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
return None, None
with torch.no_grad():
@ -338,6 +343,7 @@ def add_actor_information_and_train(
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
saved_data = False
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
@ -372,7 +378,6 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@ -382,7 +387,9 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -415,7 +422,9 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -436,8 +445,10 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
@ -447,7 +458,7 @@ def add_actor_information_and_train(
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
@ -458,7 +469,9 @@ def add_actor_information_and_train(
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
training_infos["Optimization step"] = optimization_step
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
)
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
@ -621,11 +634,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,