Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
459f22ed30
commit
c462a478c7
|
@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
# self.pull_from_repo(allow_patterns="meta/")
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
|
@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
||||
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
||||
# self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
|
|
@ -137,7 +137,7 @@ class SACPolicy(
|
|||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -31,7 +31,7 @@ training:
|
|||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 1000 #5000
|
||||
online_step_before_learning: 100 #5000
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
|
@ -61,7 +61,7 @@ policy:
|
|||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
action: [4] # ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
|
@ -84,9 +84,12 @@ policy:
|
|||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
# action:
|
||||
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
action:
|
||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
min: [-145.283203125, -69.43359375, -78.75, -46.0546875]
|
||||
max: [145.283203125, 69.43359375, 78.75, 46.0546875]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
|
|
|
@ -201,6 +201,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
|||
"action": {"min": min_action_space, "max": max_action_space}
|
||||
}
|
||||
cfg.policy.output_normalization_params = output_normalization_params
|
||||
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
|
@ -252,6 +253,8 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
|||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||
if info["is_intervention"]:
|
||||
# TODO: Check the shape
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
|
||||
|
|
|
@ -195,6 +195,7 @@ class ReplayBuffer:
|
|||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
@ -229,6 +230,12 @@ class ReplayBuffer:
|
|||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
if action_mask is not None:
|
||||
if data["action"].dim() == 1:
|
||||
data["action"] = data["action"][action_mask]
|
||||
else:
|
||||
data["action"] = data["action"][:, action_mask]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
|
|
|
@ -328,7 +328,7 @@ class RewardWrapper(gym.Wrapper):
|
|||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class JointMaskingActionSpace(gym.ActionWrapper):
|
||||
class JointMaskingActionSpace(gym.Wrapper):
|
||||
def __init__(self, env, mask):
|
||||
"""
|
||||
Wrapper to mask out dimensions of the action space.
|
||||
|
@ -388,6 +388,16 @@ class JointMaskingActionSpace(gym.ActionWrapper):
|
|||
full_action[self.active_dims] = masked_action
|
||||
return full_action
|
||||
|
||||
def step(self, action):
|
||||
action = self.action(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||
if info["action_intervention"].dim() == 1:
|
||||
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
||||
else:
|
||||
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
|
|
|
@ -354,7 +354,7 @@ def add_actor_information_and_train(
|
|||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_interaction"):
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
|
@ -568,6 +568,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
|
@ -593,8 +594,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
|
|
Loading…
Reference in New Issue