Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation
This commit is contained in:
commit
7be2c35c0a
|
@ -70,7 +70,7 @@ python lerobot/scripts/train.py policy=act env=aloha
|
||||||
|
|
||||||
There are two things to note here:
|
There are two things to note here:
|
||||||
- Config overrides are passed as `param_name=param_value`.
|
- Config overrides are passed as `param_name=param_value`.
|
||||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
|
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||||
|
|
||||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||||
|
|
||||||
|
|
|
@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
|
||||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.config.n_action_steps
|
end = start + self.config.n_action_steps
|
||||||
|
|
|
@ -147,7 +147,7 @@ class Normalize(nn.Module):
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
batch[key] = batch[key] * 2 - 1
|
batch[key] = batch[key] * 2 - 1
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -224,7 +224,8 @@ def main():
|
||||||
help=(
|
help=(
|
||||||
"Mode of viewing between 'local' or 'distant'. "
|
"Mode of viewing between 'local' or 'distant'. "
|
||||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
"'distant' creates a server on the distant machine where the data is stored. "
|
||||||
|
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -245,8 +246,8 @@ def main():
|
||||||
default=0,
|
default=0,
|
||||||
help=(
|
help=(
|
||||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||||
"It also deactivates the spawning of a viewer. ",
|
"It also deactivates the spawning of a viewer. "
|
||||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
Loading…
Reference in New Issue