add warnings, merge to main branch
This commit is contained in:
parent
1d689512af
commit
d8f8fa5918
|
@ -84,7 +84,7 @@ class Logger:
|
|||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_").replace("/", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
|
|
|
@ -71,7 +71,6 @@ class VQBeTConfig:
|
|||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 5
|
||||
# n_action_steps: int = 7
|
||||
n_action_pred_token: int = 3
|
||||
n_action_pred_chunk: int = 5
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from random import randrange
|
|||
import math
|
||||
from math import ceil
|
||||
from dataclasses import dataclass
|
||||
|
||||
import warnings
|
||||
|
||||
import einops
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
|
@ -97,9 +97,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
if not self.check_discretized():
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
# raise NotImplementedError(
|
||||
# "Should train VQ-VAE before rollout."
|
||||
# )
|
||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.')
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
|
@ -195,7 +193,7 @@ class VQBeTModel(nn.Module):
|
|||
], dim=-2).view(batch_size, -1, self.config.n_embd)
|
||||
if img_features.shape[1] != n_obs_steps:
|
||||
raise NotImplementedError
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1)
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token
|
||||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
|
@ -205,7 +203,7 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# get action features
|
||||
features = self._policy(global_cond)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO make it compatible with other values
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:]
|
||||
|
|
|
@ -24,7 +24,7 @@ override_dataset_stats:
|
|||
training:
|
||||
offline_steps: 800000
|
||||
online_steps: 0
|
||||
eval_freq: 20000 # jay
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
@ -41,7 +41,7 @@ training:
|
|||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
discretize_step: 20000 # jay
|
||||
discretize_step: 20000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
@ -60,7 +60,6 @@ policy:
|
|||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
|
||||
n_action_pred_token: 7
|
||||
n_action_pred_chunk: 5
|
||||
|
||||
|
|
|
@ -217,7 +217,6 @@ def eval_policy(
|
|||
"""
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
max_episodes_rendered = 20
|
||||
|
||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||
|
@ -230,7 +229,6 @@ def eval_policy(
|
|||
all_seeds = []
|
||||
threads = [] # for video saving threads
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
all_coverages = []
|
||||
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
|
@ -282,7 +280,6 @@ def eval_policy(
|
|||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
all_coverages.extend((rollout_data['reward'][:, -1]*0.95).tolist())
|
||||
all_seeds.extend(seeds)
|
||||
|
||||
if return_episode_data:
|
||||
|
@ -335,9 +332,7 @@ def eval_policy(
|
|||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
final_cover = "%.3f" % all_coverages[n_episodes_rendered]
|
||||
max_cover = "%.3f" % ((max_rewards[n_episodes_rendered] * 0.95))
|
||||
video_path = video_dir / f"eval_episode_{max_cover}_{final_cover}.mp4"
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
|
@ -358,44 +353,7 @@ def eval_policy(
|
|||
# Wait till all video rendering threads are done.
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
import matplotlib.pyplot as plt
|
||||
plt.hist(all_coverages, bins=100)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "final_coverage_histogram100.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
max_coverage = [0.95 * value for value in max_rewards]
|
||||
plt.hist(max_coverage, bins=100)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "max_coverage_histogram100.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
plt.hist(all_coverages, bins=50)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "final_coverage_histogram.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
|
||||
max_coverage = [0.95 * value for value in max_rewards]
|
||||
plt.hist(max_coverage, bins=50)
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Quantity')
|
||||
plt.title('Histogram of Data')
|
||||
plt.ylim(0, 500)
|
||||
fig_path = video_dir / "max_coverage_histogram.png"
|
||||
plt.savefig(fig_path) # Save the plot as a PNG file
|
||||
plt.close()
|
||||
# Compile eval info.
|
||||
info = {
|
||||
"per_episode": [
|
||||
|
@ -643,4 +601,4 @@ if __name__ == "__main__":
|
|||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
Loading…
Reference in New Issue