undo policy protocol format
This commit is contained in:
parent
9c0c82f1e9
commit
0aec7fd8f0
|
@ -49,7 +49,6 @@ class Policy(Protocol):
|
|||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
@ -57,7 +56,6 @@ class Policy(Protocol):
|
|||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
@ -65,7 +63,6 @@ class Policy(Protocol):
|
|||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -76,4 +73,3 @@ class PolicyWithUpdate(Policy, Protocol):
|
|||
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
|
||||
target model, or incrementing an internal buffer).
|
||||
"""
|
||||
...
|
||||
|
|
Loading…
Reference in New Issue