From ff223c106db753202fcac25fbfe50a305ad6e685 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 21 Feb 2025 10:13:43 +0000 Subject: [PATCH] Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen. Added tensordict dependencies Updated the version of torch and torchvision Co-authored-by: Adil Zouitine --- lerobot/common/policies/sac/modeling_sac.py | 28 ++++++++------- lerobot/configs/env/maniskill_example.yaml | 4 +-- lerobot/configs/policy/sac_maniskill.yaml | 13 +++---- lerobot/scripts/server/actor_server.py | 2 +- lerobot/scripts/server/learner_server.py | 35 ++++++++++++++++--- lerobot/scripts/server/learner_service.py | 22 ++++++------ .../scripts/server/maniskill_manipulator.py | 1 - pyproject.toml | 3 +- 8 files changed, 66 insertions(+), 42 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 7cb41ebd..db596982 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -153,7 +153,7 @@ class SACPolicy( return actions def critic_forward( - self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False + self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None ) -> Tensor: """Forward pass through a critic network ensemble @@ -166,7 +166,7 @@ class SACPolicy( Tensor of Q-values from all critics """ critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions) + q_values = critics(observations, actions, observation_features) return q_values def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... @@ -180,14 +180,14 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor: + def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor: temperature = self.log_alpha.exp().item() with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations) + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) # 2- compute q targets q_targets = self.critic_forward( - observations=next_observations, actions=next_action_preds, use_target=True + observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features ) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -204,7 +204,7 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - q_preds = self.critic_forward(observations, actions, use_target=False) + q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. @@ -219,20 +219,20 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_temperature(self, observations) -> Tensor: + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss with torch.no_grad(): - _, log_probs, _ = self.actor(observations) + _, log_probs, _ = self.actor(observations, observation_features) temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() return temperature_loss - def compute_loss_actor(self, observations) -> Tensor: + def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: temperature = self.log_alpha.exp().item() - actions_pi, log_probs, _ = self.actor(observations) + actions_pi, log_probs, _ = self.actor(observations, observation_features) - q_preds = self.critic_forward(observations, actions_pi, use_target=False) + q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features) min_q_preds = q_preds.min(dim=0)[0] actor_loss = ((temperature * log_probs) - min_q_preds).mean() @@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module): self, observations: dict[str, torch.Tensor], actions: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> torch.Tensor: device = get_device_from_parameters(self) # Move each tensor in observations to device @@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = observations if self.encoder is None else self.encoder(observations) + obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) inputs = torch.cat([obs_enc, actions], dim=-1) q_values = self.ensemble(inputs) # [num_critics, B, 1] @@ -441,9 +442,10 @@ class Policy(nn.Module): def forward( self, observations: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = observations if self.encoder is None else self.encoder(observations) + obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) # Get network outputs outputs = self.network(obs_enc) diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index 2b9966c9..03814614 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -5,14 +5,14 @@ fps: 20 env: name: maniskill/pushcube task: PushCube-v1 - image_size: 64 + image_size: 128 control_mode: pd_ee_delta_pose state_dim: 25 action_dim: 7 fps: ${fps} obs: rgb render_mode: rgb_array - render_size: 64 + render_size: 128 device: cuda reward_classifier: diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 3e0dbe61..d23c0017 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -31,7 +31,7 @@ training: online_env_seed: 10000 online_buffer_capacity: 1000000 online_buffer_seed_size: 0 - online_step_before_learning: 5000 + online_step_before_learning: 500 do_online_rollout_async: false policy_update_freq: 1 @@ -52,19 +52,16 @@ policy: n_action_steps: 1 shared_encoder: true - vision_encoder_name: null - # vision_encoder_name: "helper2424/resnet10" - # freeze_vision_encoder: true - freeze_vision_encoder: false + vision_encoder_name: "helper2424/resnet10" + freeze_vision_encoder: true input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] - observation.image: [3, 64, 64] - observation.image.2: [3, 64, 64] + observation.image: [3, 128, 128] output_shapes: action: [7] - camera_number: 2 + camera_number: 1 # Normalization / Unnormalization input_normalization_modes: null diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index f0c6f2a9..64091883 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -217,7 +217,7 @@ def learner_service_client( { "name": [{}], # Applies to ALL methods in ALL services "retryPolicy": { - "maxAttempts": 5, # Max retries (total attempts = 5) + "maxAttempts": 7, # Max retries (total attempts = 5) "initialBackoff": "0.1s", # First retry after 0.1s "maxBackoff": "2s", # Max wait time between retries "backoffMultiplier": 2, # Exponential backoff factor diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2d00e7ed..e46681f9 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -169,6 +169,25 @@ def initialize_replay_buffer( ) +def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = ( + policy.actor.encoder(observations) + if policy.actor.encoder is not None + else None + ) + next_observation_features = ( + policy.actor.encoder(next_observations) + if policy.actor.encoder is not None + else None + ) + + return observation_features, next_observation_features + + def start_learner_threads( cfg: DictConfig, device: str, @@ -345,9 +364,6 @@ def add_actor_information_and_train( if len(replay_buffer) < cfg.training.online_step_before_learning: continue - # logging.info(f"Size of replay buffer: {len(replay_buffer)}") - # logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}") - time_for_one_optimization_step = time.time() for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) @@ -356,6 +372,7 @@ def add_actor_information_and_train( batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions(batch, batch_offline) + actions = batch["action"] rewards = batch["reward"] observations = batch["state"] @@ -365,6 +382,7 @@ def add_actor_information_and_train( observations=observations, actions=actions, next_state=next_observations ) + observation_features, next_observation_features = get_observation_features(policy, observations, next_observations) with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -372,6 +390,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -395,6 +415,7 @@ def add_actor_information_and_train( observations=observations, actions=actions, next_state=next_observations ) + observation_features, next_observation_features = get_observation_features(policy, observations, next_observations) with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -402,6 +423,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -413,7 +436,8 @@ def add_actor_information_and_train( if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): with policy_lock: - loss_actor = policy.compute_loss_actor(observations=observations) + loss_actor = policy.compute_loss_actor(observations=observations, + observation_features=observation_features) optimizers["actor"].zero_grad() loss_actor.backward() @@ -422,7 +446,8 @@ def add_actor_information_and_train( training_infos["loss_actor"] = loss_actor.item() loss_temperature = policy.compute_loss_temperature( - observations=observations + observations=observations, + observation_features=observation_features ) optimizers["temperature"].zero_grad() loss_temperature.backward() diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py index 97601528..d6e6b5b7 100644 --- a/lerobot/scripts/server/learner_service.py +++ b/lerobot/scripts/server/learner_service.py @@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): def _get_policy_state(self): with self.policy_lock: params_dict = self.policy.actor.state_dict() - if self.policy.config.vision_encoder_name is not None: - if self.policy.config.freeze_vision_encoder: - params_dict: dict[str, torch.Tensor] = { - k: v - for k, v in params_dict.items() - if not k.startswith("encoder.") - } - else: - raise NotImplementedError( - "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." - ) + # if self.policy.config.vision_encoder_name is not None: + # if self.policy.config.freeze_vision_encoder: + # params_dict: dict[str, torch.Tensor] = { + # k: v + # for k, v in params_dict.items() + # if not k.startswith("encoder.") + # } + # else: + # raise NotImplementedError( + # "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." + # ) return move_state_dict_to_device(params_dict, device="cpu") diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index 105deeb4..e1c0840a 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) return_observations["observation.image"] = img - return_observations["observation.image.2"] = img return_observations["observation.state"] = state return return_observations diff --git a/pyproject.toml b/pyproject.toml index 89577a4f..6f884d44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dependencies = [ "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", "termcolor>=2.4.0", + "tensordict>=0.0.1", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))", "torchmetrics>=1.6.0", @@ -88,7 +89,7 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"] +hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] mani_skill = ["mani-skill"] pi0 = ["transformers>=4.48.0"]