Compare commits
14 Commits
72ce7782ae
...
68f937894f
Author | SHA1 | Date |
---|---|---|
|
68f937894f | |
|
b568de35ad | |
|
ae9c81ac39 | |
|
bc2665cd4d | |
|
4c7de8d1c9 | |
|
77dd0e5056 | |
|
f71e0a7068 | |
|
387f5018d4 | |
|
e2aa4864e8 | |
|
5d5a4186c2 | |
|
66db8c668c | |
|
14fd7715fd | |
|
fbc7adad22 | |
|
7c6944f597 |
|
@ -98,7 +98,7 @@ conda create -y -n lerobot python=3.10
|
|||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, if you don't have `fffmpeg` in your environment:
|
||||
When using `miniconda`, if you don't have `ffmpeg` in your environment:
|
||||
```bash
|
||||
conda install ffmpeg
|
||||
```
|
||||
|
|
|
@ -257,6 +257,7 @@ def encode_video_frames(
|
|||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
|
|
|
@ -36,6 +36,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import smoothen_actions
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
|
@ -136,6 +137,8 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# use low-pass filter to prevent jerky actions
|
||||
actions = smoothen_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
|
|
|
@ -42,6 +42,7 @@ from lerobot.common.policies.utils import (
|
|||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
smoothen_actions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -137,6 +138,8 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# use low-pass filter to prevent jerky actions
|
||||
actions = smoothen_actions(actions)
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.signal import butter, filtfilt
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
@ -65,3 +67,47 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
|||
with torch.inference_mode():
|
||||
output = module(dummy_input)
|
||||
return tuple(output.shape)
|
||||
|
||||
|
||||
def butterworth_lowpass_filter(
|
||||
data: np.ndarray, cutoff_freq: float = 1.0, sampling_freq: float = 15.0, order=2
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Applies a low-pass Butterworth filter to the input data.
|
||||
|
||||
Parameters:
|
||||
data (np.array): Input data array.
|
||||
cutoff (float): Cutoff frequency of the filter (Hz). Smoother for lower values.
|
||||
fs (float): Sampling frequency of the data (Hz).
|
||||
order (int): Order of the filter. Higher order may introduce phase distortions.
|
||||
|
||||
Returns:
|
||||
filtered_data (np.array): Filtered data array with same shape as data.
|
||||
"""
|
||||
nyquist = 0.5 * sampling_freq
|
||||
normal_cutoff = cutoff_freq / nyquist
|
||||
b, a = butter(order, normal_cutoff, btype="low", analog=False)
|
||||
|
||||
# apply the filter along axis 0
|
||||
filtered_data = filtfilt(b, a, data, axis=0)
|
||||
return filtered_data
|
||||
|
||||
|
||||
def smoothen_actions(actions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Smoothens the provided action sequence tensor
|
||||
Args:
|
||||
actions (torch.Tensor): actions from policy
|
||||
"""
|
||||
if not isinstance(actions, torch.Tensor):
|
||||
raise ValueError(f"Invalid input type for actions {type(actions)}. Expected torch.Tensor!")
|
||||
|
||||
if len(actions.shape) == 3 and not actions.shape[0] == 1:
|
||||
raise NotImplementedError("Batch processing not implemented!!")
|
||||
|
||||
actions_np = actions.squeeze(0).cpu().numpy()
|
||||
# apply the low-pass filter
|
||||
filtered_actions_np = butterworth_lowpass_filter(actions_np.copy())
|
||||
# disable filtering for the gripper joint
|
||||
filtered_actions_np[:, -1] = actions_np[:, -1]
|
||||
return torch.from_numpy(filtered_actions_np.copy()).unsqueeze(0).to(actions.device)
|
||||
|
|
Loading…
Reference in New Issue