40 lines
978 B
Python
40 lines
978 B
Python
|
# License: see [LICENSE, LICENSES/rsl_rl/LICENSE]
|
||
|
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
# minimal interface of the environment
|
||
|
class VecEnv(ABC):
|
||
|
num_envs: int
|
||
|
num_obs: int
|
||
|
num_privileged_obs: int
|
||
|
num_actions: int
|
||
|
max_episode_length: int
|
||
|
privileged_obs_buf: torch.Tensor
|
||
|
obs_buf: torch.Tensor
|
||
|
rew_buf: torch.Tensor
|
||
|
reset_buf: torch.Tensor
|
||
|
episode_length_buf: torch.Tensor # current episode duration
|
||
|
extras: dict
|
||
|
device: torch.device
|
||
|
|
||
|
@abstractmethod
|
||
|
def step(self, actions: torch.Tensor) -> Tuple[
|
||
|
torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def reset(self, env_ids: Union[list, torch.Tensor]):
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_observations(self) -> torch.Tensor:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
||
|
pass
|