Merge bc2665cd4d
into 1c873df5c0
This commit is contained in:
commit
cd3dda8533
|
@ -36,6 +36,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.utils import smoothen_actions
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(PreTrainedPolicy):
|
class ACTPolicy(PreTrainedPolicy):
|
||||||
|
@ -136,6 +137,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
# use low-pass filter to prevent jerky actions
|
||||||
|
actions = smoothen_actions(actions)
|
||||||
|
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
|
|
@ -42,6 +42,7 @@ from lerobot.common.policies.utils import (
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
get_output_shape,
|
get_output_shape,
|
||||||
populate_queues,
|
populate_queues,
|
||||||
|
smoothen_actions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,6 +138,8 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
# use low-pass filter to prevent jerky actions
|
||||||
|
actions = smoothen_actions(actions)
|
||||||
|
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from scipy.signal import butter, filtfilt
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,3 +67,47 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output = module(dummy_input)
|
output = module(dummy_input)
|
||||||
return tuple(output.shape)
|
return tuple(output.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def butterworth_lowpass_filter(
|
||||||
|
data: np.ndarray, cutoff_freq: float = 1.0, sampling_freq: float = 15.0, order=2
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Applies a low-pass Butterworth filter to the input data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
data (np.array): Input data array.
|
||||||
|
cutoff (float): Cutoff frequency of the filter (Hz). Smoother for lower values.
|
||||||
|
fs (float): Sampling frequency of the data (Hz).
|
||||||
|
order (int): Order of the filter. Higher order may introduce phase distortions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
filtered_data (np.array): Filtered data array with same shape as data.
|
||||||
|
"""
|
||||||
|
nyquist = 0.5 * sampling_freq
|
||||||
|
normal_cutoff = cutoff_freq / nyquist
|
||||||
|
b, a = butter(order, normal_cutoff, btype="low", analog=False)
|
||||||
|
|
||||||
|
# apply the filter along axis 0
|
||||||
|
filtered_data = filtfilt(b, a, data, axis=0)
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
|
|
||||||
|
def smoothen_actions(actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Smoothens the provided action sequence tensor
|
||||||
|
Args:
|
||||||
|
actions (torch.Tensor): actions from policy
|
||||||
|
"""
|
||||||
|
if not isinstance(actions, torch.Tensor):
|
||||||
|
raise ValueError(f"Invalid input type for actions {type(actions)}. Expected torch.Tensor!")
|
||||||
|
|
||||||
|
if len(actions.shape) == 3 and not actions.shape[0] == 1:
|
||||||
|
raise NotImplementedError("Batch processing not implemented!!")
|
||||||
|
|
||||||
|
actions_np = actions.squeeze(0).cpu().numpy()
|
||||||
|
# apply the low-pass filter
|
||||||
|
filtered_actions_np = butterworth_lowpass_filter(actions_np.copy())
|
||||||
|
# disable filtering for the gripper joint
|
||||||
|
filtered_actions_np[:, -1] = actions_np[:, -1]
|
||||||
|
return torch.from_numpy(filtered_actions_np.copy()).unsqueeze(0).to(actions.device)
|
||||||
|
|
Loading…
Reference in New Issue