This commit is contained in:
Vaishanth Ramaraj 2025-04-06 10:12:09 +08:00 committed by GitHub
commit cd3dda8533
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 0 deletions

View File

@ -36,6 +36,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import smoothen_actions
class ACTPolicy(PreTrainedPolicy):
@ -136,6 +137,8 @@ class ACTPolicy(PreTrainedPolicy):
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# use low-pass filter to prevent jerky actions
actions = smoothen_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.

View File

@ -42,6 +42,7 @@ from lerobot.common.policies.utils import (
get_dtype_from_parameters,
get_output_shape,
populate_queues,
smoothen_actions,
)
@ -137,6 +138,8 @@ class DiffusionPolicy(PreTrainedPolicy):
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# use low-pass filter to prevent jerky actions
actions = smoothen_actions(actions)
self._queues["action"].extend(actions.transpose(0, 1))

View File

@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from scipy.signal import butter, filtfilt
from torch import nn
@ -65,3 +67,47 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)
def butterworth_lowpass_filter(
data: np.ndarray, cutoff_freq: float = 1.0, sampling_freq: float = 15.0, order=2
) -> np.ndarray:
"""
Applies a low-pass Butterworth filter to the input data.
Parameters:
data (np.array): Input data array.
cutoff (float): Cutoff frequency of the filter (Hz). Smoother for lower values.
fs (float): Sampling frequency of the data (Hz).
order (int): Order of the filter. Higher order may introduce phase distortions.
Returns:
filtered_data (np.array): Filtered data array with same shape as data.
"""
nyquist = 0.5 * sampling_freq
normal_cutoff = cutoff_freq / nyquist
b, a = butter(order, normal_cutoff, btype="low", analog=False)
# apply the filter along axis 0
filtered_data = filtfilt(b, a, data, axis=0)
return filtered_data
def smoothen_actions(actions: torch.Tensor) -> torch.Tensor:
"""
Smoothens the provided action sequence tensor
Args:
actions (torch.Tensor): actions from policy
"""
if not isinstance(actions, torch.Tensor):
raise ValueError(f"Invalid input type for actions {type(actions)}. Expected torch.Tensor!")
if len(actions.shape) == 3 and not actions.shape[0] == 1:
raise NotImplementedError("Batch processing not implemented!!")
actions_np = actions.squeeze(0).cpu().numpy()
# apply the low-pass filter
filtered_actions_np = butterworth_lowpass_filter(actions_np.copy())
# disable filtering for the gripper joint
filtered_actions_np[:, -1] = actions_np[:, -1]
return torch.from_numpy(filtered_actions_np.copy()).unsqueeze(0).to(actions.device)