40 lines
978 B
Python
Executable File
40 lines
978 B
Python
Executable File
# License: see [LICENSE, LICENSES/rsl_rl/LICENSE]
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Tuple, Union
|
|
|
|
import torch
|
|
|
|
|
|
# minimal interface of the environment
|
|
class VecEnv(ABC):
|
|
num_envs: int
|
|
num_obs: int
|
|
num_privileged_obs: int
|
|
num_actions: int
|
|
max_episode_length: int
|
|
privileged_obs_buf: torch.Tensor
|
|
obs_buf: torch.Tensor
|
|
rew_buf: torch.Tensor
|
|
reset_buf: torch.Tensor
|
|
episode_length_buf: torch.Tensor # current episode duration
|
|
extras: dict
|
|
device: torch.device
|
|
|
|
@abstractmethod
|
|
def step(self, actions: torch.Tensor) -> Tuple[
|
|
torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def reset(self, env_ids: Union[list, torch.Tensor]):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_observations(self) -> torch.Tensor:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
|
pass
|