diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03..d17394cb 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -36,6 +36,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import smoothen_actions class ACTPolicy(PreTrainedPolicy): @@ -136,6 +137,8 @@ class ACTPolicy(PreTrainedPolicy): # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # use low-pass filter to prevent jerky actions + actions = smoothen_actions(actions) # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ecadcb0..615eaf7a 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -42,6 +42,7 @@ from lerobot.common.policies.utils import ( get_dtype_from_parameters, get_output_shape, populate_queues, + smoothen_actions, ) @@ -137,6 +138,8 @@ class DiffusionPolicy(PreTrainedPolicy): # TODO(rcadene): make above methods return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # use low-pass filter to prevent jerky actions + actions = smoothen_actions(actions) self._queues["action"].extend(actions.transpose(0, 1)) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index c06e620b..9bb46128 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import torch +from scipy.signal import butter, filtfilt from torch import nn @@ -65,3 +67,47 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: with torch.inference_mode(): output = module(dummy_input) return tuple(output.shape) + + +def butterworth_lowpass_filter( + data: np.ndarray, cutoff_freq: float = 1.0, sampling_freq: float = 15.0, order=2 +) -> np.ndarray: + """ + Applies a low-pass Butterworth filter to the input data. + + Parameters: + data (np.array): Input data array. + cutoff (float): Cutoff frequency of the filter (Hz). Smoother for lower values. + fs (float): Sampling frequency of the data (Hz). + order (int): Order of the filter. Higher order may introduce phase distortions. + + Returns: + filtered_data (np.array): Filtered data array with same shape as data. + """ + nyquist = 0.5 * sampling_freq + normal_cutoff = cutoff_freq / nyquist + b, a = butter(order, normal_cutoff, btype="low", analog=False) + + # apply the filter along axis 0 + filtered_data = filtfilt(b, a, data, axis=0) + return filtered_data + + +def smoothen_actions(actions: torch.Tensor) -> torch.Tensor: + """ + Smoothens the provided action sequence tensor + Args: + actions (torch.Tensor): actions from policy + """ + if not isinstance(actions, torch.Tensor): + raise ValueError(f"Invalid input type for actions {type(actions)}. Expected torch.Tensor!") + + if len(actions.shape) == 3 and not actions.shape[0] == 1: + raise NotImplementedError("Batch processing not implemented!!") + + actions_np = actions.squeeze(0).cpu().numpy() + # apply the low-pass filter + filtered_actions_np = butterworth_lowpass_filter(actions_np.copy()) + # disable filtering for the gripper joint + filtered_actions_np[:, -1] = actions_np[:, -1] + return torch.from_numpy(filtered_actions_np.copy()).unsqueeze(0).to(actions.device)