Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies Updated the version of torch and torchvision Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
304d7136df
commit
28361a1584
|
@ -153,7 +153,7 @@ class SACPolicy(
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def critic_forward(
|
def critic_forward(
|
||||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Forward pass through a critic network ensemble
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ class SACPolicy(
|
||||||
Tensor of Q-values from all critics
|
Tensor of Q-values from all critics
|
||||||
"""
|
"""
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
q_values = critics(observations, actions)
|
q_values = critics(observations, actions, observation_features)
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||||
|
@ -180,14 +180,14 @@ class SACPolicy(
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
temperature = self.log_alpha.exp().item()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.critic_forward(
|
q_targets = self.critic_forward(
|
||||||
observations=next_observations, actions=next_action_preds, use_target=True
|
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
|
||||||
)
|
)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
|
@ -204,7 +204,7 @@ class SACPolicy(
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
|
@ -219,20 +219,20 @@ class SACPolicy(
|
||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
def compute_loss_temperature(self, observations) -> Tensor:
|
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||||
"""Compute the temperature loss"""
|
"""Compute the temperature loss"""
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, log_probs, _ = self.actor(observations)
|
_, log_probs, _ = self.actor(observations, observation_features)
|
||||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||||
return temperature_loss
|
return temperature_loss
|
||||||
|
|
||||||
def compute_loss_actor(self, observations) -> Tensor:
|
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
|
|
||||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||||
|
@ -370,6 +370,7 @@ class CriticEnsemble(nn.Module):
|
||||||
self,
|
self,
|
||||||
observations: dict[str, torch.Tensor],
|
observations: dict[str, torch.Tensor],
|
||||||
actions: torch.Tensor,
|
actions: torch.Tensor,
|
||||||
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device
|
||||||
|
@ -380,7 +381,7 @@ class CriticEnsemble(nn.Module):
|
||||||
actions = self.output_normalization(actions)["action"]
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||||
|
@ -441,9 +442,10 @@ class Policy(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
|
|
@ -5,14 +5,14 @@ fps: 20
|
||||||
env:
|
env:
|
||||||
name: maniskill/pushcube
|
name: maniskill/pushcube
|
||||||
task: PushCube-v1
|
task: PushCube-v1
|
||||||
image_size: 64
|
image_size: 128
|
||||||
control_mode: pd_ee_delta_pose
|
control_mode: pd_ee_delta_pose
|
||||||
state_dim: 25
|
state_dim: 25
|
||||||
action_dim: 7
|
action_dim: 7
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
obs: rgb
|
obs: rgb
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
render_size: 64
|
render_size: 128
|
||||||
device: cuda
|
device: cuda
|
||||||
|
|
||||||
reward_classifier:
|
reward_classifier:
|
||||||
|
|
|
@ -31,7 +31,7 @@ training:
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
online_buffer_capacity: 1000000
|
online_buffer_capacity: 1000000
|
||||||
online_buffer_seed_size: 0
|
online_buffer_seed_size: 0
|
||||||
online_step_before_learning: 5000
|
online_step_before_learning: 500
|
||||||
do_online_rollout_async: false
|
do_online_rollout_async: false
|
||||||
policy_update_freq: 1
|
policy_update_freq: 1
|
||||||
|
|
||||||
|
@ -52,19 +52,16 @@ policy:
|
||||||
n_action_steps: 1
|
n_action_steps: 1
|
||||||
|
|
||||||
shared_encoder: true
|
shared_encoder: true
|
||||||
vision_encoder_name: null
|
vision_encoder_name: "helper2424/resnet10"
|
||||||
# vision_encoder_name: "helper2424/resnet10"
|
freeze_vision_encoder: true
|
||||||
# freeze_vision_encoder: true
|
|
||||||
freeze_vision_encoder: false
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
observation.image: [3, 64, 64]
|
observation.image: [3, 128, 128]
|
||||||
observation.image.2: [3, 64, 64]
|
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: [7]
|
action: [7]
|
||||||
|
|
||||||
camera_number: 2
|
camera_number: 1
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: null
|
input_normalization_modes: null
|
||||||
|
|
|
@ -217,7 +217,7 @@ def learner_service_client(
|
||||||
{
|
{
|
||||||
"name": [{}], # Applies to ALL methods in ALL services
|
"name": [{}], # Applies to ALL methods in ALL services
|
||||||
"retryPolicy": {
|
"retryPolicy": {
|
||||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
"maxAttempts": 7, # Max retries (total attempts = 5)
|
||||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||||
"maxBackoff": "2s", # Max wait time between retries
|
"maxBackoff": "2s", # Max wait time between retries
|
||||||
"backoffMultiplier": 2, # Exponential backoff factor
|
"backoffMultiplier": 2, # Exponential backoff factor
|
||||||
|
|
|
@ -169,6 +169,25 @@ def initialize_replay_buffer(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
observation_features = (
|
||||||
|
policy.actor.encoder(observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
next_observation_features = (
|
||||||
|
policy.actor.encoder(next_observations)
|
||||||
|
if policy.actor.encoder is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
||||||
def start_learner_threads(
|
def start_learner_threads(
|
||||||
cfg: DictConfig,
|
cfg: DictConfig,
|
||||||
device: str,
|
device: str,
|
||||||
|
@ -345,9 +364,6 @@ def add_actor_information_and_train(
|
||||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
|
||||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(cfg.policy.utd_ratio - 1):
|
for _ in range(cfg.policy.utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
@ -356,6 +372,7 @@ def add_actor_information_and_train(
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
|
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
|
@ -365,6 +382,7 @@ def add_actor_information_and_train(
|
||||||
observations=observations, actions=actions, next_state=next_observations
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -372,6 +390,8 @@ def add_actor_information_and_train(
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
|
observation_features=observation_features,
|
||||||
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
@ -395,6 +415,7 @@ def add_actor_information_and_train(
|
||||||
observations=observations, actions=actions, next_state=next_observations
|
observations=observations, actions=actions, next_state=next_observations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -402,6 +423,8 @@ def add_actor_information_and_train(
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
|
observation_features=observation_features,
|
||||||
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
@ -413,7 +436,8 @@ def add_actor_information_and_train(
|
||||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
loss_actor = policy.compute_loss_actor(observations=observations,
|
||||||
|
observation_features=observation_features)
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
|
@ -422,7 +446,8 @@ def add_actor_information_and_train(
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
observations=observations
|
observations=observations,
|
||||||
|
observation_features=observation_features
|
||||||
)
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
|
|
|
@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||||
def _get_policy_state(self):
|
def _get_policy_state(self):
|
||||||
with self.policy_lock:
|
with self.policy_lock:
|
||||||
params_dict = self.policy.actor.state_dict()
|
params_dict = self.policy.actor.state_dict()
|
||||||
if self.policy.config.vision_encoder_name is not None:
|
# if self.policy.config.vision_encoder_name is not None:
|
||||||
if self.policy.config.freeze_vision_encoder:
|
# if self.policy.config.freeze_vision_encoder:
|
||||||
params_dict: dict[str, torch.Tensor] = {
|
# params_dict: dict[str, torch.Tensor] = {
|
||||||
k: v
|
# k: v
|
||||||
for k, v in params_dict.items()
|
# for k, v in params_dict.items()
|
||||||
if not k.startswith("encoder.")
|
# if not k.startswith("encoder.")
|
||||||
}
|
# }
|
||||||
else:
|
# else:
|
||||||
raise NotImplementedError(
|
# raise NotImplementedError(
|
||||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||||
)
|
# )
|
||||||
|
|
||||||
return move_state_dict_to_device(params_dict, device="cpu")
|
return move_state_dict_to_device(params_dict, device="cpu")
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
||||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||||
|
|
||||||
return_observations["observation.image"] = img
|
return_observations["observation.image"] = img
|
||||||
return_observations["observation.image.2"] = img
|
|
||||||
return_observations["observation.state"] = state
|
return_observations["observation.state"] = state
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,7 @@ dependencies = [
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
|
"tensordict>=0.0.1",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||||
"torchmetrics>=1.6.0",
|
"torchmetrics>=1.6.0",
|
||||||
|
@ -88,7 +89,7 @@ dora = [
|
||||||
]
|
]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
mani_skill = ["mani-skill"]
|
mani_skill = ["mani-skill"]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
|
|
Loading…
Reference in New Issue