backup wip

This commit is contained in:
Alexander Soare 2024-06-07 14:49:05 +01:00
parent 53b36dcaab
commit a48789d629
2 changed files with 27 additions and 10 deletions

View File

@ -29,6 +29,7 @@ TODO(alexander-soare): Use batch-first throughout.
import logging
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
@ -310,7 +311,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
def forward(self, batch: dict[str, Tensor], step) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
@ -328,7 +329,21 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
batch.update(torch.load("/tmp/batch.pth"))
batch_ = torch.load("/tmp/batch.pth")
print(f"STEP {step}")
assert torch.equal(batch['index'], batch_['index'])
assert torch.equal(batch['episode_index'], batch_['episode_index'])
if not torch.equal(batch["observation.image"], batch_["observation.image"]):
import cv2
for b, fn in [(batch, "outputs/img.png"), (batch_, "outputs/img_.png")]:
cv2.imwrite(
fn,
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
),
)
assert False
assert torch.equal(batch["observation.state"], batch_["observation.state"])
os.remove("/tmp/mutex.txt")
break
@ -343,12 +358,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# # Apply random image augmentations.
# if self.config.max_random_shift_ratio > 0:
# observations["observation.image"] = flatten_forward_unflatten(
# partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
# observations["observation.image"],
# )
# Apply random image augmentations.
if self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.

View File

@ -103,13 +103,14 @@ def update_policy(
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
step: int = 0
):
"""Returns a dictionary of items for logging."""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
output_dict = policy.forward(batch, step)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
@ -445,7 +446,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
shuffle=True,
shuffle=False, # TODO(now)
)
else:
shuffle = False # TODO(now)
@ -484,6 +485,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
step=step
)
train_info["dataloading_s"] = dataloading_s