backup wip

This commit is contained in:
Alexander Soare 2024-06-07 14:51:13 +01:00
parent a48789d629
commit df8f95f157
2 changed files with 7 additions and 8 deletions

View File

@ -331,16 +331,15 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
continue
batch_ = torch.load("/tmp/batch.pth")
print(f"STEP {step}")
assert torch.equal(batch['index'], batch_['index'])
assert torch.equal(batch['episode_index'], batch_['episode_index'])
assert torch.equal(batch["index"], batch_["index"])
assert torch.equal(batch["episode_index"], batch_["episode_index"])
if not torch.equal(batch["observation.image"], batch_["observation.image"]):
import cv2
for b, fn in [(batch, "outputs/img.png"), (batch_, "outputs/img_.png")]:
cv2.imwrite(
fn,
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
),
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8),
)
assert False
assert torch.equal(batch["observation.state"], batch_["observation.state"])

View File

@ -103,7 +103,7 @@ def update_policy(
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
step: int = 0
step: int = 0,
):
"""Returns a dictionary of items for logging."""
start_time = time.perf_counter()
@ -446,7 +446,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
shuffle=False, # TODO(now)
shuffle=False, # TODO(now)
)
else:
shuffle = False # TODO(now)
@ -485,7 +485,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
step=step
step=step,
)
train_info["dataloading_s"] = dataloading_s