fix: server predicts multiple actions for a given observation, VLA-like
This commit is contained in:
parent
485d64c8f4
commit
b9b7492132
|
@ -17,15 +17,11 @@ def get_device():
|
||||||
|
|
||||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||||
def __init__(self, policy: PreTrainedPolicy = None):
|
def __init__(self, policy: PreTrainedPolicy = None):
|
||||||
|
# TODO: Add code for loading and using policy for inference
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
# TODO: Add device specification for policy inference at init
|
||||||
|
self.observation = None
|
||||||
|
|
||||||
# TODO: Add device specification for policy inference
|
|
||||||
# self.observation = None
|
|
||||||
self.observation = async_inference_pb2.Observation(
|
|
||||||
transfer_state=2,
|
|
||||||
data=np.array([1], dtype=np.float32).tobytes()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
# keeping a list of all observations received from the robot client
|
# keeping a list of all observations received from the robot client
|
||||||
self.observations = []
|
self.observations = []
|
||||||
|
@ -62,7 +58,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||||
def StreamActions(self, request, context):
|
def StreamActions(self, request, context):
|
||||||
"""Stream actions to the robot client"""
|
"""Stream actions to the robot client"""
|
||||||
client_id = context.peer()
|
client_id = context.peer()
|
||||||
print(f"Client {client_id} connected for action streaming")
|
# print(f"Client {client_id} connected for action streaming")
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
yield self._generate_and_queue_action(self.observation)
|
yield self._generate_and_queue_action(self.observation)
|
||||||
|
@ -80,17 +76,20 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||||
raise NotImplementedError("Not implemented")
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
def _generate_and_queue_action(self, observation):
|
def _generate_and_queue_action(self, observation):
|
||||||
"""Generate an action based on the observation (dummy logic).
|
"""Generate a buffer of actions based on the observation (dummy logic).
|
||||||
Mainly used for testing purposes"""
|
Mainly used for testing purposes"""
|
||||||
# Debinarize the observation data
|
time.sleep(2)
|
||||||
|
# Debinarize observation data
|
||||||
data = np.frombuffer(
|
data = np.frombuffer(
|
||||||
observation.data,
|
observation.data,
|
||||||
dtype=np.float32
|
dtype=np.float32
|
||||||
)
|
)
|
||||||
# dummy transform on the observation data
|
# dummy transform on the observation data
|
||||||
action = (data * 1.4).sum()
|
action_content = (data * 2).sum().item()
|
||||||
# map action to bytes
|
action_data = (action_content * np.ones(
|
||||||
action_data = np.array([action], dtype=np.float32).tobytes()
|
shape=(10, 5), # 10 5-dimensional actions
|
||||||
|
dtype=np.float32
|
||||||
|
)).tobytes()
|
||||||
|
|
||||||
action = async_inference_pb2.Action(
|
action = async_inference_pb2.Action(
|
||||||
transfer_state=observation.transfer_state,
|
transfer_state=observation.transfer_state,
|
||||||
|
|
Loading…
Reference in New Issue