add temporary eval part, optimizer
This commit is contained in:
parent
44edc6f648
commit
02d55c0b9a
|
@ -84,7 +84,7 @@ class Logger:
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
self._group.replace(":", "_").replace("/", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||||
type="model",
|
type="model",
|
||||||
)
|
)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
|
|
|
@ -97,9 +97,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self._obs_queues = populate_queues(self._obs_queues, batch)
|
self._obs_queues = populate_queues(self._obs_queues, batch)
|
||||||
|
|
||||||
if not self.check_discretized():
|
if not self.check_discretized():
|
||||||
raise NotImplementedError(
|
self.vqbet._action_head._vqvae_model.discretized = True
|
||||||
"Should train VQ-VAE before rollout."
|
# raise NotImplementedError(
|
||||||
)
|
# "Should train VQ-VAE before rollout."
|
||||||
|
# )
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
|
|
||||||
|
@ -490,6 +491,14 @@ class VQBeTHead(nn.Module):
|
||||||
cbet_logits[:, 1, :],
|
cbet_logits[:, 1, :],
|
||||||
action_bins[:, 1],
|
action_bins[:, 1],
|
||||||
)
|
)
|
||||||
|
# cbet_loss3 = self._criterion( # F.cross_entropy
|
||||||
|
# cbet_logits[:, 2, :],
|
||||||
|
# action_bins[:, 2],
|
||||||
|
# )
|
||||||
|
# cbet_loss4 = self._criterion( # F.cross_entropy
|
||||||
|
# cbet_logits[:, 3, :],
|
||||||
|
# action_bins[:, 3],
|
||||||
|
# )
|
||||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
|
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
|
||||||
|
|
||||||
equal_total_code_rate = (
|
equal_total_code_rate = (
|
||||||
|
@ -527,6 +536,140 @@ class VQBeTOptimizer:
|
||||||
self.offline_steps = cfg.training.offline_steps
|
self.offline_steps = cfg.training.offline_steps
|
||||||
self.optimizing_step = 0
|
self.optimizing_step = 0
|
||||||
|
|
||||||
|
# Option 1
|
||||||
|
|
||||||
|
# vqvae_params = (
|
||||||
|
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||||
|
# )
|
||||||
|
# self.vqvae_optimizer = torch.optim.Adam(
|
||||||
|
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.encoder_optimizer = torch.optim.Adam(
|
||||||
|
# policy.vqbet.parameters(),
|
||||||
|
# cfg.training.lr,
|
||||||
|
# cfg.training.adam_betas,
|
||||||
|
# cfg.training.adam_eps,
|
||||||
|
# cfg.training.adam_weight_decay,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||||
|
# weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
# learning_rate=cfg.training.bet_learning_rate,
|
||||||
|
# betas=cfg.training.bet_betas,
|
||||||
|
# )
|
||||||
|
# if policy.vqbet._action_head.sequentially_select:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||||
|
# )
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer2 = torch.optim.AdamW(
|
||||||
|
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||||
|
# lr=cfg.training.bet_learning_rate,
|
||||||
|
# weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
# betas=cfg.training.bet_betas,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Option 2
|
||||||
|
|
||||||
|
# vqvae_params = (
|
||||||
|
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||||
|
# )
|
||||||
|
# self.vqvae_optimizer = torch.optim.Adam(
|
||||||
|
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.encoder_optimizer = torch.optim.Adam(
|
||||||
|
# policy.vqbet.parameters(),
|
||||||
|
# cfg.training.lr,
|
||||||
|
# cfg.training.adam_betas,
|
||||||
|
# cfg.training.adam_eps,
|
||||||
|
# cfg.training.adam_weight_decay,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||||
|
# weight_decay=cfg.training.adam_weight_decay,
|
||||||
|
# learning_rate=cfg.training.lr,
|
||||||
|
# betas=cfg.training.adam_betas,
|
||||||
|
# optimizer = "Adam",
|
||||||
|
# eps=cfg.training.adam_eps,
|
||||||
|
# )
|
||||||
|
# if policy.vqbet._action_head.sequentially_select:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||||
|
# )
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer2 = torch.optim.Adam(
|
||||||
|
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||||
|
# cfg.training.lr,
|
||||||
|
# cfg.training.adam_betas,
|
||||||
|
# cfg.training.adam_eps,
|
||||||
|
# cfg.training.adam_weight_decay,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Option 3
|
||||||
|
|
||||||
|
# vqvae_params = (
|
||||||
|
# list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||||
|
# + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||||
|
# )
|
||||||
|
# self.vqvae_optimizer = torch.optim.Adam(
|
||||||
|
# vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.encoder_optimizer = torch.optim.AdamW(
|
||||||
|
# policy.vqbet.parameters(),
|
||||||
|
# lr=cfg.training.bet_learning_rate,
|
||||||
|
# weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
# betas=cfg.training.bet_betas,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||||
|
# weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
# learning_rate=cfg.training.bet_learning_rate,
|
||||||
|
# betas=cfg.training.bet_betas,
|
||||||
|
# )
|
||||||
|
# if policy.vqbet._action_head.sequentially_select:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||||
|
# )
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.bet_optimizer1.add_param_group(
|
||||||
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.bet_optimizer2 = torch.optim.AdamW(
|
||||||
|
# policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||||
|
# lr=cfg.training.bet_learning_rate,
|
||||||
|
# weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
# betas=cfg.training.bet_betas,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# Option 4
|
||||||
|
|
||||||
vqvae_params = (
|
vqvae_params = (
|
||||||
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||||
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||||
|
@ -549,17 +692,24 @@ class VQBeTOptimizer:
|
||||||
learning_rate=cfg.training.bet_learning_rate,
|
learning_rate=cfg.training.bet_learning_rate,
|
||||||
betas=cfg.training.bet_betas,
|
betas=cfg.training.bet_betas,
|
||||||
)
|
)
|
||||||
if policy.vqbet._action_head.sequentially_select:
|
# if policy.vqbet._action_head.sequentially_select:
|
||||||
self.bet_optimizer1.add_param_group(
|
# self.bet_optimizer1.add_param_group(
|
||||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()}
|
||||||
)
|
# )
|
||||||
self.bet_optimizer1.add_param_group(
|
# self.bet_optimizer1.add_param_group(
|
||||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()}
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
self.bet_optimizer1.add_param_group(
|
# self.bet_optimizer1.add_param_group(
|
||||||
{"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
# {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()}
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
self.bet_optimizer0 = torch.optim.AdamW(
|
||||||
|
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
|
||||||
|
lr=cfg.training.bet_learning_rate,
|
||||||
|
weight_decay=cfg.training.bet_weight_decay,
|
||||||
|
betas=cfg.training.bet_betas,
|
||||||
|
)
|
||||||
|
|
||||||
self.bet_optimizer2 = torch.optim.AdamW(
|
self.bet_optimizer2 = torch.optim.AdamW(
|
||||||
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||||
|
@ -580,6 +730,7 @@ class VQBeTOptimizer:
|
||||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||||
self.encoder_optimizer.step()
|
self.encoder_optimizer.step()
|
||||||
self.bet_optimizer1.step()
|
self.bet_optimizer1.step()
|
||||||
|
self.bet_optimizer0.step()
|
||||||
self.bet_optimizer2.step()
|
self.bet_optimizer2.step()
|
||||||
else:
|
else:
|
||||||
self.bet_optimizer2.step()
|
self.bet_optimizer2.step()
|
||||||
|
@ -593,6 +744,7 @@ class VQBeTOptimizer:
|
||||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||||
self.encoder_optimizer.zero_grad()
|
self.encoder_optimizer.zero_grad()
|
||||||
self.bet_optimizer1.zero_grad()
|
self.bet_optimizer1.zero_grad()
|
||||||
|
self.bet_optimizer0.zero_grad()
|
||||||
self.bet_optimizer2.zero_grad()
|
self.bet_optimizer2.zero_grad()
|
||||||
else:
|
else:
|
||||||
self.bet_optimizer2.zero_grad()
|
self.bet_optimizer2.zero_grad()
|
||||||
|
@ -604,17 +756,33 @@ class VQBeTScheduler:
|
||||||
self.offline_steps = cfg.training.offline_steps
|
self.offline_steps = cfg.training.offline_steps
|
||||||
self.optimizing_step = 0
|
self.optimizing_step = 0
|
||||||
|
|
||||||
self.lr_scheduler = get_scheduler(
|
self.lr_scheduler1 = get_scheduler(
|
||||||
cfg.training.lr_scheduler,
|
cfg.training.lr_scheduler,
|
||||||
optimizer=optimizer.encoder_optimizer,
|
optimizer=optimizer.encoder_optimizer,
|
||||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
num_training_steps=cfg.training.offline_steps,
|
num_training_steps=cfg.training.offline_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# self.lr_scheduler2 = get_scheduler(
|
||||||
|
# cfg.training.lr_scheduler,
|
||||||
|
# optimizer=optimizer.bet_optimizer1,
|
||||||
|
# num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
# num_training_steps=cfg.training.offline_steps,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.lr_scheduler3 = get_scheduler(
|
||||||
|
# cfg.training.lr_scheduler,
|
||||||
|
# optimizer=optimizer.bet_optimizer2,
|
||||||
|
# num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
# num_training_steps=cfg.training.offline_steps,
|
||||||
|
# )
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.optimizing_step +=1
|
self.optimizing_step +=1
|
||||||
if self.optimizing_step >= self.discretize_step:
|
if self.optimizing_step >= self.discretize_step:
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler1.step()
|
||||||
|
# self.lr_scheduler2.step()
|
||||||
|
# self.lr_scheduler3.step()
|
||||||
|
|
||||||
class DiffusionRgbEncoder(nn.Module):
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
@ -2848,7 +3016,7 @@ class GPT(nn.Module):
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
||||||
|
|
||||||
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
|
||||||
"""
|
"""
|
||||||
This long function is unfortunately doing something very simple and is being very defensive:
|
This long function is unfortunately doing something very simple and is being very defensive:
|
||||||
We are separating out all parameters of the model into two buckets: those that will experience
|
We are separating out all parameters of the model into two buckets: those that will experience
|
||||||
|
@ -2898,5 +3066,10 @@ class GPT(nn.Module):
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
if optimizer=="Adamw":
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
elif optimizer=="Adam":
|
||||||
|
optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
|
@ -217,6 +217,7 @@ def eval_policy(
|
||||||
"""
|
"""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
max_episodes_rendered = 20
|
||||||
|
|
||||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||||
|
@ -229,6 +230,7 @@ def eval_policy(
|
||||||
all_seeds = []
|
all_seeds = []
|
||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||||
|
all_coverages = []
|
||||||
|
|
||||||
# Callback for visualization.
|
# Callback for visualization.
|
||||||
def render_frame(env: gym.vector.VectorEnv):
|
def render_frame(env: gym.vector.VectorEnv):
|
||||||
|
@ -280,6 +282,7 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_rewards.tolist())
|
max_rewards.extend(batch_max_rewards.tolist())
|
||||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||||
all_successes.extend(batch_successes.tolist())
|
all_successes.extend(batch_successes.tolist())
|
||||||
|
all_coverages.extend((rollout_data['reward'][:, -1]*0.95).tolist())
|
||||||
all_seeds.extend(seeds)
|
all_seeds.extend(seeds)
|
||||||
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
@ -332,7 +335,9 @@ def eval_policy(
|
||||||
if n_episodes_rendered >= max_episodes_rendered:
|
if n_episodes_rendered >= max_episodes_rendered:
|
||||||
break
|
break
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
final_cover = "%.3f" % all_coverages[n_episodes_rendered]
|
||||||
|
max_cover = "%.3f" % ((max_rewards[n_episodes_rendered] * 0.95))
|
||||||
|
video_path = video_dir / f"eval_episode_{max_cover}_{final_cover}.mp4"
|
||||||
video_paths.append(str(video_path))
|
video_paths.append(str(video_path))
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=write_video,
|
target=write_video,
|
||||||
|
@ -353,7 +358,44 @@ def eval_policy(
|
||||||
# Wait till all video rendering threads are done.
|
# Wait till all video rendering threads are done.
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.hist(all_coverages, bins=100)
|
||||||
|
plt.xlabel('Value')
|
||||||
|
plt.ylabel('Quantity')
|
||||||
|
plt.title('Histogram of Data')
|
||||||
|
plt.ylim(0, 500)
|
||||||
|
fig_path = video_dir / "final_coverage_histogram100.png"
|
||||||
|
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
max_coverage = [0.95 * value for value in max_rewards]
|
||||||
|
plt.hist(max_coverage, bins=100)
|
||||||
|
plt.xlabel('Value')
|
||||||
|
plt.ylabel('Quantity')
|
||||||
|
plt.title('Histogram of Data')
|
||||||
|
plt.ylim(0, 500)
|
||||||
|
fig_path = video_dir / "max_coverage_histogram100.png"
|
||||||
|
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
plt.hist(all_coverages, bins=50)
|
||||||
|
plt.xlabel('Value')
|
||||||
|
plt.ylabel('Quantity')
|
||||||
|
plt.title('Histogram of Data')
|
||||||
|
plt.ylim(0, 500)
|
||||||
|
fig_path = video_dir / "final_coverage_histogram.png"
|
||||||
|
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
max_coverage = [0.95 * value for value in max_rewards]
|
||||||
|
plt.hist(max_coverage, bins=50)
|
||||||
|
plt.xlabel('Value')
|
||||||
|
plt.ylabel('Quantity')
|
||||||
|
plt.title('Histogram of Data')
|
||||||
|
plt.ylim(0, 500)
|
||||||
|
fig_path = video_dir / "max_coverage_histogram.png"
|
||||||
|
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||||
|
plt.close()
|
||||||
# Compile eval info.
|
# Compile eval info.
|
||||||
info = {
|
info = {
|
||||||
"per_episode": [
|
"per_episode": [
|
||||||
|
|
Loading…
Reference in New Issue