add temporary eval part, optimizer

This commit is contained in:
jayLEE0301 2024-05-13 20:31:06 -04:00
parent 44edc6f648
commit 02d55c0b9a
3 changed files with 235 additions and 20 deletions

View File

@ -84,7 +84,7 @@ class Logger:
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
self._group.replace(":", "_").replace("/", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)

View File

@ -97,9 +97,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self._obs_queues = populate_queues(self._obs_queues, batch)
if not self.check_discretized():
raise NotImplementedError(
"Should train VQ-VAE before rollout."
)
self.vqbet._action_head._vqvae_model.discretized = True
# raise NotImplementedError(
# "Should train VQ-VAE before rollout."
# )
assert "observation.image" in batch
assert "observation.state" in batch
@ -490,6 +491,14 @@ class VQBeTHead(nn.Module):
cbet_logits[:, 1, :],
action_bins[:, 1],
)
# cbet_loss3 = self._criterion( # F.cross_entropy
# cbet_logits[:, 2, :],
# action_bins[:, 2],
# )
# cbet_loss4 = self._criterion( # F.cross_entropy
# cbet_logits[:, 3, :],
# action_bins[:, 3],
# )
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
equal_total_code_rate = (
@ -527,6 +536,140 @@ class VQBeTOptimizer:
self.offline_steps = cfg.training.offline_steps
self.optimizing_step = 0
# Option 1
# vqvae_params = (
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
# )
# self.vqvae_optimizer = torch.optim.Adam(
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
# )
# self.encoder_optimizer = torch.optim.Adam(
# policy.vqbet.parameters(),
# cfg.training.lr,
# cfg.training.adam_betas,
# cfg.training.adam_eps,
# cfg.training.adam_weight_decay,
# )
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
# weight_decay=cfg.training.bet_weight_decay,
# learning_rate=cfg.training.bet_learning_rate,
# betas=cfg.training.bet_betas,
# )
# if policy.vqbet._action_head.sequentially_select:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
# )
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
# )
# else:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
# )
# self.bet_optimizer2 = torch.optim.AdamW(
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
# lr=cfg.training.bet_learning_rate,
# weight_decay=cfg.training.bet_weight_decay,
# betas=cfg.training.bet_betas,
# )
# Option 2
# vqvae_params = (
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
# )
# self.vqvae_optimizer = torch.optim.Adam(
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
# )
# self.encoder_optimizer = torch.optim.Adam(
# policy.vqbet.parameters(),
# cfg.training.lr,
# cfg.training.adam_betas,
# cfg.training.adam_eps,
# cfg.training.adam_weight_decay,
# )
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
# weight_decay=cfg.training.adam_weight_decay,
# learning_rate=cfg.training.lr,
# betas=cfg.training.adam_betas,
# optimizer = "Adam",
# eps=cfg.training.adam_eps,
# )
# if policy.vqbet._action_head.sequentially_select:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
# )
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
# )
# else:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
# )
# self.bet_optimizer2 = torch.optim.Adam(
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
# cfg.training.lr,
# cfg.training.adam_betas,
# cfg.training.adam_eps,
# cfg.training.adam_weight_decay,
# )
# Option 3
# vqvae_params = (
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
# )
# self.vqvae_optimizer = torch.optim.Adam(
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
# )
# self.encoder_optimizer = torch.optim.AdamW(
# policy.vqbet.parameters(),
# lr=cfg.training.bet_learning_rate,
# weight_decay=cfg.training.bet_weight_decay,
# betas=cfg.training.bet_betas,
# )
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
# weight_decay=cfg.training.bet_weight_decay,
# learning_rate=cfg.training.bet_learning_rate,
# betas=cfg.training.bet_betas,
# )
# if policy.vqbet._action_head.sequentially_select:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
# )
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
# )
# else:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
# )
# self.bet_optimizer2 = torch.optim.AdamW(
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
# lr=cfg.training.bet_learning_rate,
# weight_decay=cfg.training.bet_weight_decay,
# betas=cfg.training.bet_betas,
# )
# Option 4
vqvae_params = (
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
@ -549,17 +692,24 @@ class VQBeTOptimizer:
learning_rate=cfg.training.bet_learning_rate,
betas=cfg.training.bet_betas,
)
if policy.vqbet._action_head.sequentially_select:
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
)
else:
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
)
# if policy.vqbet._action_head.sequentially_select:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
# )
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
# )
# else:
# self.bet_optimizer1.add_param_group(
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
# )
self.bet_optimizer0 = torch.optim.AdamW(
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
lr=cfg.training.bet_learning_rate,
weight_decay=cfg.training.bet_weight_decay,
betas=cfg.training.bet_betas,
)
self.bet_optimizer2 = torch.optim.AdamW(
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
@ -580,6 +730,7 @@ class VQBeTOptimizer:
if self.optimizing_step < 0.6 * self.offline_steps:
self.encoder_optimizer.step()
self.bet_optimizer1.step()
self.bet_optimizer0.step()
self.bet_optimizer2.step()
else:
self.bet_optimizer2.step()
@ -593,6 +744,7 @@ class VQBeTOptimizer:
if self.optimizing_step < 0.6 * self.offline_steps:
self.encoder_optimizer.zero_grad()
self.bet_optimizer1.zero_grad()
self.bet_optimizer0.zero_grad()
self.bet_optimizer2.zero_grad()
else:
self.bet_optimizer2.zero_grad()
@ -604,17 +756,33 @@ class VQBeTScheduler:
self.offline_steps = cfg.training.offline_steps
self.optimizing_step = 0
self.lr_scheduler = get_scheduler(
self.lr_scheduler1 = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer.encoder_optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
# self.lr_scheduler2 = get_scheduler(
# cfg.training.lr_scheduler,
# optimizer=optimizer.bet_optimizer1,
# num_warmup_steps=cfg.training.lr_warmup_steps,
# num_training_steps=cfg.training.offline_steps,
# )
# self.lr_scheduler3 = get_scheduler(
# cfg.training.lr_scheduler,
# optimizer=optimizer.bet_optimizer2,
# num_warmup_steps=cfg.training.lr_warmup_steps,
# num_training_steps=cfg.training.offline_steps,
# )
def step(self):
self.optimizing_step +=1
if self.optimizing_step >= self.discretize_step:
self.lr_scheduler.step()
self.lr_scheduler1.step()
# self.lr_scheduler2.step()
# self.lr_scheduler3.step()
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
@ -2848,7 +3016,7 @@ class GPT(nn.Module):
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
def configure_optimizers(self, weight_decay, learning_rate, betas):
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
@ -2898,5 +3066,10 @@ class GPT(nn.Module):
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
if optimizer=="Adamw":
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
elif optimizer=="Adam":
optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
else:
raise NotImplementedError
return optimizer

View File

@ -217,6 +217,7 @@ def eval_policy(
"""
start = time.time()
policy.eval()
max_episodes_rendered = 20
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
# divisible by env.num_envs we end up discarding some data in the last batch.
@ -229,6 +230,7 @@ def eval_policy(
all_seeds = []
threads = [] # for video saving threads
n_episodes_rendered = 0 # for saving the correct number of videos
all_coverages = []
# Callback for visualization.
def render_frame(env: gym.vector.VectorEnv):
@ -280,6 +282,7 @@ def eval_policy(
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
all_coverages.extend((rollout_data['reward'][:, -1]*0.95).tolist())
all_seeds.extend(seeds)
if return_episode_data:
@ -332,7 +335,9 @@ def eval_policy(
if n_episodes_rendered >= max_episodes_rendered:
break
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
final_cover = "%.3f" % all_coverages[n_episodes_rendered]
max_cover = "%.3f" % ((max_rewards[n_episodes_rendered] * 0.95))
video_path = video_dir / f"eval_episode_{max_cover}_{final_cover}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=write_video,
@ -353,7 +358,44 @@ def eval_policy(
# Wait till all video rendering threads are done.
for thread in threads:
thread.join()
import matplotlib.pyplot as plt
plt.hist(all_coverages, bins=100)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "final_coverage_histogram100.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
max_coverage = [0.95 * value for value in max_rewards]
plt.hist(max_coverage, bins=100)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "max_coverage_histogram100.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
plt.hist(all_coverages, bins=50)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "final_coverage_histogram.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
max_coverage = [0.95 * value for value in max_rewards]
plt.hist(max_coverage, bins=50)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "max_coverage_histogram.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
# Compile eval info.
info = {
"per_episode": [