diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 7cb41ebd..db596982 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -153,7 +153,7 @@ class SACPolicy( return actions def critic_forward( - self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False + self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None ) -> Tensor: """Forward pass through a critic network ensemble @@ -166,7 +166,7 @@ class SACPolicy( Tensor of Q-values from all critics """ critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions) + q_values = critics(observations, actions, observation_features) return q_values def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... @@ -180,14 +180,14 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor: + def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor: temperature = self.log_alpha.exp().item() with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations) + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) # 2- compute q targets q_targets = self.critic_forward( - observations=next_observations, actions=next_action_preds, use_target=True + observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features ) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -204,7 +204,7 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - q_preds = self.critic_forward(observations, actions, use_target=False) + q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. @@ -219,20 +219,20 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_temperature(self, observations) -> Tensor: + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss with torch.no_grad(): - _, log_probs, _ = self.actor(observations) + _, log_probs, _ = self.actor(observations, observation_features) temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() return temperature_loss - def compute_loss_actor(self, observations) -> Tensor: + def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: temperature = self.log_alpha.exp().item() - actions_pi, log_probs, _ = self.actor(observations) + actions_pi, log_probs, _ = self.actor(observations, observation_features) - q_preds = self.critic_forward(observations, actions_pi, use_target=False) + q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features) min_q_preds = q_preds.min(dim=0)[0] actor_loss = ((temperature * log_probs) - min_q_preds).mean() @@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module): self, observations: dict[str, torch.Tensor], actions: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> torch.Tensor: device = get_device_from_parameters(self) # Move each tensor in observations to device @@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = observations if self.encoder is None else self.encoder(observations) + obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) inputs = torch.cat([obs_enc, actions], dim=-1) q_values = self.ensemble(inputs) # [num_critics, B, 1] @@ -441,9 +442,10 @@ class Policy(nn.Module): def forward( self, observations: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = observations if self.encoder is None else self.encoder(observations) + obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) # Get network outputs outputs = self.network(obs_enc) diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index 2b9966c9..03814614 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -5,14 +5,14 @@ fps: 20 env: name: maniskill/pushcube task: PushCube-v1 - image_size: 64 + image_size: 128 control_mode: pd_ee_delta_pose state_dim: 25 action_dim: 7 fps: ${fps} obs: rgb render_mode: rgb_array - render_size: 64 + render_size: 128 device: cuda reward_classifier: diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 3e0dbe61..d23c0017 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -31,7 +31,7 @@ training: online_env_seed: 10000 online_buffer_capacity: 1000000 online_buffer_seed_size: 0 - online_step_before_learning: 5000 + online_step_before_learning: 500 do_online_rollout_async: false policy_update_freq: 1 @@ -52,19 +52,16 @@ policy: n_action_steps: 1 shared_encoder: true - vision_encoder_name: null - # vision_encoder_name: "helper2424/resnet10" - # freeze_vision_encoder: true - freeze_vision_encoder: false + vision_encoder_name: "helper2424/resnet10" + freeze_vision_encoder: true input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] - observation.image: [3, 64, 64] - observation.image.2: [3, 64, 64] + observation.image: [3, 128, 128] output_shapes: action: [7] - camera_number: 2 + camera_number: 1 # Normalization / Unnormalization input_normalization_modes: null diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index f0c6f2a9..64091883 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -217,7 +217,7 @@ def learner_service_client( { "name": [{}], # Applies to ALL methods in ALL services "retryPolicy": { - "maxAttempts": 5, # Max retries (total attempts = 5) + "maxAttempts": 7, # Max retries (total attempts = 5) "initialBackoff": "0.1s", # First retry after 0.1s "maxBackoff": "2s", # Max wait time between retries "backoffMultiplier": 2, # Exponential backoff factor diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2d00e7ed..e46681f9 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -169,6 +169,25 @@ def initialize_replay_buffer( ) +def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = ( + policy.actor.encoder(observations) + if policy.actor.encoder is not None + else None + ) + next_observation_features = ( + policy.actor.encoder(next_observations) + if policy.actor.encoder is not None + else None + ) + + return observation_features, next_observation_features + + def start_learner_threads( cfg: DictConfig, device: str, @@ -345,9 +364,6 @@ def add_actor_information_and_train( if len(replay_buffer) < cfg.training.online_step_before_learning: continue - # logging.info(f"Size of replay buffer: {len(replay_buffer)}") - # logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}") - time_for_one_optimization_step = time.time() for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) @@ -356,6 +372,7 @@ def add_actor_information_and_train( batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions(batch, batch_offline) + actions = batch["action"] rewards = batch["reward"] observations = batch["state"] @@ -365,6 +382,7 @@ def add_actor_information_and_train( observations=observations, actions=actions, next_state=next_observations ) + observation_features, next_observation_features = get_observation_features(policy, observations, next_observations) with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -372,6 +390,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -395,6 +415,7 @@ def add_actor_information_and_train( observations=observations, actions=actions, next_state=next_observations ) + observation_features, next_observation_features = get_observation_features(policy, observations, next_observations) with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -402,6 +423,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -413,7 +436,8 @@ def add_actor_information_and_train( if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): with policy_lock: - loss_actor = policy.compute_loss_actor(observations=observations) + loss_actor = policy.compute_loss_actor(observations=observations, + observation_features=observation_features) optimizers["actor"].zero_grad() loss_actor.backward() @@ -422,7 +446,8 @@ def add_actor_information_and_train( training_infos["loss_actor"] = loss_actor.item() loss_temperature = policy.compute_loss_temperature( - observations=observations + observations=observations, + observation_features=observation_features ) optimizers["temperature"].zero_grad() loss_temperature.backward() diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py index 97601528..d6e6b5b7 100644 --- a/lerobot/scripts/server/learner_service.py +++ b/lerobot/scripts/server/learner_service.py @@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): def _get_policy_state(self): with self.policy_lock: params_dict = self.policy.actor.state_dict() - if self.policy.config.vision_encoder_name is not None: - if self.policy.config.freeze_vision_encoder: - params_dict: dict[str, torch.Tensor] = { - k: v - for k, v in params_dict.items() - if not k.startswith("encoder.") - } - else: - raise NotImplementedError( - "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." - ) + # if self.policy.config.vision_encoder_name is not None: + # if self.policy.config.freeze_vision_encoder: + # params_dict: dict[str, torch.Tensor] = { + # k: v + # for k, v in params_dict.items() + # if not k.startswith("encoder.") + # } + # else: + # raise NotImplementedError( + # "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." + # ) return move_state_dict_to_device(params_dict, device="cpu") diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index 105deeb4..e1c0840a 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) return_observations["observation.image"] = img - return_observations["observation.image.2"] = img return_observations["observation.state"] = state return return_observations diff --git a/pyproject.toml b/pyproject.toml index 3c7c9c96..10e0e86b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dependencies = [ "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", "termcolor>=2.4.0", + "tensordict>=0.0.1", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))", "torchmetrics>=1.6.0", @@ -88,7 +89,7 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"] +hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] mani_skill = ["mani-skill"] pi0 = ["transformers>=4.48.0"]