add temporary eval part, optimizer
This commit is contained in:
parent
44edc6f648
commit
02d55c0b9a
|
@ -84,7 +84,7 @@ class Logger:
|
|||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
self._group.replace(":", "_").replace("/", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
|
|
|
@ -97,9 +97,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self._obs_queues = populate_queues(self._obs_queues, batch)
|
||||
|
||||
if not self.check_discretized():
|
||||
raise NotImplementedError(
|
||||
"Should train VQ-VAE before rollout."
|
||||
)
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
# raise NotImplementedError(
|
||||
# "Should train VQ-VAE before rollout."
|
||||
# )
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
|
@ -490,6 +491,14 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
# cbet_loss3 = self._criterion( # F.cross_entropy
|
||||
# cbet_logits[:, 2, :],
|
||||
# action_bins[:, 2],
|
||||
# )
|
||||
# cbet_loss4 = self._criterion( # F.cross_entropy
|
||||
# cbet_logits[:, 3, :],
|
||||
# action_bins[:, 3],
|
||||
# )
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
|
||||
|
||||
equal_total_code_rate = (
|
||||
|
@ -527,6 +536,140 @@ class VQBeTOptimizer:
|
|||
self.offline_steps = cfg.training.offline_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
# Option 1
|
||||
|
||||
# vqvae_params = (
|
||||
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||
# )
|
||||
# self.vqvae_optimizer = torch.optim.Adam(
|
||||
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||
# )
|
||||
|
||||
# self.encoder_optimizer = torch.optim.Adam(
|
||||
# policy.vqbet.parameters(),
|
||||
# cfg.training.lr,
|
||||
# cfg.training.adam_betas,
|
||||
# cfg.training.adam_eps,
|
||||
# cfg.training.adam_weight_decay,
|
||||
# )
|
||||
|
||||
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||
# weight_decay=cfg.training.bet_weight_decay,
|
||||
# learning_rate=cfg.training.bet_learning_rate,
|
||||
# betas=cfg.training.bet_betas,
|
||||
# )
|
||||
# if policy.vqbet._action_head.sequentially_select:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
# )
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||
# )
|
||||
# else:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||
# )
|
||||
|
||||
# self.bet_optimizer2 = torch.optim.AdamW(
|
||||
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||
# lr=cfg.training.bet_learning_rate,
|
||||
# weight_decay=cfg.training.bet_weight_decay,
|
||||
# betas=cfg.training.bet_betas,
|
||||
# )
|
||||
|
||||
# Option 2
|
||||
|
||||
# vqvae_params = (
|
||||
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||
# )
|
||||
# self.vqvae_optimizer = torch.optim.Adam(
|
||||
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||
# )
|
||||
|
||||
# self.encoder_optimizer = torch.optim.Adam(
|
||||
# policy.vqbet.parameters(),
|
||||
# cfg.training.lr,
|
||||
# cfg.training.adam_betas,
|
||||
# cfg.training.adam_eps,
|
||||
# cfg.training.adam_weight_decay,
|
||||
# )
|
||||
|
||||
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||
# weight_decay=cfg.training.adam_weight_decay,
|
||||
# learning_rate=cfg.training.lr,
|
||||
# betas=cfg.training.adam_betas,
|
||||
# optimizer = "Adam",
|
||||
# eps=cfg.training.adam_eps,
|
||||
# )
|
||||
# if policy.vqbet._action_head.sequentially_select:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
# )
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||
# )
|
||||
# else:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||
# )
|
||||
|
||||
# self.bet_optimizer2 = torch.optim.Adam(
|
||||
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||
# cfg.training.lr,
|
||||
# cfg.training.adam_betas,
|
||||
# cfg.training.adam_eps,
|
||||
# cfg.training.adam_weight_decay,
|
||||
# )
|
||||
|
||||
# Option 3
|
||||
|
||||
# vqvae_params = (
|
||||
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||
# )
|
||||
# self.vqvae_optimizer = torch.optim.Adam(
|
||||
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||
# )
|
||||
|
||||
# self.encoder_optimizer = torch.optim.AdamW(
|
||||
# policy.vqbet.parameters(),
|
||||
# lr=cfg.training.bet_learning_rate,
|
||||
# weight_decay=cfg.training.bet_weight_decay,
|
||||
# betas=cfg.training.bet_betas,
|
||||
# )
|
||||
|
||||
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||
# weight_decay=cfg.training.bet_weight_decay,
|
||||
# learning_rate=cfg.training.bet_learning_rate,
|
||||
# betas=cfg.training.bet_betas,
|
||||
# )
|
||||
# if policy.vqbet._action_head.sequentially_select:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
# )
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||
# )
|
||||
# else:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||
# )
|
||||
|
||||
# self.bet_optimizer2 = torch.optim.AdamW(
|
||||
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||
# lr=cfg.training.bet_learning_rate,
|
||||
# weight_decay=cfg.training.bet_weight_decay,
|
||||
# betas=cfg.training.bet_betas,
|
||||
# )
|
||||
|
||||
|
||||
# Option 4
|
||||
|
||||
vqvae_params = (
|
||||
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||
|
@ -549,17 +692,24 @@ class VQBeTOptimizer:
|
|||
learning_rate=cfg.training.bet_learning_rate,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
if policy.vqbet._action_head.sequentially_select:
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||
)
|
||||
else:
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||
)
|
||||
# if policy.vqbet._action_head.sequentially_select:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||
# )
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||
# )
|
||||
# else:
|
||||
# self.bet_optimizer1.add_param_group(
|
||||
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||
# )
|
||||
|
||||
self.bet_optimizer0 = torch.optim.AdamW(
|
||||
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
|
||||
lr=cfg.training.bet_learning_rate,
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
|
||||
self.bet_optimizer2 = torch.optim.AdamW(
|
||||
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||
|
@ -580,6 +730,7 @@ class VQBeTOptimizer:
|
|||
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||
self.encoder_optimizer.step()
|
||||
self.bet_optimizer1.step()
|
||||
self.bet_optimizer0.step()
|
||||
self.bet_optimizer2.step()
|
||||
else:
|
||||
self.bet_optimizer2.step()
|
||||
|
@ -593,6 +744,7 @@ class VQBeTOptimizer:
|
|||
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||
self.encoder_optimizer.zero_grad()
|
||||
self.bet_optimizer1.zero_grad()
|
||||
self.bet_optimizer0.zero_grad()
|
||||
self.bet_optimizer2.zero_grad()
|
||||
else:
|
||||
self.bet_optimizer2.zero_grad()
|
||||
|
@ -604,17 +756,33 @@ class VQBeTScheduler:
|
|||
self.offline_steps = cfg.training.offline_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
self.lr_scheduler = get_scheduler(
|
||||
self.lr_scheduler1 = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer.encoder_optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
|
||||
# self.lr_scheduler2 = get_scheduler(
|
||||
# cfg.training.lr_scheduler,
|
||||
# optimizer=optimizer.bet_optimizer1,
|
||||
# num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
# num_training_steps=cfg.training.offline_steps,
|
||||
# )
|
||||
|
||||
# self.lr_scheduler3 = get_scheduler(
|
||||
# cfg.training.lr_scheduler,
|
||||
# optimizer=optimizer.bet_optimizer2,
|
||||
# num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
# num_training_steps=cfg.training.offline_steps,
|
||||
# )
|
||||
|
||||
def step(self):
|
||||
self.optimizing_step +=1
|
||||
if self.optimizing_step >= self.discretize_step:
|
||||
self.lr_scheduler.step()
|
||||
self.lr_scheduler1.step()
|
||||
# self.lr_scheduler2.step()
|
||||
# self.lr_scheduler3.step()
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
@ -2848,7 +3016,7 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
|
@ -2898,5 +3066,10 @@ class GPT(nn.Module):
|
|||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
if optimizer=="Adamw":
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
elif optimizer=="Adam":
|
||||
optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return optimizer
|
||||
|
|
|
@ -217,6 +217,7 @@ def eval_policy(
|
|||
"""
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
max_episodes_rendered = 20
|
||||
|
||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||
|
@ -229,6 +230,7 @@ def eval_policy(
|
|||
all_seeds = []
|
||||
threads = [] # for video saving threads
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
all_coverages = []
|
||||
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
|
@ -280,6 +282,7 @@ def eval_policy(
|
|||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
all_coverages.extend((rollout_data['reward'][:, -1]*0.95).tolist())
|
||||
all_seeds.extend(seeds)
|
||||
|
||||
if return_episode_data:
|
||||
|
@ -332,7 +335,9 @@ def eval_policy(
|
|||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
final_cover = "%.3f" % all_coverages[n_episodes_rendered]
|
||||
max_cover = "%.3f" % ((max_rewards[n_episodes_rendered] * 0.95))
|
||||
video_path = video_dir / f"eval_episode_{max_cover}_{final_cover}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
|
@ -353,7 +358,44 @@ def eval_policy(
|
|||
# Wait till all video rendering threads are done.
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
import matplotlib.pyplot as plt
|
||||
plt.hist(all_coverages, bins=100)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "final_coverage_histogram100.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
max_coverage = [0.95 * value for value in max_rewards]
|
||||
plt.hist(max_coverage, bins=100)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "max_coverage_histogram100.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
plt.hist(all_coverages, bins=50)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "final_coverage_histogram.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
max_coverage = [0.95 * value for value in max_rewards]
|
||||
plt.hist(max_coverage, bins=50)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "max_coverage_histogram.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
# Compile eval info.
|
||||
info = {
|
||||
"per_episode": [
|
||||
|
|
Loading…
Reference in New Issue