fix training

This commit is contained in:
Cadene 2024-03-11 12:33:15 +00:00
parent 816b2e9d63
commit ccd5dc5a42
2 changed files with 16 additions and 2 deletions

View File

@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from rl.torchrl.envs.transforms.transforms import Compose
class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
return self._transform
def set_transform(self, transform):
self._transform = transform
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
else:
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"

View File

@ -1,5 +1,7 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from rl.torchrl.envs.transforms.transforms import Compose, Transform
def make_env(cfg, transform=None):
kwargs = {
@ -33,7 +35,13 @@ def make_env(cfg, transform=None):
if transform is not None:
# useful to add normalization
env.append_transform(transform)
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env