Add temporary patch in TD-MPC

This commit is contained in:
Alexander Soare 2024-04-17 16:27:57 +01:00
parent 2298ddf226
commit dd9c6eed15
1 changed files with 4 additions and 0 deletions

View File

@ -330,6 +330,10 @@ class TDMPCPolicy(nn.Module):
return td_target return td_target
def forward(self, batch, step): def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.time()