undo policy protocol format

This commit is contained in:
Wael Karkoub 2024-06-10 18:31:47 +01:00
parent 9c0c82f1e9
commit 0aec7fd8f0
1 changed files with 0 additions and 4 deletions

View File

@ -49,7 +49,6 @@ class Policy(Protocol):
Does things like clearing caches.
"""
...
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
@ -57,7 +56,6 @@ class Policy(Protocol):
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
...
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
@ -65,7 +63,6 @@ class Policy(Protocol):
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
...
@runtime_checkable
@ -76,4 +73,3 @@ class PolicyWithUpdate(Policy, Protocol):
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""
...