33 lines
670 B
C
33 lines
670 B
C
|
//
|
||
|
// Created by biao on 24-10-6.
|
||
|
//
|
||
|
|
||
|
#ifndef OBSERVATIONBUFFER_H
|
||
|
#define OBSERVATIONBUFFER_H
|
||
|
|
||
|
#include <torch/torch.h>
|
||
|
#include <vector>
|
||
|
|
||
|
class ObservationBuffer {
|
||
|
public:
|
||
|
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||
|
|
||
|
~ObservationBuffer() = default;
|
||
|
|
||
|
void reset(const std::vector<int>& reset_index, const torch::Tensor &new_obs);
|
||
|
|
||
|
void insert(const torch::Tensor &new_obs);
|
||
|
|
||
|
[[nodiscard]] torch::Tensor getObsVec(const std::vector<int> &obs_ids) const;
|
||
|
|
||
|
private:
|
||
|
int num_envs_;
|
||
|
int num_obs_;
|
||
|
int include_history_steps_;
|
||
|
int num_obs_total_;
|
||
|
torch::Tensor obs_buffer_;
|
||
|
};
|
||
|
|
||
|
|
||
|
#endif //OBSERVATIONBUFFER_H
|