remove abstracmethods, fix online training

This commit is contained in:
Alexander Soare 2024-03-20 14:49:41 +00:00
parent 5332766a82
commit 4b7ec81dde
3 changed files with 13 additions and 19 deletions

View File

@ -1,4 +1,3 @@
import abc
from collections import deque from collections import deque
from typing import Optional from typing import Optional
@ -44,26 +43,20 @@ class AbstractEnv(EnvBase):
raise NotImplementedError() raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action) # self._prev_action_queue = deque(maxlen=self.num_prev_action)
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480): def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError() raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError() raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
raise NotImplementedError() raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _make_env(self): def _make_env(self):
raise NotImplementedError() raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _make_spec(self): def _make_spec(self):
raise NotImplementedError() raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]): def _set_seed(self, seed: Optional[int]):
raise NotImplementedError() raise NotImplementedError("Abstract method")

View File

@ -1,11 +1,10 @@
from abc import ABC, abstractmethod
from collections import deque from collections import deque
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
class AbstractPolicy(nn.Module, ABC): class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from. """Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC):
self.n_action_steps = n_action_steps self.n_action_steps = n_action_steps
self.clear_action_queue() self.clear_action_queue()
@abstractmethod
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm.""" """One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp): def save(self, fp):
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC):
d = torch.load(fp) d = torch.load(fp)
self.load_state_dict(d) self.load_state_dict(d)
@abstractmethod
def select_actions(self, observation) -> Tensor: def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout. """Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
""" """
raise NotImplementedError("Abstract method")
def clear_action_queue(self): def clear_action_queue(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""

View File

@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
raise NotImplementedError() raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
init_logging() init_logging()
@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1) # TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
max_steps=cfg.env.episode_length // cfg.n_action_steps, max_steps=cfg.env.episode_length,
policy=td_policy, policy=td_policy,
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps assert len(rollout) <= cfg.env.episode_length
# set same episode index for all time steps contained in this rollout # set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout) online_buffer.extend(rollout)