RoboWaiter/robowaiter/algos/explore/rrt_star.py

289 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Path planning Sample Code with RRT*
author: Atsushi Sakai(@Atsushi_twi)
"""
import math
import sys
import matplotlib.pyplot as plt
import pathlib
from rrt import RRT
sys.path.append(str(pathlib.Path(__file__).parent.parent))
class RRTStar(RRT):
"""
Class for RRT Star planning
"""
class Node(RRT.Node):
def __init__(self, x, y):
super().__init__(x, y)
self.cost = 0.0 # 路径代价
def __init__(self,
rand_area,
map,
scale_ratio=5,
expand_dis=100.0,
path_resolution=10.0,
goal_sample_rate=10,
max_iter=300,
connect_circle_dist=50.0, # new
search_until_max_iter=False, # new
robot_radius=5.0):
super().__init__(rand_area, map, scale_ratio, expand_dis,
path_resolution, goal_sample_rate, max_iter,
robot_radius=robot_radius)
self.connect_circle_dist = connect_circle_dist
self.search_until_max_iter = search_until_max_iter
self.node_list = []
def planning(self, start, goal, path_smoothing=True, animation=False) -> [(float, float), ...]:
"""
rrt star path planning, return a path list
animation: flag for animation on or off .
"""
(self.start.x, self.start.y) = start
(self.goal.x, self.goal.y) = goal
# self.update_play_area()
self.node_list = [self.start]
while len(self.node_list) < self.max_iter:
#for i in range(self.max_iter):
# print("Iter:", i, ", number of nodes:", len(self.node_list))
print("number of nodes:", len(self.node_list))
rnd_node = self.get_random_node()
nearest_ind = self.get_nearest_node_index(self.node_list, rnd_node)
nearest_node = self.node_list[nearest_ind]
new_node = self.steer(nearest_node, rnd_node, self.expand_dis)
if animation:
self.draw_graph(new_node)
if self.check_collision(new_node):
new_node.cost = nearest_node.cost + \
math.hypot(new_node.x - nearest_node.x,
new_node.y - nearest_node.y)
near_inds = self.find_near_nodes(new_node)
# node_with_updated_parent: 已经找到父节点的new_node
node_with_updated_parent = self.choose_parent(new_node, near_inds)
if node_with_updated_parent:
self.rewire(node_with_updated_parent, near_inds)
self.node_list.append(node_with_updated_parent)
else:
self.node_list.append(new_node) # ??? 不可能发生
# 目标检测
if not self.search_until_max_iter:
last_index = self.search_best_goal_node() # 找到目标单步范围内的总距离最短的节点
if last_index is not None:
path = self.generate_final_course(last_index)
if path_smoothing:
return self.path_smoothing(path)
else:
return path
print("reached max iteration")
last_index = self.search_best_goal_node()
if last_index is not None:
path = self.generate_final_course(last_index)
if path_smoothing:
return self.path_smoothing(path)
else:
return path
return None
def choose_parent(self, new_node, near_inds):
"""
为 new_node 选择(从起点开始)总距离最小的父节点
Computes the cheapest point to new_node contained in the list
near_inds and set such a node as the parent of new_node.
Arguments:
--------
new_node, Node
randomly generated node with a path from its neared point
There are not coalitions between this node and th tree.
near_inds: list
Indices of indices of the nodes what are near to new_node
Returns.
------
Node, a copy of new_node
"""
if not near_inds:
return None
# search nearest cost in near_inds
costs = []
for i in near_inds:
near_node = self.node_list[i]
t_node = self.steer(near_node, new_node)
if t_node and self.check_collision(t_node):
costs.append(self.calc_new_cost(near_node, new_node))
else:
costs.append(float("inf")) # the cost of collision node
min_cost = min(costs)
if min_cost == float("inf"):
print("There is no good path.(min_cost is inf)")
return None
min_ind = near_inds[costs.index(min_cost)]
new_node = self.steer(self.node_list[min_ind], new_node) # 为new_node设置父节点
new_node.cost = min_cost
return new_node
def search_best_goal_node(self):
'''
从可直达目标的节点(单步范围内且中间无障碍物)中,选出从起点到目标距离最短的中间节点
'''
dist_to_goal_list = [
self.calc_dist_to_goal(n.x, n.y) for n in self.node_list
]
goal_inds = [ # 距离目标单步范围内的节点
dist_to_goal_list.index(i) for i in dist_to_goal_list
if i <= self.expand_dis
]
safe_goal_inds = [] # 目标单步范围内且中间没有障碍物的节点
for goal_ind in goal_inds:
t_node = self.steer(self.node_list[goal_ind], self.goal)
if self.check_collision(t_node):
safe_goal_inds.append(goal_ind)
if not safe_goal_inds:
return None
safe_goal_costs = [self.node_list[i].cost + # 从起点经过安全节点到目标的距离
self.calc_dist_to_goal(self.node_list[i].x, self.node_list[i].y)
for i in safe_goal_inds]
min_cost = min(safe_goal_costs)
for i, cost in zip(safe_goal_inds, safe_goal_costs):
if cost == min_cost:
return i
return None
def find_near_nodes(self, new_node):
"""
找到 new_node 周围一定范围内的树中的节点
1) defines a ball centered on new_node
2) Returns all nodes of the three that are inside this ball
Arguments:
---------
new_node: Node
new randomly generated node, without collisions between
its nearest node
Returns:
-------
list
List with the indices of the nodes inside the ball of
radius r
"""
nnode = len(self.node_list) + 1
r = self.connect_circle_dist * math.sqrt(math.log(nnode) / nnode)
# if expand_dist exists, search vertices in a range no more than expand_dist
if hasattr(self, 'expand_dis'):
r = min(r, self.expand_dis)
dist_list = [(node.x - new_node.x) ** 2 + (node.y - new_node.y) ** 2
for node in self.node_list]
near_inds = [dist_list.index(i) for i in dist_list if i <= r ** 2] #
return near_inds
def rewire(self, new_node, near_inds):
"""
新加入节点后,为周围的其他节点重新计算最短路径并更新其父节点
For each node in near_inds, this will check if it is cheaper to
arrive to them from new_node.
In such a case, this will re-assign the parent of the nodes in
near_inds to new_node.
Parameters:
----------
new_node, Node
Node randomly added which can be joined to the tree
near_inds, list of uints
A list of indices of the self.new_node which contains
nodes within a circle of a given radius.
Remark: parent is designated in choose_parent.
"""
for i in near_inds:
near_node = self.node_list[i]
edge_node = self.steer(new_node, near_node)
if not edge_node:
continue
edge_node.cost = self.calc_new_cost(new_node, near_node)
no_collision = self.check_collision(edge_node)
improved_cost = near_node.cost > edge_node.cost
if no_collision and improved_cost:
for node in self.node_list:
if node.parent == near_node:
node.parent = edge_node
self.node_list[i] = edge_node
self.propagate_cost_to_leaves(self.node_list[i])
def calc_new_cost(self, from_node, to_node):
'''
从起始位置经过 from_node 到 to_node 的 cost
'''
d, _ = self.calc_distance_and_angle(from_node, to_node)
return from_node.cost + d
def propagate_cost_to_leaves(self, parent_node):
'''
(递归算法) 从父节点不断向子节点传播计算cost
'''
for node in self.node_list:
if node.parent == parent_node:
node.cost = self.calc_new_cost(parent_node, node)
self.propagate_cost_to_leaves(node)
def main():
print("Start " + __file__)
# ====Search Path with RRT====
obstacle_list = [
(5, 5, 1),
(3, 6, 2),
(3, 8, 2),
(3, 10, 2),
(7, 5, 2),
(9, 5, 2),
(8, 10, 1),
(6, 12, 1),
] # [x,y,size(radius)]
# Set Initial parameters
rrt_star = RRTStar(rand_area=[-2, 15], expand_dis=1, robot_radius=0.8)
path = rrt_star.planning(animation=show_animation)
if path is None:
print("Cannot find path")
else:
print("found path!!")
# Draw final path
if show_animation:
rrt_star.draw_graph()
plt.plot([x for (x, y) in path], [y for (x, y) in path], 'r--')
plt.grid(True)
plt.show()
if __name__ == '__main__':
main()