This commit is contained in:
Thomas Wolf 2024-06-19 09:56:36 +02:00
parent 594acbf136
commit 6966f1257f
6 changed files with 21 additions and 32 deletions

View File

@ -35,7 +35,8 @@ Run the `record_training_data.py` example, selecting the duration and number of
DATA_DIR='./data' python record_training_data.py \ DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \ --repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \ --num-episodes=50 \
--num-frames=400 --num-frames=400 \
--gym-config=./train_config/env/gym_real_world.yaml
``` ```
TODO: TODO:

View File

@ -129,6 +129,7 @@ class RealEnv(gym.Env):
self._observation = {} self._observation = {}
self._terminated = False self._terminated = False
self.timestamps = [] self.timestamps = []
self.observation_time = None
def _get_obs(self): def _get_obs(self):
qpos = self.follower.read_position() qpos = self.follower.read_position()

View File

@ -29,18 +29,10 @@ from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_vide
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort") parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
parser.add_argument("--num-episodes", type=int, default=2) parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16) parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true") parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--data_dir", type=str, default=None) parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument(
"--fps_tolerance",
type=float,
default=0.5,
help="Tolerance in fps for the recording before dropping episodes.",
)
parser.add_argument( parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset." "--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
) )
@ -50,11 +42,7 @@ args = parser.parse_args()
repo_id = args.repo_id repo_id = args.repo_id
num_episodes = args.num_episodes num_episodes = args.num_episodes
num_frames = args.num_frames
revision = args.revision revision = args.revision
fps = args.fps
fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir) out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
# During data collection, frames are stored as png images in `images_dir` # During data collection, frames are stored as png images in `images_dir`
@ -64,8 +52,8 @@ videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data" meta_data_dir = out_data / "meta_data"
gym_config = None gym_config = None
if args.config is not None: if args.gym_config is not None:
gym_config = OmegaConf.load(args.config) gym_config = OmegaConf.load(args.gym_config)
# Create image and video directories # Create image and video directories
if not os.path.exists(images_dir): if not os.path.exists(images_dir):
@ -76,12 +64,9 @@ if not os.path.exists(videos_dir):
if __name__ == "__main__": if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py # Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0" gym_handle = "gym_real_world/RealEnv-v0"
gym_kwargs = {} gym_kwargs = OmegaConf.to_container(gym_config.env.gym)
if gym_config is not None: env = gym.make(gym_handle, disable_env_checker=True, record=True, **gym_kwargs)
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs) num_frames = gym_kwargs["max_episode_steps"]
env = gym.make(
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
)
ep_dicts = [] ep_dicts = []
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}

View File

@ -1,15 +1,12 @@
# @package _global_ # @package _global_
fps: 30 fps: 30
env: env:
name: real_world name: real_world
task: RealEnv-v0 task: RealEnv-v0
real_world: true
state_dim: 6 state_dim: 6
action_dim: 6 action_dim: 6
fps: ${fps}
episode_length: 200
real_world: true
gym: gym:
cameras_shapes: cameras_shapes:
images.high: [480, 640, 3] images.high: [480, 640, 3]
@ -17,3 +14,8 @@ env:
cameras_ports: cameras_ports:
images.high: /dev/video6 images.high: /dev/video6
images.low: /dev/video0 images.low: /dev/video0
num_joints: 6
fps: 30
max_episode_steps: 200
fps_tolerance: 0.5
mock: false

View File

@ -353,7 +353,7 @@ class ACT(nn.Module):
images = batch["observation.images"] images = batch["observation.images"]
for cam_index in range(images.shape[-4]): for cam_index in range(images.shape[-4]):
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
cam_features = self.backbone(images[:, cam_index]) cam_features = self.backbone(images[:, cam_index])
cam_features = cam_features["feature_map"] cam_features = cam_features["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer

View File

@ -93,12 +93,12 @@ def update_policy(
use_amp: bool = False, use_amp: bool = False,
): ):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
################## TODO remove this part # ################## TODO remove this part
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
# torch.use_deterministic_algorithms(True) # # torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_tf32 = True
################## # ##################
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy) device = get_device_from_parameters(policy)
policy.train() policy.train()