Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies Updated the version of torch and torchvision Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
d48161da1b
commit
ff223c106d
|
@ -153,7 +153,7 @@ class SACPolicy(
|
|||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
|
@ -166,7 +166,7 @@ class SACPolicy(
|
|||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions)
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
|
@ -180,14 +180,14 @@ class SACPolicy(
|
|||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
|
@ -204,7 +204,7 @@ class SACPolicy(
|
|||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
|
@ -219,20 +219,20 @@ class SACPolicy(
|
|||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
|
@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module):
|
|||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
|
@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module):
|
|||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||
|
@ -441,9 +442,10 @@ class Policy(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
|
|
|
@ -5,14 +5,14 @@ fps: 20
|
|||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 64
|
||||
image_size: 128
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 64
|
||||
render_size: 128
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
|
|
|
@ -31,7 +31,7 @@ training:
|
|||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 5000
|
||||
online_step_before_learning: 500
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
|
@ -52,19 +52,16 @@ policy:
|
|||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: null
|
||||
# vision_encoder_name: "helper2424/resnet10"
|
||||
# freeze_vision_encoder: true
|
||||
freeze_vision_encoder: false
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 64, 64]
|
||||
observation.image.2: [3, 64, 64]
|
||||
observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
camera_number: 2
|
||||
camera_number: 1
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
|
|
|
@ -217,7 +217,7 @@ def learner_service_client(
|
|||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"maxAttempts": 7, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
|
|
|
@ -169,6 +169,25 @@ def initialize_replay_buffer(
|
|||
)
|
||||
|
||||
|
||||
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
|
@ -345,9 +364,6 @@ def add_actor_information_and_train(
|
|||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
@ -356,6 +372,7 @@ def add_actor_information_and_train(
|
|||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
|
@ -365,6 +382,7 @@ def add_actor_information_and_train(
|
|||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -372,6 +390,8 @@ def add_actor_information_and_train(
|
|||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
@ -395,6 +415,7 @@ def add_actor_information_and_train(
|
|||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -402,6 +423,8 @@ def add_actor_information_and_train(
|
|||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
@ -413,7 +436,8 @@ def add_actor_information_and_train(
|
|||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
loss_actor = policy.compute_loss_actor(observations=observations,
|
||||
observation_features=observation_features)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
@ -422,7 +446,8 @@ def add_actor_information_and_train(
|
|||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
observations=observations,
|
||||
observation_features=observation_features
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
|
|
@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||
def _get_policy_state(self):
|
||||
with self.policy_lock:
|
||||
params_dict = self.policy.actor.state_dict()
|
||||
if self.policy.config.vision_encoder_name is not None:
|
||||
if self.policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v
|
||||
for k, v in params_dict.items()
|
||||
if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
# if self.policy.config.vision_encoder_name is not None:
|
||||
# if self.policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v
|
||||
# for k, v in params_dict.items()
|
||||
# if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
return move_state_dict_to_device(params_dict, device="cpu")
|
||||
|
||||
|
|
|
@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
|||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.image.2"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ dependencies = [
|
|||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"tensordict>=0.0.1",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||
"torchmetrics>=1.6.0",
|
||||
|
@ -88,7 +89,7 @@ dora = [
|
|||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
mani_skill = ["mani-skill"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
|
|
Loading…
Reference in New Issue