diff --git a/gym_dora/gym_dora/env.py b/gym_dora/gym_dora/env.py index 928ef27c..d0b406e3 100644 --- a/gym_dora/gym_dora/env.py +++ b/gym_dora/gym_dora/env.py @@ -1,53 +1,55 @@ import gymnasium as gym import numpy as np +from dora import Node +import os +import pyarrow as pa +IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640")) +IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480")) +FPS = int(os.getenv("FPS", "30")) class DoraEnv(gym.Env): - metadata = {"render_modes": ["rgb_array"], "render_fps": 30} + metadata = {"render_modes": ["rgb_array"], "render_fps": FPS} def __init__(self, model="aloha"): - ... - # TODO: add code to connect with dora client here + self.node = Node() + self.observation = {"pixels": {}, "terminated": False} + + def _update(self) -> dict: + while True: + event = self.node.next(timeout=0.001) + + if event is None: + self.observation["terminated"] = True + break + if event["type"] == "INPUT": + if "cam" in event["id"]: + self.observation["pixels"][event["id"]] = event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3) + else: + self.observation[event["id"]] = event["value"].to_numpy() + elif event["type"] == "ERROR": + break + def reset(self, seed: int | None = None): - ... - # TODO: same as `step` but doesn't take `actions` - observation = { - "pixels": { - "top": ..., - "bottom": ..., - "left": ..., - "right": ..., - }, - "agent_pos": ..., - # "agent_vel": ..., # will be added later - } + + self._update() + reward = 0 - terminated = truncated = False + terminated = truncated = self.observation["terminated"] info = {} - return observation, reward, terminated, truncated, info + return self.observation, reward, terminated, truncated, info def render(self): ... def step(self, action: np.ndarray): - ... - # TODO: this is the important bit: the data to be return by Dora to the policy. - observation = { - "pixels": { - "top": ..., - "bottom": ..., - "left": ..., - "right": ..., - }, - "agent_pos": ..., - # "agent_vel": ..., # will be added later - } + self._update() + self.node.send_output("action", pa.array(action)) reward = 0 - terminated = truncated = False + terminated = truncated = self.observation["terminated"] info = {} - return observation, reward, terminated, truncated, info + return self.observation, reward, terminated, truncated, info def close(self): - pass - # TODO: If code needs to be run when closing the env (e.g. shutting down Dora client), - # this is the place to do it. Otherwise this can stay as is. + del self.node + diff --git a/gym_dora/pyproject.toml b/gym_dora/pyproject.toml index 7cc210f9..d4dc9975 100644 --- a/gym_dora/pyproject.toml +++ b/gym_dora/pyproject.toml @@ -4,12 +4,13 @@ version = "0.1.0" description = "" authors = ["Simon Alibert "] readme = "README.md" -packages = [{include = "gym_dora"}] +packages = [{ include = "gym_dora" }] [tool.poetry.dependencies] python = "^3.10" gymnasium = ">=0.29.1" - +dora-rs = ">=0.3.4" +pyarrow = ">=12.0.0" [build-system] requires = ["poetry-core"]