add warnings, merge to main branch

This commit is contained in:
jayLEE0301 2024-05-22 18:15:45 -04:00
parent 1d689512af
commit d8f8fa5918
5 changed files with 9 additions and 55 deletions

View File

@ -84,7 +84,7 @@ class Logger:
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group.replace(":", "_").replace("/", "_") + "-" + str(self._seed) + "-" + str(identifier),
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)

View File

@ -71,7 +71,6 @@ class VQBeTConfig:
# Inputs / output structure.
n_obs_steps: int = 5
# n_action_steps: int = 7
n_action_pred_token: int = 3
n_action_pred_chunk: int = 5

View File

@ -8,7 +8,7 @@ from random import randrange
import math
from math import ceil
from dataclasses import dataclass
import warnings
import einops
from einops import rearrange, repeat, reduce, pack, unpack
@ -97,9 +97,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True
# raise NotImplementedError(
# "Should train VQ-VAE before rollout."
# )
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.')
assert "observation.image" in batch
assert "observation.state" in batch
@ -195,7 +193,7 @@ class VQBeTModel(nn.Module):
], dim=-2).view(batch_size, -1, self.config.n_embd)
if img_features.shape[1] != n_obs_steps:
raise NotImplementedError
# eos_token = self._eos_token.repeat(batch_size, 1, 1)
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token
len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
@ -205,7 +203,7 @@ class VQBeTModel(nn.Module):
# get action features
features = self._policy(global_cond)
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO make it compatible with other values
features = torch.cat([
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:]

View File

@ -24,7 +24,7 @@ override_dataset_stats:
training:
offline_steps: 800000
online_steps: 0
eval_freq: 20000 # jay
eval_freq: 20000
save_freq: 20000
log_freq: 250
save_model: true
@ -41,7 +41,7 @@ training:
# VQ-BeT specific
vqvae_lr: 1.0e-3
discretize_step: 20000 # jay
discretize_step: 20000
bet_weight_decay: 2e-4
bet_learning_rate: 5.5e-5
bet_betas: [0.9, 0.999]
@ -60,7 +60,6 @@ policy:
# Input / output structure.
n_obs_steps: 5
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
n_action_pred_token: 7
n_action_pred_chunk: 5

View File

@ -217,7 +217,6 @@ def eval_policy(
"""
start = time.time()
policy.eval()
max_episodes_rendered = 20
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
# divisible by env.num_envs we end up discarding some data in the last batch.
@ -230,7 +229,6 @@ def eval_policy(
all_seeds = []
threads = [] # for video saving threads
n_episodes_rendered = 0 # for saving the correct number of videos
all_coverages = []
# Callback for visualization.
def render_frame(env: gym.vector.VectorEnv):
@ -282,7 +280,6 @@ def eval_policy(
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
all_coverages.extend((rollout_data['reward'][:, -1]*0.95).tolist())
all_seeds.extend(seeds)
if return_episode_data:
@ -335,9 +332,7 @@ def eval_policy(
if n_episodes_rendered >= max_episodes_rendered:
break
video_dir.mkdir(parents=True, exist_ok=True)
final_cover = "%.3f" % all_coverages[n_episodes_rendered]
max_cover = "%.3f" % ((max_rewards[n_episodes_rendered] * 0.95))
video_path = video_dir / f"eval_episode_{max_cover}_{final_cover}.mp4"
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=write_video,
@ -358,44 +353,7 @@ def eval_policy(
# Wait till all video rendering threads are done.
for thread in threads:
thread.join()
import matplotlib.pyplot as plt
plt.hist(all_coverages, bins=100)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "final_coverage_histogram100.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
max_coverage = [0.95 * value for value in max_rewards]
plt.hist(max_coverage, bins=100)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "max_coverage_histogram100.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
plt.hist(all_coverages, bins=50)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "final_coverage_histogram.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
max_coverage = [0.95 * value for value in max_rewards]
plt.hist(max_coverage, bins=50)
plt.xlabel('Value')
plt.ylabel('Quantity')
plt.title('Histogram of Data')
plt.ylim(0, 500)
fig_path = video_dir / "max_coverage_histogram.png"
plt.savefig(fig_path) # Save the plot as a PNG file
plt.close()
# Compile eval info.
info = {
"per_episode": [
@ -643,4 +601,4 @@ if __name__ == "__main__":
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)