Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
459f22ed30
commit
c462a478c7
|
@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
# self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
self.tasks = load_tasks(self.root)
|
self.tasks = load_tasks(self.root)
|
||||||
|
@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
||||||
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
||||||
# self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
|
|
|
@ -137,7 +137,7 @@ class SACPolicy(
|
||||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||||
if self.actor.fixed_std is not None:
|
if self.actor.fixed_std is not None:
|
||||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||||
self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||||
super().to(*args, **kwargs)
|
super().to(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -31,7 +31,7 @@ training:
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
online_buffer_capacity: 1000000
|
online_buffer_capacity: 1000000
|
||||||
online_buffer_seed_size: 0
|
online_buffer_seed_size: 0
|
||||||
online_step_before_learning: 1000 #5000
|
online_step_before_learning: 100 #5000
|
||||||
do_online_rollout_async: false
|
do_online_rollout_async: false
|
||||||
policy_update_freq: 1
|
policy_update_freq: 1
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ policy:
|
||||||
observation.images.side: [3, 128, 128]
|
observation.images.side: [3, 128, 128]
|
||||||
# observation.image: [3, 128, 128]
|
# observation.image: [3, 128, 128]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: ["${env.action_dim}"]
|
action: [4] # ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes:
|
input_normalization_modes:
|
||||||
|
@ -84,9 +84,12 @@ policy:
|
||||||
output_normalization_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
output_normalization_params:
|
output_normalization_params:
|
||||||
|
# action:
|
||||||
|
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||||
|
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
action:
|
action:
|
||||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
min: [-145.283203125, -69.43359375, -78.75, -46.0546875]
|
||||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
max: [145.283203125, 69.43359375, 78.75, 46.0546875]
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
|
|
|
@ -201,6 +201,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
"action": {"min": min_action_space, "max": max_action_space}
|
"action": {"min": min_action_space, "max": max_action_space}
|
||||||
}
|
}
|
||||||
cfg.policy.output_normalization_params = output_normalization_params
|
cfg.policy.output_normalization_params = output_normalization_params
|
||||||
|
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
||||||
|
|
||||||
### Instantiate the policy in both the actor and learner processes
|
### Instantiate the policy in both the actor and learner processes
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||||
|
@ -252,6 +253,8 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||||
if info["is_intervention"]:
|
if info["is_intervention"]:
|
||||||
# TODO: Check the shape
|
# TODO: Check the shape
|
||||||
|
# NOTE: The action space for demonstration before hand is with the full action space
|
||||||
|
# but sometimes for example we want to deactivate the gripper
|
||||||
action = info["action_intervention"]
|
action = info["action_intervention"]
|
||||||
episode_intervention = True
|
episode_intervention = True
|
||||||
|
|
||||||
|
|
|
@ -195,6 +195,7 @@ class ReplayBuffer:
|
||||||
device: str = "cuda:0",
|
device: str = "cuda:0",
|
||||||
state_keys: Optional[Sequence[str]] = None,
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
capacity: Optional[int] = None,
|
capacity: Optional[int] = None,
|
||||||
|
action_mask: Optional[Sequence[int]] = None,
|
||||||
) -> "ReplayBuffer":
|
) -> "ReplayBuffer":
|
||||||
"""
|
"""
|
||||||
Convert a LeRobotDataset into a ReplayBuffer.
|
Convert a LeRobotDataset into a ReplayBuffer.
|
||||||
|
@ -229,6 +230,12 @@ class ReplayBuffer:
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
data[k] = v.to(device)
|
data[k] = v.to(device)
|
||||||
|
|
||||||
|
if action_mask is not None:
|
||||||
|
if data["action"].dim() == 1:
|
||||||
|
data["action"] = data["action"][action_mask]
|
||||||
|
else:
|
||||||
|
data["action"] = data["action"][:, action_mask]
|
||||||
|
|
||||||
replay_buffer.add(
|
replay_buffer.add(
|
||||||
state=data["state"],
|
state=data["state"],
|
||||||
action=data["action"],
|
action=data["action"],
|
||||||
|
|
|
@ -328,7 +328,7 @@ class RewardWrapper(gym.Wrapper):
|
||||||
return self.env.reset(seed=seed, options=options)
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
class JointMaskingActionSpace(gym.ActionWrapper):
|
class JointMaskingActionSpace(gym.Wrapper):
|
||||||
def __init__(self, env, mask):
|
def __init__(self, env, mask):
|
||||||
"""
|
"""
|
||||||
Wrapper to mask out dimensions of the action space.
|
Wrapper to mask out dimensions of the action space.
|
||||||
|
@ -388,6 +388,16 @@ class JointMaskingActionSpace(gym.ActionWrapper):
|
||||||
full_action[self.active_dims] = masked_action
|
full_action[self.active_dims] = masked_action
|
||||||
return full_action
|
return full_action
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action = self.action(action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||||
|
if info["action_intervention"].dim() == 1:
|
||||||
|
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
||||||
|
else:
|
||||||
|
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
class TimeLimitWrapper(gym.Wrapper):
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, control_time_s, fps):
|
def __init__(self, env, control_time_s, fps):
|
||||||
|
|
|
@ -354,7 +354,7 @@ def add_actor_information_and_train(
|
||||||
transition = move_transition_to_device(transition, device=device)
|
transition = move_transition_to_device(transition, device=device)
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
|
|
||||||
if transition.get("complementary_info", {}).get("is_interaction"):
|
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
|
||||||
while not interaction_message_queue.empty():
|
while not interaction_message_queue.empty():
|
||||||
|
@ -568,6 +568,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
# TODO: At some point we should just need make sac policy
|
# TODO: At some point we should just need make sac policy
|
||||||
|
|
||||||
policy_lock = Lock()
|
policy_lock = Lock()
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
hydra_cfg=cfg,
|
hydra_cfg=cfg,
|
||||||
|
@ -593,8 +594,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
logging.info("Convertion to a offline replay buffer")
|
logging.info("Convertion to a offline replay buffer")
|
||||||
|
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
offline_dataset,
|
||||||
|
device=device,
|
||||||
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
action_mask=active_action_dims,
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue