hop
This commit is contained in:
parent
594acbf136
commit
6966f1257f
|
@ -35,7 +35,8 @@ Run the `record_training_data.py` example, selecting the duration and number of
|
|||
DATA_DIR='./data' python record_training_data.py \
|
||||
--repo-id=thomwolf/blue_red_sort \
|
||||
--num-episodes=50 \
|
||||
--num-frames=400
|
||||
--num-frames=400 \
|
||||
--gym-config=./train_config/env/gym_real_world.yaml
|
||||
```
|
||||
|
||||
TODO:
|
||||
|
|
|
@ -129,6 +129,7 @@ class RealEnv(gym.Env):
|
|||
self._observation = {}
|
||||
self._terminated = False
|
||||
self.timestamps = []
|
||||
self.observation_time = None
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.follower.read_position()
|
||||
|
|
|
@ -29,18 +29,10 @@ from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_vide
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
|
||||
parser.add_argument("--num-episodes", type=int, default=2)
|
||||
parser.add_argument("--num-frames", type=int, default=400)
|
||||
parser.add_argument("--num-workers", type=int, default=16)
|
||||
parser.add_argument("--keep-last", action="store_true")
|
||||
parser.add_argument("--data_dir", type=str, default=None)
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
||||
parser.add_argument(
|
||||
"--fps_tolerance",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Tolerance in fps for the recording before dropping episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||
)
|
||||
|
@ -50,11 +42,7 @@ args = parser.parse_args()
|
|||
|
||||
repo_id = args.repo_id
|
||||
num_episodes = args.num_episodes
|
||||
num_frames = args.num_frames
|
||||
revision = args.revision
|
||||
fps = args.fps
|
||||
fps_tolerance = args.fps_tolerance
|
||||
|
||||
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
|
||||
|
||||
# During data collection, frames are stored as png images in `images_dir`
|
||||
|
@ -64,8 +52,8 @@ videos_dir = out_data / "videos"
|
|||
meta_data_dir = out_data / "meta_data"
|
||||
|
||||
gym_config = None
|
||||
if args.config is not None:
|
||||
gym_config = OmegaConf.load(args.config)
|
||||
if args.gym_config is not None:
|
||||
gym_config = OmegaConf.load(args.gym_config)
|
||||
|
||||
# Create image and video directories
|
||||
if not os.path.exists(images_dir):
|
||||
|
@ -76,12 +64,9 @@ if not os.path.exists(videos_dir):
|
|||
if __name__ == "__main__":
|
||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||
gym_handle = "gym_real_world/RealEnv-v0"
|
||||
gym_kwargs = {}
|
||||
if gym_config is not None:
|
||||
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
|
||||
env = gym.make(
|
||||
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
|
||||
)
|
||||
gym_kwargs = OmegaConf.to_container(gym_config.env.gym)
|
||||
env = gym.make(gym_handle, disable_env_checker=True, record=True, **gym_kwargs)
|
||||
num_frames = gym_kwargs["max_episode_steps"]
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: RealEnv-v0
|
||||
real_world: true
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
real_world: true
|
||||
gym:
|
||||
cameras_shapes:
|
||||
images.high: [480, 640, 3]
|
||||
|
@ -17,3 +14,8 @@ env:
|
|||
cameras_ports:
|
||||
images.high: /dev/video6
|
||||
images.low: /dev/video0
|
||||
num_joints: 6
|
||||
fps: 30
|
||||
max_episode_steps: 200
|
||||
fps_tolerance: 0.5
|
||||
mock: false
|
||||
|
|
|
@ -353,7 +353,7 @@ class ACT(nn.Module):
|
|||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
cam_features = self.backbone(images[:, cam_index])
|
||||
cam_features = cam_features["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
|
|
|
@ -93,12 +93,12 @@ def update_policy(
|
|||
use_amp: bool = False,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
################## TODO remove this part
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
##################
|
||||
# ################## TODO remove this part
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# # torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# ##################
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
|
|
Loading…
Reference in New Issue