Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64 - Added support for truncation in replay buffer and actor server - Updated SAC policy configuration to use a specific dataset and modify vision encoder settings - Improved dataset conversion process with progress tracking and task naming - Added flexibility for joint action space masking in learner server
This commit is contained in:
parent
546719137a
commit
42a038173f
|
@ -5,16 +5,20 @@ fps: 20
|
|||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
image_size: 64
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
render_size: 64
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null
|
||||
config_path: null
|
||||
config_path: null
|
||||
|
||||
wrapper:
|
||||
joint_masking_action_space: null
|
||||
delta_action: null
|
||||
|
|
|
@ -8,7 +8,8 @@
|
|||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: null
|
||||
# dataset_repo_id: null
|
||||
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
|
@ -52,12 +53,14 @@ policy:
|
|||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
# vision_encoder_name: "helper2424/resnet10"
|
||||
vision_encoder_name: null
|
||||
# freeze_vision_encoder: true
|
||||
freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
observation.image: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
|
|
|
@ -373,6 +373,7 @@ def act_with_policy(
|
|||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||
)
|
||||
)
|
||||
|
|
|
@ -31,6 +31,7 @@ class Transition(TypedDict):
|
|||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
|
||||
|
||||
|
@ -40,6 +41,7 @@ class BatchTransition(TypedDict):
|
|||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(
|
||||
|
@ -70,6 +72,11 @@ def move_transition_to_device(
|
|||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
|
@ -205,6 +212,7 @@ class ReplayBuffer:
|
|||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
|
@ -229,6 +237,7 @@ class ReplayBuffer:
|
|||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
@ -294,6 +303,7 @@ class ReplayBuffer:
|
|||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False,
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
|
@ -352,6 +362,8 @@ class ReplayBuffer:
|
|||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
truncated = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
|
@ -374,6 +386,7 @@ class ReplayBuffer:
|
|||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
|
@ -419,6 +432,11 @@ class ReplayBuffer:
|
|||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched truncateds --
|
||||
batch_truncateds = torch.tensor(
|
||||
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
|
@ -426,6 +444,7 @@ class ReplayBuffer:
|
|||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
|
@ -501,7 +520,7 @@ class ReplayBuffer:
|
|||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
|
@ -513,7 +532,11 @@ class ReplayBuffer:
|
|||
)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
for global_frame_idx, transition in tqdm(
|
||||
enumerate(self.memory),
|
||||
desc="Converting replay buffer to dataset",
|
||||
total=len(self.memory),
|
||||
):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
|
@ -546,14 +569,15 @@ class ReplayBuffer:
|
|||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
if transition["done"] or transition["truncated"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# We are done adding frames
|
||||
|
@ -624,6 +648,10 @@ def concatenate_batch_transitions(
|
|||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
|
|
|
@ -153,7 +153,7 @@ def initialize_replay_buffer(
|
|||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device
|
||||
storage_device=device,
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
|
@ -169,8 +169,13 @@ def initialize_replay_buffer(
|
|||
)
|
||||
|
||||
|
||||
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if (
|
||||
policy.config.vision_encoder_name is None
|
||||
or not policy.config.freeze_vision_encoder
|
||||
):
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -338,6 +343,7 @@ def add_actor_information_and_train(
|
|||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
saved_data = False
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
|
@ -372,7 +378,6 @@ def add_actor_information_and_train(
|
|||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
|
@ -382,7 +387,9 @@ def add_actor_information_and_train(
|
|||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -415,7 +422,9 @@ def add_actor_information_and_train(
|
|||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -436,8 +445,10 @@ def add_actor_information_and_train(
|
|||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations,
|
||||
observation_features=observation_features)
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
@ -447,7 +458,7 @@ def add_actor_information_and_train(
|
|||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
@ -458,7 +469,9 @@ def add_actor_information_and_train(
|
|||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
logger.log_dict(
|
||||
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||
)
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
|
@ -621,11 +634,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
|
|
Loading…
Reference in New Issue