diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 5587a6e4..70a5b505 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -70,7 +70,7 @@ python lerobot/scripts/train.py policy=act env=aloha There are two things to note here: - Config overrides are passed as `param_name=param_value`. -- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`. +- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`. _As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._ diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 273f4f75..e0482143 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -239,10 +239,8 @@ class DiffusionModel(nn.Module): global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond) - # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index d638c541..9b055f7e 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -147,7 +147,7 @@ class Normalize(nn.Module): assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min) + batch[key] = (batch[key] - min) / (max - min + 1e-8) # normalize to [-1, 1] batch[key] = batch[key] * 2 - 1 else: diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 58da6a47..b5f40d11 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -224,7 +224,8 @@ def main(): help=( "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "'distant' creates a server on the distant machine where the data is stored. " + "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." ), ) parser.add_argument( @@ -245,8 +246,8 @@ def main(): default=0, help=( "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + "It also deactivates the spawning of a viewer. " + "Visualize the data by running `rerun path/to/file.rrd` on your local machine." ), ) parser.add_argument(