backup wip
This commit is contained in:
parent
d7ffcc9127
commit
d36ca387e8
|
@ -9,7 +9,7 @@ class TDMPCConfig:
|
||||||
camera observations.
|
camera observations.
|
||||||
|
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
|
|
|
@ -298,8 +298,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss."""
|
"""Run the batch through the model and compute the loss.
|
||||||
|
|
||||||
|
Returns a dictionary with loss as a tensor, and scalar valued
|
||||||
|
"""
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
|
@ -8,7 +8,7 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
episode_length: 25
|
episode_length: 100
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
|
|
@ -4,10 +4,12 @@ seed: 1
|
||||||
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 50000
|
||||||
online_steps: 25000
|
online_steps: 50000
|
||||||
eval_freq: 5000
|
eval_freq: 5000
|
||||||
online_steps_between_rollouts: 1
|
# This approximately matches the FOWM implementation. There though, they do as many steps as there were
|
||||||
|
# steps in the last sampled episode. TODO(now): hmmmm
|
||||||
|
online_steps_between_rollouts: 25
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
dataset_use_cache: true
|
dataset_use_cache: true
|
||||||
|
|
|
@ -10,6 +10,7 @@ from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
|
@ -100,6 +101,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
}
|
}
|
||||||
|
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
@ -213,78 +215,80 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(now): Should probably be unit tested.
|
||||||
def add_episodes_inplace(
|
def add_episodes_inplace(
|
||||||
online_dataset: torch.utils.data.Dataset,
|
online_dataset: LeRobotDataset,
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
concat_dataset: torch.utils.data.ConcatDataset,
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
sampler: torch.utils.data.WeightedRandomSampler,
|
||||||
hf_dataset: datasets.Dataset,
|
new_hf_dataset: datasets.Dataset,
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
new_episode_data_index: dict[str, torch.Tensor],
|
||||||
pc_online_samples: float,
|
online_sampling_ratio: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
new episodes from new_hf_dataset into the online_dataset, updating the concatenated
|
||||||
dataset's structure and adjusting the sampling strategy based on the specified
|
dataset's structure and adjusting the sampling strategy based on the specified
|
||||||
percentage of online samples.
|
percentage of online samples.
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
online_dataset: The existing online dataset to be updated.
|
||||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
concat_dataset: The concatenated dataset that combines offline and online datasets (in that order),
|
||||||
offline and online datasets, used for sampling purposes.
|
used for sampling purposes.
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
sampler: A sampler that will be updated to reflect changes in the dataset sizes and specified sampling
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
weights.
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
new_hf_dataset: A Hugging Face dataset containing the new episodes to be added.
|
||||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
new_episode_data_index: A dictionary containing two keys ("from" and "to") associated to dataset
|
||||||
They indicate the start index and end index of each episode in the dataset.
|
indices. They indicate the start index and end index of each episode in the dataset.
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
online_sampling_ratio: The target percentage of samples that should come from the online dataset
|
||||||
the online dataset during sampling operations.
|
during sampling operations.
|
||||||
|
|
||||||
Raises:
|
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
|
||||||
"""
|
"""
|
||||||
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
# Sanity check to make sure that new_hf_dataset starts from 0.
|
||||||
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
assert new_hf_dataset["episode_index"][0].item() == 0
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
assert new_hf_dataset["index"][0].item() == 0
|
||||||
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
# Sanity check to make sure that new_episode_data_index is aligned with new_hf_dataset.
|
||||||
# sanity check
|
assert new_episode_data_index["from"][0].item() == 0
|
||||||
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
assert new_episode_data_index["to"] - 1 == new_hf_dataset["index"][-1].item()
|
||||||
assert first_index == 0, f"{first_index=} is not 0"
|
|
||||||
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
|
||||||
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
|
||||||
|
|
||||||
if len(online_dataset) == 0:
|
if len(online_dataset) == 0:
|
||||||
# initialize online dataset
|
# Initialize online dataset.
|
||||||
online_dataset.hf_dataset = hf_dataset
|
online_dataset.hf_dataset = new_hf_dataset
|
||||||
online_dataset.episode_data_index = episode_data_index
|
online_dataset.episode_data_index = new_episode_data_index
|
||||||
else:
|
if len(online_dataset) > 0:
|
||||||
# get the starting indices of the new episodes and frames to be added
|
# Get the indices required to continue where the data in concat_dataset finishes.
|
||||||
start_episode_idx = last_episode_idx + 1
|
start_episode_idx = concat_dataset.datasets[-1].hf_dataset["episode_index"][-1].item() + 1
|
||||||
start_index = last_index + 1
|
start_index = concat_dataset.datasets[-1].hf_dataset["index"][-1].item() + 1
|
||||||
|
|
||||||
def shift_indices(episode_index, index):
|
# Shift the indices of new_hf_dataset.
|
||||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
disable_progress_bars() # Dataset.map has a tqdm progress bar
|
||||||
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it
|
||||||
return example
|
# belongs to
|
||||||
|
new_hf_dataset = new_hf_dataset.map(
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
lambda episode_index, data_index: {
|
||||||
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
"episode_index": episode_index + start_episode_idx,
|
||||||
|
"index": data_index + start_index,
|
||||||
|
},
|
||||||
|
input_columns=["episode_index", "index"],
|
||||||
|
)
|
||||||
enable_progress_bars()
|
enable_progress_bars()
|
||||||
|
|
||||||
episode_data_index["from"] += start_index
|
# Extend the online dataset with the new data.
|
||||||
episode_data_index["to"] += start_index
|
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, new_hf_dataset])
|
||||||
|
online_dataset.episode_data_index = {
|
||||||
# extend online dataset
|
k: torch.cat([online_dataset.episode_data_index[k], new_episode_data_index[k] + start_index])
|
||||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
for k in ["from", "to"]
|
||||||
|
}
|
||||||
|
|
||||||
# update the concatenated dataset length used during sampling
|
# update the concatenated dataset length used during sampling
|
||||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
|
||||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
# update the sampling weights for each frame so that online frames get sampled a certain percentage of
|
||||||
|
# times
|
||||||
len_online = len(online_dataset)
|
len_online = len(online_dataset)
|
||||||
len_offline = len(concat_dataset) - len_online
|
len_offline = len(concat_dataset) - len_online
|
||||||
weight_offline = 1.0
|
sampler.weights = torch.tensor(
|
||||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
[(1 - online_sampling_ratio) / len_offline] * len_offline
|
||||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
+ [online_sampling_ratio / len_online] * len_online
|
||||||
|
)
|
||||||
|
|
||||||
# update the total number of samples used during sampling
|
# update the total number of samples used during sampling
|
||||||
sampler.num_samples = len(concat_dataset)
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
@ -405,8 +409,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
|
# TODO(now): Consolidate the reset.
|
||||||
online_dataset.hf_dataset = {}
|
online_dataset.hf_dataset = {}
|
||||||
online_dataset.episode_data_index = {}
|
online_dataset.episode_data_index = {}
|
||||||
|
online_dataset.cache = {}
|
||||||
|
|
||||||
# create dataloader for online training
|
# create dataloader for online training
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
@ -416,8 +422,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
num_workers=cfg.training.dataloader_num_workers,
|
num_workers=0,
|
||||||
persistent_workers=cfg.training.dataloader_persistent_workers,
|
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
@ -427,8 +432,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
online_step = 0
|
online_step = 0
|
||||||
is_offline = False
|
is_offline = False
|
||||||
for env_step in range(cfg.training.online_steps):
|
for online_step in range(cfg.training.online_steps):
|
||||||
if env_step == 0:
|
if online_step == 0:
|
||||||
logging.info("Start online training by interacting with environment")
|
logging.info("Start online training by interacting with environment")
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
@ -439,16 +444,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
n_episodes=1,
|
n_episodes=1,
|
||||||
return_episode_data=True,
|
return_episode_data=True,
|
||||||
start_seed=cfg.training.online_env_seed,
|
start_seed=cfg.training.online_env_seed,
|
||||||
enable_progbar=True,
|
enable_progbar=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
online_dataset,
|
online_dataset,
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
sampler,
|
sampler,
|
||||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
new_hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
new_episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
Loading…
Reference in New Issue