remove abstracmethods, fix online training
This commit is contained in:
parent
5332766a82
commit
4b7ec81dde
|
@ -1,4 +1,3 @@
|
||||||
import abc
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -44,26 +43,20 @@ class AbstractEnv(EnvBase):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def render(self, mode="rgb_array", width=640, height=480):
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _make_spec(self):
|
def _make_spec(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module, ABC):
|
class AbstractPolicy(nn.Module):
|
||||||
"""Base policy which all policies should be derived from.
|
"""Base policy which all policies should be derived from.
|
||||||
|
|
||||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
|
@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC):
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.clear_action_queue()
|
self.clear_action_queue()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
"""One step of the policy's learning algorithm."""
|
"""One step of the policy's learning algorithm."""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC):
|
||||||
d = torch.load(fp)
|
d = torch.load(fp)
|
||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def select_actions(self, observation) -> Tensor:
|
def select_actions(self, observation) -> Tensor:
|
||||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||||
|
|
||||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def clear_action_queue(self):
|
def clear_action_queue(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
|
|
|
@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
if cfg.online_steps > 0:
|
||||||
|
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# TODO: add configurable number of rollout? (default=1)
|
# TODO: add configurable number of rollout? (default=1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
max_steps=cfg.env.episode_length,
|
||||||
policy=td_policy,
|
policy=td_policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
|
assert len(rollout) <= cfg.env.episode_length
|
||||||
# set same episode index for all time steps contained in this rollout
|
# set same episode index for all time steps contained in this rollout
|
||||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||||
online_buffer.extend(rollout)
|
online_buffer.extend(rollout)
|
||||||
|
|
Loading…
Reference in New Issue