backup wip

This commit is contained in:
Alexander Soare 2024-05-08 11:31:54 +01:00
parent d7ffcc9127
commit d36ca387e8
5 changed files with 76 additions and 66 deletions

View File

@ -9,7 +9,7 @@ class TDMPCConfig:
camera observations. camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`. Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Args: Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google

View File

@ -298,8 +298,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.""" """Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and scalar valued
"""
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)

View File

@ -8,7 +8,7 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 84 image_size: 84
episode_length: 25 episode_length: 100
fps: ${fps} fps: ${fps}
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4

View File

@ -4,10 +4,12 @@ seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay dataset_repo_id: lerobot/xarm_lift_medium_replay
training: training:
offline_steps: 25000 offline_steps: 50000
online_steps: 25000 online_steps: 50000
eval_freq: 5000 eval_freq: 5000
online_steps_between_rollouts: 1 # This approximately matches the FOWM implementation. There though, they do as many steps as there were
# steps in the last sampled episode. TODO(now): hmmmm
online_steps_between_rollouts: 25
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5
online_env_seed: 10000 online_env_seed: 10000
dataset_use_cache: true dataset_use_cache: true

View File

@ -10,6 +10,7 @@ from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
@ -100,6 +101,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.time() - start_time,
} }
info.update({k: v for k, v in output_dict.items() if k not in info})
return info return info
@ -213,78 +215,80 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
return -(n_off * pc_on) / (n_on * (pc_on - 1)) return -(n_off * pc_on) / (n_on * (pc_on - 1))
# TODO(now): Should probably be unit tested.
def add_episodes_inplace( def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset, online_dataset: LeRobotDataset,
concat_dataset: torch.utils.data.ConcatDataset, concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler, sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset, new_hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor], new_episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float, online_sampling_ratio: float,
): ):
""" """
Modifies the online_dataset, concat_dataset, and sampler in place by integrating Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated new episodes from new_hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples. percentage of online samples.
Parameters: Args:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. online_dataset: The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines concat_dataset: The concatenated dataset that combines offline and online datasets (in that order),
offline and online datasets, used for sampling purposes. used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to sampler: A sampler that will be updated to reflect changes in the dataset sizes and specified sampling
reflect changes in the dataset sizes and specified sampling weights. weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. new_hf_dataset: A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. new_episode_data_index: A dictionary containing two keys ("from" and "to") associated to dataset
They indicate the start index and end index of each episode in the dataset. indices. They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from online_sampling_ratio: The target percentage of samples that should come from the online dataset
the online dataset during sampling operations. during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
""" """
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() # Sanity check to make sure that new_hf_dataset starts from 0.
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() assert new_hf_dataset["episode_index"][0].item() == 0
first_index = hf_dataset.select_columns("index")[0]["index"].item() assert new_hf_dataset["index"][0].item() == 0
last_index = hf_dataset.select_columns("index")[-1]["index"].item() # Sanity check to make sure that new_episode_data_index is aligned with new_hf_dataset.
# sanity check assert new_episode_data_index["from"][0].item() == 0
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" assert new_episode_data_index["to"] - 1 == new_hf_dataset["index"][-1].item()
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0: if len(online_dataset) == 0:
# initialize online dataset # Initialize online dataset.
online_dataset.hf_dataset = hf_dataset online_dataset.hf_dataset = new_hf_dataset
online_dataset.episode_data_index = episode_data_index online_dataset.episode_data_index = new_episode_data_index
else: if len(online_dataset) > 0:
# get the starting indices of the new episodes and frames to be added # Get the indices required to continue where the data in concat_dataset finishes.
start_episode_idx = last_episode_idx + 1 start_episode_idx = concat_dataset.datasets[-1].hf_dataset["episode_index"][-1].item() + 1
start_index = last_index + 1 start_index = concat_dataset.datasets[-1].hf_dataset["index"][-1].item() + 1
def shift_indices(episode_index, index): # Shift the indices of new_hf_dataset.
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to disable_progress_bars() # Dataset.map has a tqdm progress bar
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} # note: we dont shift "frame_index" since it represents the index of the frame in the episode it
return example # belongs to
new_hf_dataset = new_hf_dataset.map(
disable_progress_bars() # map has a tqdm progress bar lambda episode_index, data_index: {
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) "episode_index": episode_index + start_episode_idx,
"index": data_index + start_index,
},
input_columns=["episode_index", "index"],
)
enable_progress_bars() enable_progress_bars()
episode_data_index["from"] += start_index # Extend the online dataset with the new data.
episode_data_index["to"] += start_index online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, new_hf_dataset])
online_dataset.episode_data_index = {
# extend online dataset k: torch.cat([online_dataset.episode_data_index[k], new_episode_data_index[k] + start_index])
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) for k in ["from", "to"]
}
# update the concatenated dataset length used during sampling # update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times # update the sampling weights for each frame so that online frames get sampled a certain percentage of
# times
len_online = len(online_dataset) len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online len_offline = len(concat_dataset) - len_online
weight_offline = 1.0 sampler.weights = torch.tensor(
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) [(1 - online_sampling_ratio) / len_offline] * len_offline
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + [online_sampling_ratio / len_online] * len_online
)
# update the total number of samples used during sampling # update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset) sampler.num_samples = len(concat_dataset)
@ -405,8 +409,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset # create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset) online_dataset = deepcopy(offline_dataset)
# TODO(now): Consolidate the reset.
online_dataset.hf_dataset = {} online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {} online_dataset.episode_data_index = {}
online_dataset.cache = {}
# create dataloader for online training # create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
@ -416,8 +422,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
concat_dataset, concat_dataset,
num_workers=cfg.training.dataloader_num_workers, num_workers=0,
persistent_workers=cfg.training.dataloader_persistent_workers,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
@ -427,8 +432,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_step = 0 online_step = 0
is_offline = False is_offline = False
for env_step in range(cfg.training.online_steps): for online_step in range(cfg.training.online_steps):
if env_step == 0: if online_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
policy.eval() policy.eval()
@ -439,16 +444,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
n_episodes=1, n_episodes=1,
return_episode_data=True, return_episode_data=True,
start_seed=cfg.training.online_env_seed, start_seed=cfg.training.online_env_seed,
enable_progbar=True, enable_progbar=False,
) )
add_episodes_inplace( add_episodes_inplace(
online_dataset, online_dataset,
concat_dataset, concat_dataset,
sampler, sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"], new_hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"], new_episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio, online_sampling_ratio=cfg.training.online_sampling_ratio,
) )
policy.train() policy.train()