fix training
This commit is contained in:
parent
816b2e9d63
commit
ccd5dc5a42
|
@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
from rl.torchrl.envs.transforms.transforms import Compose
|
||||||
|
|
||||||
|
|
||||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
return self._transform
|
return self._transform
|
||||||
|
|
||||||
def set_transform(self, transform):
|
def set_transform(self, transform):
|
||||||
self._transform = transform
|
if not isinstance(transform, Compose):
|
||||||
|
# required since torchrl calls `len(self._transform)` downstream
|
||||||
|
self._transform = Compose(transform)
|
||||||
|
else:
|
||||||
|
self._transform = transform
|
||||||
|
|
||||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||||
stats_path = self.data_dir / "stats.pth"
|
stats_path = self.data_dir / "stats.pth"
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
|
from rl.torchrl.envs.transforms.transforms import Compose, Transform
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, transform=None):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -33,7 +35,13 @@ def make_env(cfg, transform=None):
|
||||||
|
|
||||||
if transform is not None:
|
if transform is not None:
|
||||||
# useful to add normalization
|
# useful to add normalization
|
||||||
env.append_transform(transform)
|
if isinstance(transform, Compose):
|
||||||
|
for tf in transform:
|
||||||
|
env.append_transform(tf.clone())
|
||||||
|
elif isinstance(transform, Transform):
|
||||||
|
env.append_transform(transform.clone())
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue