Hardcoded some normalization parameters. TODO refactor

Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-13 14:27:14 +01:00
parent 459f22ed30
commit c462a478c7
7 changed files with 38 additions and 10 deletions

View File

@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
# self.pull_from_repo(allow_patterns="meta/")
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
# self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""

View File

@ -137,7 +137,7 @@ class SACPolicy(
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
self.log_alpha = self.log_alpha.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()

View File

@ -31,7 +31,7 @@ training:
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 1000 #5000
online_step_before_learning: 100 #5000
do_online_rollout_async: false
policy_update_freq: 1
@ -61,7 +61,7 @@ policy:
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: ["${env.action_dim}"]
action: [4] # ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
@ -84,9 +84,12 @@ policy:
output_normalization_modes:
action: min_max
output_normalization_params:
# action:
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
action:
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
min: [-145.283203125, -69.43359375, -78.75, -46.0546875]
max: [145.283203125, 69.43359375, 78.75, 46.0546875]
# Architecture / modeling.
# Neural networks.

View File

@ -201,6 +201,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
"action": {"min": min_action_space, "max": max_action_space}
}
cfg.policy.output_normalization_params = output_normalization_params
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
@ -252,6 +253,8 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
if info["is_intervention"]:
# TODO: Check the shape
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True

View File

@ -195,6 +195,7 @@ class ReplayBuffer:
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
@ -229,6 +230,12 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(device)
if action_mask is not None:
if data["action"].dim() == 1:
data["action"] = data["action"][action_mask]
else:
data["action"] = data["action"][:, action_mask]
replay_buffer.add(
state=data["state"],
action=data["action"],

View File

@ -328,7 +328,7 @@ class RewardWrapper(gym.Wrapper):
return self.env.reset(seed=seed, options=options)
class JointMaskingActionSpace(gym.ActionWrapper):
class JointMaskingActionSpace(gym.Wrapper):
def __init__(self, env, mask):
"""
Wrapper to mask out dimensions of the action space.
@ -388,6 +388,16 @@ class JointMaskingActionSpace(gym.ActionWrapper):
full_action[self.active_dims] = masked_action
return full_action
def step(self, action):
action = self.action(action)
obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][self.active_dims]
else:
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
return obs, reward, terminated, truncated, info
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):

View File

@ -354,7 +354,7 @@ def add_actor_information_and_train(
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_interaction"):
if transition.get("complementary_info", {}).get("is_intervention"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
@ -568,6 +568,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
@ -593,8 +594,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer