hop
This commit is contained in:
parent
594acbf136
commit
6966f1257f
|
@ -35,7 +35,8 @@ Run the `record_training_data.py` example, selecting the duration and number of
|
||||||
DATA_DIR='./data' python record_training_data.py \
|
DATA_DIR='./data' python record_training_data.py \
|
||||||
--repo-id=thomwolf/blue_red_sort \
|
--repo-id=thomwolf/blue_red_sort \
|
||||||
--num-episodes=50 \
|
--num-episodes=50 \
|
||||||
--num-frames=400
|
--num-frames=400 \
|
||||||
|
--gym-config=./train_config/env/gym_real_world.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
|
|
|
@ -129,6 +129,7 @@ class RealEnv(gym.Env):
|
||||||
self._observation = {}
|
self._observation = {}
|
||||||
self._terminated = False
|
self._terminated = False
|
||||||
self.timestamps = []
|
self.timestamps = []
|
||||||
|
self.observation_time = None
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
qpos = self.follower.read_position()
|
qpos = self.follower.read_position()
|
||||||
|
|
|
@ -29,18 +29,10 @@ from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_vide
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
|
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
|
||||||
parser.add_argument("--num-episodes", type=int, default=2)
|
parser.add_argument("--num-episodes", type=int, default=2)
|
||||||
parser.add_argument("--num-frames", type=int, default=400)
|
|
||||||
parser.add_argument("--num-workers", type=int, default=16)
|
parser.add_argument("--num-workers", type=int, default=16)
|
||||||
parser.add_argument("--keep-last", action="store_true")
|
parser.add_argument("--keep-last", action="store_true")
|
||||||
parser.add_argument("--data_dir", type=str, default=None)
|
parser.add_argument("--data_dir", type=str, default=None)
|
||||||
parser.add_argument("--push-to-hub", action="store_true")
|
parser.add_argument("--push-to-hub", action="store_true")
|
||||||
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--fps_tolerance",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Tolerance in fps for the recording before dropping episodes.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||||
)
|
)
|
||||||
|
@ -50,11 +42,7 @@ args = parser.parse_args()
|
||||||
|
|
||||||
repo_id = args.repo_id
|
repo_id = args.repo_id
|
||||||
num_episodes = args.num_episodes
|
num_episodes = args.num_episodes
|
||||||
num_frames = args.num_frames
|
|
||||||
revision = args.revision
|
revision = args.revision
|
||||||
fps = args.fps
|
|
||||||
fps_tolerance = args.fps_tolerance
|
|
||||||
|
|
||||||
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
|
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
|
||||||
|
|
||||||
# During data collection, frames are stored as png images in `images_dir`
|
# During data collection, frames are stored as png images in `images_dir`
|
||||||
|
@ -64,8 +52,8 @@ videos_dir = out_data / "videos"
|
||||||
meta_data_dir = out_data / "meta_data"
|
meta_data_dir = out_data / "meta_data"
|
||||||
|
|
||||||
gym_config = None
|
gym_config = None
|
||||||
if args.config is not None:
|
if args.gym_config is not None:
|
||||||
gym_config = OmegaConf.load(args.config)
|
gym_config = OmegaConf.load(args.gym_config)
|
||||||
|
|
||||||
# Create image and video directories
|
# Create image and video directories
|
||||||
if not os.path.exists(images_dir):
|
if not os.path.exists(images_dir):
|
||||||
|
@ -76,12 +64,9 @@ if not os.path.exists(videos_dir):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||||
gym_handle = "gym_real_world/RealEnv-v0"
|
gym_handle = "gym_real_world/RealEnv-v0"
|
||||||
gym_kwargs = {}
|
gym_kwargs = OmegaConf.to_container(gym_config.env.gym)
|
||||||
if gym_config is not None:
|
env = gym.make(gym_handle, disable_env_checker=True, record=True, **gym_kwargs)
|
||||||
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
|
num_frames = gym_kwargs["max_episode_steps"]
|
||||||
env = gym.make(
|
|
||||||
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
|
|
||||||
)
|
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
fps: 30
|
fps: 30
|
||||||
|
|
||||||
env:
|
env:
|
||||||
name: real_world
|
name: real_world
|
||||||
task: RealEnv-v0
|
task: RealEnv-v0
|
||||||
|
real_world: true
|
||||||
state_dim: 6
|
state_dim: 6
|
||||||
action_dim: 6
|
action_dim: 6
|
||||||
fps: ${fps}
|
|
||||||
episode_length: 200
|
|
||||||
real_world: true
|
|
||||||
gym:
|
gym:
|
||||||
cameras_shapes:
|
cameras_shapes:
|
||||||
images.high: [480, 640, 3]
|
images.high: [480, 640, 3]
|
||||||
|
@ -17,3 +14,8 @@ env:
|
||||||
cameras_ports:
|
cameras_ports:
|
||||||
images.high: /dev/video6
|
images.high: /dev/video6
|
||||||
images.low: /dev/video0
|
images.low: /dev/video0
|
||||||
|
num_joints: 6
|
||||||
|
fps: 30
|
||||||
|
max_episode_steps: 200
|
||||||
|
fps_tolerance: 0.5
|
||||||
|
mock: false
|
||||||
|
|
|
@ -353,7 +353,7 @@ class ACT(nn.Module):
|
||||||
images = batch["observation.images"]
|
images = batch["observation.images"]
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
cam_features = self.backbone(images[:, cam_index])
|
cam_features = self.backbone(images[:, cam_index])
|
||||||
cam_features = cam_features["feature_map"]
|
cam_features = cam_features["feature_map"]
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
|
|
|
@ -93,12 +93,12 @@ def update_policy(
|
||||||
use_amp: bool = False,
|
use_amp: bool = False,
|
||||||
):
|
):
|
||||||
"""Returns a dictionary of items for logging."""
|
"""Returns a dictionary of items for logging."""
|
||||||
################## TODO remove this part
|
# ################## TODO remove this part
|
||||||
torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
# torch.use_deterministic_algorithms(True)
|
# # torch.use_deterministic_algorithms(True)
|
||||||
torch.backends.cudnn.benchmark = False
|
# torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
##################
|
# ##################
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
device = get_device_from_parameters(policy)
|
device = get_device_from_parameters(policy)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
Loading…
Reference in New Issue