Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.

Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-21 10:13:43 +00:00 committed by AdilZouitine
parent d48161da1b
commit ff223c106d
8 changed files with 66 additions and 42 deletions

View File

@ -153,7 +153,7 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
) -> Tensor:
"""Forward pass through a critic network ensemble
@ -166,7 +166,7 @@ class SACPolicy(
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions)
q_values = critics(observations, actions, observation_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
@ -180,14 +180,14 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@ -204,7 +204,7 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@ -219,20 +219,20 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module):
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
@ -441,9 +442,10 @@ class Policy(nn.Module):
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
# Get network outputs
outputs = self.network(obs_enc)

View File

@ -5,14 +5,14 @@ fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 64
image_size: 128
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 64
render_size: 128
device: cuda
reward_classifier:

View File

@ -31,7 +31,7 @@ training:
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 5000
online_step_before_learning: 500
do_online_rollout_async: false
policy_update_freq: 1
@ -52,19 +52,16 @@ policy:
n_action_steps: 1
shared_encoder: true
vision_encoder_name: null
# vision_encoder_name: "helper2424/resnet10"
# freeze_vision_encoder: true
freeze_vision_encoder: false
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 64, 64]
observation.image.2: [3, 64, 64]
observation.image: [3, 128, 128]
output_shapes:
action: [7]
camera_number: 2
camera_number: 1
# Normalization / Unnormalization
input_normalization_modes: null

View File

@ -217,7 +217,7 @@ def learner_service_client(
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"maxAttempts": 7, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor

View File

@ -169,6 +169,25 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
return observation_features, next_observation_features
def start_learner_threads(
cfg: DictConfig,
device: str,
@ -345,9 +364,6 @@ def add_actor_information_and_train(
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@ -356,6 +372,7 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@ -365,6 +382,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -372,6 +390,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@ -395,6 +415,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -402,6 +423,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@ -413,7 +436,8 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@ -422,7 +446,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations
observations=observations,
observation_features=observation_features
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()

View File

@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
if self.policy.config.vision_encoder_name is not None:
if self.policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v
for k, v in params_dict.items()
if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
# if self.policy.config.vision_encoder_name is not None:
# if self.policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v
# for k, v in params_dict.items()
# if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
return move_state_dict_to_device(params_dict, device="cpu")

View File

@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.image.2"] = img
return_observations["observation.state"] = state
return return_observations

View File

@ -71,6 +71,7 @@ dependencies = [
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"tensordict>=0.0.1",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
"torchmetrics>=1.6.0",
@ -88,7 +89,7 @@ dora = [
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
mani_skill = ["mani-skill"]
pi0 = ["transformers>=4.48.0"]