From 4b7ec81dde7c4c567bae2b0e70d7d1508f753863 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 14:49:41 +0000 Subject: [PATCH] remove abstracmethods, fix online training --- lerobot/common/envs/abstract.py | 19 ++++++------------- lerobot/common/policies/abstract.py | 7 +++---- lerobot/scripts/train.py | 6 ++++-- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 8d1a09de..a449e23f 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -1,4 +1,3 @@ -import abc from collections import deque from typing import Optional @@ -44,26 +43,20 @@ class AbstractEnv(EnvBase): raise NotImplementedError() # self._prev_action_queue = deque(maxlen=self.num_prev_action) - @abc.abstractmethod def render(self, mode="rgb_array", width=640, height=480): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _reset(self, tensordict: Optional[TensorDict] = None): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _step(self, tensordict: TensorDict): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_env(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_spec(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _set_seed(self, seed: Optional[int]): - raise NotImplementedError() + raise NotImplementedError("Abstract method") diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 1c300dbe..e9c331a0 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod from collections import deque import torch from torch import Tensor, nn -class AbstractPolicy(nn.Module, ABC): +class AbstractPolicy(nn.Module): """Base policy which all policies should be derived from. The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its @@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC): self.n_action_steps = n_action_steps self.clear_action_queue() - @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" + raise NotImplementedError("Abstract method") def save(self, fp): torch.save(self.state_dict(), fp) @@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC): d = torch.load(fp) self.load_state_dict(d) - @abstractmethod def select_actions(self, observation) -> Tensor: """Select an action (or trajectory of actions) based on an observation during rollout. If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ + raise NotImplementedError("Abstract method") def clear_action_queue(self): """This should be called whenever the environment is reset.""" diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ecd616d..242c77bc 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() + if cfg.online_steps > 0: + assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps" init_logging() @@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps + assert len(rollout) <= cfg.env.episode_length # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout)