diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 82e3a507..4bbbe78a 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -9,7 +9,7 @@ class TDMPCConfig: camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`. + Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. Args: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 1fba43d0..a421754f 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -298,8 +298,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) return G - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Run the batch through the model and compute the loss.""" + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: + """Run the batch through the model and compute the loss. + + Returns a dictionary with loss as a tensor, and scalar valued + """ device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 098b0396..a18c13d3 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -8,7 +8,7 @@ env: from_pixels: True pixels_only: False image_size: 84 - episode_length: 25 + episode_length: 100 fps: ${fps} state_dim: 4 action_dim: 4 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 2351a976..43e841eb 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -4,10 +4,12 @@ seed: 1 dataset_repo_id: lerobot/xarm_lift_medium_replay training: - offline_steps: 25000 - online_steps: 25000 + offline_steps: 50000 + online_steps: 50000 eval_freq: 5000 - online_steps_between_rollouts: 1 + # This approximately matches the FOWM implementation. There though, they do as many steps as there were + # steps in the last sampled episode. TODO(now): hmmmm + online_steps_between_rollouts: 25 online_sampling_ratio: 0.5 online_env_seed: 10000 dataset_use_cache: true diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 751181ea..7f0d36eb 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -10,6 +10,7 @@ from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -100,6 +101,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): "lr": optimizer.param_groups[0]["lr"], "update_s": time.time() - start_time, } + info.update({k: v for k, v in output_dict.items() if k not in info}) return info @@ -213,78 +215,80 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): return -(n_off * pc_on) / (n_on * (pc_on - 1)) +# TODO(now): Should probably be unit tested. def add_episodes_inplace( - online_dataset: torch.utils.data.Dataset, + online_dataset: LeRobotDataset, concat_dataset: torch.utils.data.ConcatDataset, sampler: torch.utils.data.WeightedRandomSampler, - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - pc_online_samples: float, + new_hf_dataset: datasets.Dataset, + new_episode_data_index: dict[str, torch.Tensor], + online_sampling_ratio: float, ): """ Modifies the online_dataset, concat_dataset, and sampler in place by integrating - new episodes from hf_dataset into the online_dataset, updating the concatenated + new episodes from new_hf_dataset into the online_dataset, updating the concatenated dataset's structure and adjusting the sampling strategy based on the specified percentage of online samples. - Parameters: - - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - offline and online datasets, used for sampling purposes. - - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - reflect changes in the dataset sizes and specified sampling weights. - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - pc_online_samples (float): The target percentage of samples that should come from - the online dataset during sampling operations. - - Raises: - - AssertionError: If the first episode_id or index in hf_dataset is not 0 + Args: + online_dataset: The existing online dataset to be updated. + concat_dataset: The concatenated dataset that combines offline and online datasets (in that order), + used for sampling purposes. + sampler: A sampler that will be updated to reflect changes in the dataset sizes and specified sampling + weights. + new_hf_dataset: A Hugging Face dataset containing the new episodes to be added. + new_episode_data_index: A dictionary containing two keys ("from" and "to") associated to dataset + indices. They indicate the start index and end index of each episode in the dataset. + online_sampling_ratio: The target percentage of samples that should come from the online dataset + during sampling operations. """ - first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() - last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() - first_index = hf_dataset.select_columns("index")[0]["index"].item() - last_index = hf_dataset.select_columns("index")[-1]["index"].item() - # sanity check - assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" - assert first_index == 0, f"{first_index=} is not 0" - assert first_index == episode_data_index["from"][first_episode_idx].item() - assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 + # Sanity check to make sure that new_hf_dataset starts from 0. + assert new_hf_dataset["episode_index"][0].item() == 0 + assert new_hf_dataset["index"][0].item() == 0 + # Sanity check to make sure that new_episode_data_index is aligned with new_hf_dataset. + assert new_episode_data_index["from"][0].item() == 0 + assert new_episode_data_index["to"] - 1 == new_hf_dataset["index"][-1].item() if len(online_dataset) == 0: - # initialize online dataset - online_dataset.hf_dataset = hf_dataset - online_dataset.episode_data_index = episode_data_index - else: - # get the starting indices of the new episodes and frames to be added - start_episode_idx = last_episode_idx + 1 - start_index = last_index + 1 + # Initialize online dataset. + online_dataset.hf_dataset = new_hf_dataset + online_dataset.episode_data_index = new_episode_data_index + if len(online_dataset) > 0: + # Get the indices required to continue where the data in concat_dataset finishes. + start_episode_idx = concat_dataset.datasets[-1].hf_dataset["episode_index"][-1].item() + 1 + start_index = concat_dataset.datasets[-1].hf_dataset["index"][-1].item() + 1 - def shift_indices(episode_index, index): - # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} - return example - - disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) + # Shift the indices of new_hf_dataset. + disable_progress_bars() # Dataset.map has a tqdm progress bar + # note: we dont shift "frame_index" since it represents the index of the frame in the episode it + # belongs to + new_hf_dataset = new_hf_dataset.map( + lambda episode_index, data_index: { + "episode_index": episode_index + start_episode_idx, + "index": data_index + start_index, + }, + input_columns=["episode_index", "index"], + ) enable_progress_bars() - episode_data_index["from"] += start_index - episode_data_index["to"] += start_index - - # extend online dataset - online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) + # Extend the online dataset with the new data. + online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, new_hf_dataset]) + online_dataset.episode_data_index = { + k: torch.cat([online_dataset.episode_data_index[k], new_episode_data_index[k] + start_index]) + for k in ["from", "to"] + } # update the concatenated dataset length used during sampling concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - # update the sampling weights for each frame so that online frames get sampled a certain percentage of times + # update the sampling weights for each frame so that online frames get sampled a certain percentage of + # times len_online = len(online_dataset) len_offline = len(concat_dataset) - len_online - weight_offline = 1.0 - weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) - sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + sampler.weights = torch.tensor( + [(1 - online_sampling_ratio) / len_offline] * len_offline + + [online_sampling_ratio / len_online] * len_online + ) # update the total number of samples used during sampling sampler.num_samples = len(concat_dataset) @@ -405,8 +409,10 @@ def train(cfg: dict, out_dir=None, job_name=None): # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) + # TODO(now): Consolidate the reset. online_dataset.hf_dataset = {} online_dataset.episode_data_index = {} + online_dataset.cache = {} # create dataloader for online training concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) @@ -416,8 +422,7 @@ def train(cfg: dict, out_dir=None, job_name=None): ) dataloader = torch.utils.data.DataLoader( concat_dataset, - num_workers=cfg.training.dataloader_num_workers, - persistent_workers=cfg.training.dataloader_persistent_workers, + num_workers=0, batch_size=cfg.training.batch_size, sampler=sampler, pin_memory=cfg.device != "cpu", @@ -427,8 +432,8 @@ def train(cfg: dict, out_dir=None, job_name=None): online_step = 0 is_offline = False - for env_step in range(cfg.training.online_steps): - if env_step == 0: + for online_step in range(cfg.training.online_steps): + if online_step == 0: logging.info("Start online training by interacting with environment") policy.eval() @@ -439,16 +444,16 @@ def train(cfg: dict, out_dir=None, job_name=None): n_episodes=1, return_episode_data=True, start_seed=cfg.training.online_env_seed, - enable_progbar=True, + enable_progbar=False, ) add_episodes_inplace( online_dataset, concat_dataset, sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.training.online_sampling_ratio, + new_hf_dataset=eval_info["episodes"]["hf_dataset"], + new_episode_data_index=eval_info["episodes"]["episode_data_index"], + online_sampling_ratio=cfg.training.online_sampling_ratio, ) policy.train()