diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 22cd0116..bd14e6d8 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -48,8 +48,15 @@ class AlohaEnv(gym.Env): dtype=np.float64, ) elif self.obs_type == "pixels": - self.observation_space = spaces.Box( - low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8 + self.observation_space = spaces.Dict( + { + "top": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + } ) elif self.obs_type == "pixels_agent_pos": self.observation_space = spaces.Dict(