fix training
This commit is contained in:
parent
816b2e9d63
commit
ccd5dc5a42
|
@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler
|
|||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from rl.torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
|
@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
self._transform = transform
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from rl.torchrl.envs.transforms.transforms import Compose, Transform
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
kwargs = {
|
||||
|
@ -33,7 +35,13 @@ def make_env(cfg, transform=None):
|
|||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
|
|
Loading…
Reference in New Issue