287 lines
14 KiB
Python
287 lines
14 KiB
Python
import torch
|
|
import unittest
|
|
from rsl_rl.modules.quantile_network import QuantileNetwork
|
|
|
|
|
|
class QuantileNetworkTest(unittest.TestCase):
|
|
def test_l1_loss(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
|
|
prediction = torch.tensor(
|
|
[
|
|
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
|
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
|
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
|
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
|
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
|
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
|
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
|
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
|
]
|
|
)
|
|
target = torch.tensor(
|
|
[
|
|
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
|
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
|
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
|
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
|
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
|
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
|
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
|
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
|
]
|
|
)
|
|
|
|
loss = qn.quantile_l1_loss(prediction, target)
|
|
|
|
self.assertAlmostEqual(loss.item(), 0.16419549)
|
|
|
|
def test_l1_loss_3d(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
|
|
prediction = torch.tensor(
|
|
[
|
|
[
|
|
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
|
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
|
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
|
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
|
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
|
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
|
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
|
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
|
],
|
|
[
|
|
[0.6874, 0.6214, 0.7796, 0.8148, 0.2070],
|
|
[0.0276, 0.5764, 0.5516, 0.9682, 0.6901],
|
|
[0.4020, 0.7084, 0.9965, 0.4311, 0.3789],
|
|
[0.5350, 0.9431, 0.1032, 0.6959, 0.4992],
|
|
[0.5059, 0.5479, 0.2302, 0.6753, 0.1593],
|
|
[0.6753, 0.4590, 0.9956, 0.6117, 0.1410],
|
|
[0.7464, 0.7184, 0.2972, 0.7694, 0.7999],
|
|
[0.3907, 0.2112, 0.6485, 0.0139, 0.6252],
|
|
],
|
|
]
|
|
)
|
|
target = torch.tensor(
|
|
[
|
|
[
|
|
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
|
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
|
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
|
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
|
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
|
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
|
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
|
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
|
],
|
|
[
|
|
[0.5120, 0.7683, 0.3579, 0.8640, 0.4374],
|
|
[0.2533, 0.3039, 0.2214, 0.7069, 0.3093],
|
|
[0.6993, 0.4288, 0.0827, 0.9156, 0.2043],
|
|
[0.6739, 0.2303, 0.3263, 0.6884, 0.3847],
|
|
[0.3990, 0.1458, 0.8918, 0.8036, 0.5012],
|
|
[0.9061, 0.2024, 0.7276, 0.8619, 0.1198],
|
|
[0.7379, 0.2005, 0.7634, 0.5691, 0.6132],
|
|
[0.4341, 0.5711, 0.1119, 0.4286, 0.7521],
|
|
],
|
|
]
|
|
)
|
|
|
|
loss = qn.quantile_l1_loss(prediction, target)
|
|
|
|
self.assertAlmostEqual(loss.item(), 0.15836075)
|
|
|
|
def test_l1_loss_multi_output(self):
|
|
qn = QuantileNetwork(10, 3, quantile_count=10)
|
|
|
|
prediction = torch.tensor(
|
|
[
|
|
[0.3003, 0.8692, 0.4608, 0.7158, 0.2640, 0.3928, 0.4557, 0.4620, 0.1331, 0.6356],
|
|
[0.8867, 0.1521, 0.5827, 0.0501, 0.4401, 0.7216, 0.6081, 0.5758, 0.2772, 0.6048],
|
|
[0.0763, 0.1609, 0.1860, 0.9173, 0.2121, 0.1920, 0.8509, 0.8588, 0.3321, 0.7202],
|
|
[0.8375, 0.5339, 0.4287, 0.9228, 0.8519, 0.0420, 0.5736, 0.9156, 0.4444, 0.2039],
|
|
[0.0704, 0.1833, 0.0839, 0.9573, 0.9852, 0.4191, 0.3562, 0.7225, 0.8481, 0.2096],
|
|
[0.4054, 0.8172, 0.8737, 0.2138, 0.4455, 0.7538, 0.1936, 0.9346, 0.8710, 0.0178],
|
|
[0.2139, 0.6619, 0.6889, 0.5726, 0.0595, 0.3278, 0.7673, 0.0803, 0.0374, 0.9011],
|
|
[0.2757, 0.0309, 0.8913, 0.0958, 0.1828, 0.9624, 0.6529, 0.7451, 0.9996, 0.8877],
|
|
[0.0722, 0.4240, 0.0716, 0.3199, 0.5570, 0.1056, 0.5950, 0.9926, 0.2991, 0.7334],
|
|
[0.0576, 0.6353, 0.5078, 0.4456, 0.9119, 0.6897, 0.1720, 0.5172, 0.9939, 0.5044],
|
|
[0.6300, 0.2304, 0.4064, 0.9195, 0.3299, 0.8631, 0.5842, 0.6751, 0.2964, 0.1215],
|
|
[0.7418, 0.5448, 0.7615, 0.6333, 0.9255, 0.1129, 0.0552, 0.4198, 0.9953, 0.7482],
|
|
[0.9910, 0.7644, 0.7047, 0.1395, 0.3688, 0.7688, 0.8574, 0.3494, 0.6153, 0.1286],
|
|
[0.2325, 0.7908, 0.3036, 0.4504, 0.3775, 0.6004, 0.0199, 0.9581, 0.8078, 0.8337],
|
|
[0.4038, 0.8313, 0.5441, 0.4778, 0.5777, 0.0580, 0.5314, 0.5336, 0.0740, 0.0094],
|
|
[0.9025, 0.5814, 0.4711, 0.2683, 0.4443, 0.5799, 0.6703, 0.2678, 0.7538, 0.1317],
|
|
[0.6755, 0.5696, 0.3334, 0.9146, 0.6203, 0.2080, 0.0799, 0.0059, 0.8347, 0.1874],
|
|
[0.0932, 0.0264, 0.9006, 0.3124, 0.3421, 0.8271, 0.3495, 0.2814, 0.9888, 0.5042],
|
|
[0.4893, 0.3514, 0.2564, 0.8117, 0.3738, 0.9085, 0.3055, 0.1456, 0.3624, 0.4095],
|
|
[0.0726, 0.2145, 0.6295, 0.7423, 0.1292, 0.7570, 0.4645, 0.0775, 0.1280, 0.7312],
|
|
[0.8763, 0.5302, 0.8627, 0.0429, 0.2833, 0.4745, 0.6308, 0.2245, 0.2755, 0.6823],
|
|
[0.9997, 0.3519, 0.0312, 0.1468, 0.5145, 0.0286, 0.6333, 0.1323, 0.2264, 0.9109],
|
|
[0.7742, 0.4857, 0.0413, 0.4523, 0.6847, 0.5774, 0.9478, 0.5861, 0.9834, 0.9437],
|
|
[0.7590, 0.5697, 0.7509, 0.3562, 0.9926, 0.3380, 0.0337, 0.7871, 0.1351, 0.9184],
|
|
[0.5701, 0.0234, 0.8088, 0.0681, 0.7090, 0.5925, 0.5266, 0.7198, 0.4121, 0.0268],
|
|
[0.5377, 0.1420, 0.2649, 0.0885, 0.1987, 0.1475, 0.1562, 0.2283, 0.9447, 0.4679],
|
|
[0.0306, 0.9763, 0.1234, 0.5009, 0.8800, 0.9409, 0.3525, 0.7264, 0.2209, 0.1436],
|
|
[0.2492, 0.4041, 0.9044, 0.3730, 0.3152, 0.7515, 0.2614, 0.9726, 0.6402, 0.5211],
|
|
[0.8626, 0.2828, 0.6946, 0.7066, 0.4395, 0.3015, 0.2643, 0.4421, 0.6036, 0.9009],
|
|
[0.7721, 0.1706, 0.7043, 0.4097, 0.7685, 0.3818, 0.1468, 0.6452, 0.1102, 0.1826],
|
|
[0.7156, 0.1795, 0.5574, 0.9478, 0.0058, 0.8037, 0.8712, 0.7730, 0.5638, 0.5843],
|
|
[0.8775, 0.6133, 0.4118, 0.3038, 0.2612, 0.2424, 0.8960, 0.8194, 0.3588, 0.3198],
|
|
]
|
|
)
|
|
|
|
target = torch.tensor(
|
|
[
|
|
[0.0986, 0.4029, 0.3110, 0.9976, 0.5668, 0.2658, 0.0660, 0.8492, 0.7872, 0.6368],
|
|
[0.3556, 0.9007, 0.0227, 0.7684, 0.0105, 0.9890, 0.7468, 0.0642, 0.5164, 0.1976],
|
|
[0.1331, 0.0998, 0.0959, 0.5596, 0.5984, 0.3880, 0.8050, 0.8320, 0.8977, 0.3486],
|
|
[0.3297, 0.8110, 0.2844, 0.4594, 0.0739, 0.2865, 0.2957, 0.9357, 0.9898, 0.4419],
|
|
[0.0495, 0.2826, 0.8306, 0.2968, 0.5690, 0.7251, 0.5947, 0.7526, 0.5076, 0.6480],
|
|
[0.0381, 0.8645, 0.7774, 0.9158, 0.9682, 0.5851, 0.0913, 0.8948, 0.1251, 0.1205],
|
|
[0.9059, 0.2758, 0.1948, 0.2694, 0.0946, 0.4381, 0.4667, 0.2176, 0.3494, 0.6073],
|
|
[0.1778, 0.8632, 0.3015, 0.2882, 0.4214, 0.2420, 0.8394, 0.1468, 0.9679, 0.6730],
|
|
[0.2400, 0.4344, 0.9765, 0.6544, 0.6338, 0.3434, 0.4776, 0.7981, 0.2008, 0.2267],
|
|
[0.5574, 0.8110, 0.0264, 0.4199, 0.8178, 0.8421, 0.8237, 0.2623, 0.8025, 0.9030],
|
|
[0.8652, 0.2872, 0.9463, 0.5543, 0.4866, 0.2842, 0.6692, 0.2306, 0.3136, 0.4570],
|
|
[0.0651, 0.8955, 0.7531, 0.9373, 0.0265, 0.0795, 0.7755, 0.1123, 0.1920, 0.3273],
|
|
[0.9824, 0.4177, 0.2729, 0.9447, 0.3987, 0.5495, 0.3674, 0.8067, 0.8668, 0.2394],
|
|
[0.4874, 0.3616, 0.7577, 0.6439, 0.2927, 0.8110, 0.6821, 0.0702, 0.5514, 0.7358],
|
|
[0.3627, 0.6392, 0.9085, 0.3646, 0.6051, 0.0586, 0.8763, 0.3899, 0.3242, 0.4598],
|
|
[0.0167, 0.0558, 0.3862, 0.7017, 0.0403, 0.6604, 0.9992, 0.2337, 0.5128, 0.1959],
|
|
[0.7774, 0.9201, 0.0405, 0.7894, 0.1406, 0.2458, 0.2616, 0.8787, 0.8158, 0.8591],
|
|
[0.3225, 0.9827, 0.4032, 0.2621, 0.7949, 0.9796, 0.9480, 0.3353, 0.1430, 0.5747],
|
|
[0.4734, 0.8714, 0.9320, 0.4265, 0.7765, 0.6980, 0.1587, 0.8784, 0.7119, 0.5141],
|
|
[0.7263, 0.4754, 0.8234, 0.0649, 0.4343, 0.5201, 0.8274, 0.9632, 0.3525, 0.8893],
|
|
[0.3324, 0.0142, 0.7222, 0.5026, 0.6011, 0.9275, 0.9351, 0.9236, 0.2621, 0.0768],
|
|
[0.8456, 0.1005, 0.5550, 0.0586, 0.3811, 0.0168, 0.9724, 0.9225, 0.7242, 0.0678],
|
|
[0.2167, 0.5423, 0.9059, 0.3320, 0.4026, 0.2128, 0.4562, 0.3564, 0.2573, 0.1076],
|
|
[0.8385, 0.2233, 0.0736, 0.3407, 0.4702, 0.1668, 0.5174, 0.4154, 0.4407, 0.1843],
|
|
[0.1828, 0.5321, 0.6651, 0.4108, 0.5736, 0.4012, 0.0434, 0.0034, 0.9282, 0.3111],
|
|
[0.1754, 0.8750, 0.6629, 0.7052, 0.9739, 0.7441, 0.8954, 0.9273, 0.3836, 0.5735],
|
|
[0.5586, 0.0381, 0.1493, 0.8575, 0.9351, 0.5222, 0.5600, 0.2369, 0.9217, 0.2545],
|
|
[0.1054, 0.8020, 0.8463, 0.6495, 0.3011, 0.3734, 0.7263, 0.8736, 0.9258, 0.5804],
|
|
[0.7614, 0.4748, 0.6588, 0.7717, 0.9811, 0.1659, 0.7851, 0.2135, 0.1767, 0.6724],
|
|
[0.7655, 0.8571, 0.4224, 0.9397, 0.1363, 0.9431, 0.9326, 0.3762, 0.1077, 0.9514],
|
|
[0.4115, 0.2169, 0.1340, 0.6564, 0.9989, 0.8068, 0.0387, 0.5064, 0.9964, 0.9427],
|
|
[0.5760, 0.2967, 0.3891, 0.6596, 0.8037, 0.1060, 0.0102, 0.8672, 0.5922, 0.6684],
|
|
]
|
|
)
|
|
|
|
loss = qn.quantile_l1_loss(prediction, target)
|
|
|
|
self.assertAlmostEqual(loss.item(), 0.17235948)
|
|
|
|
def test_quantile_huber_loss(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
|
|
prediction = torch.tensor(
|
|
[
|
|
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
|
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
|
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
|
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
|
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
|
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
|
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
|
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
|
]
|
|
)
|
|
target = torch.tensor(
|
|
[
|
|
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
|
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
|
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
|
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
|
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
|
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
|
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
|
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
|
]
|
|
)
|
|
|
|
loss = qn.quantile_huber_loss(prediction, target)
|
|
|
|
self.assertAlmostEqual(loss.item(), 0.04035041)
|
|
|
|
def test_sample_energy_loss(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
|
|
prediction = torch.tensor(
|
|
[
|
|
[0.9813, 0.5331, 0.3298, 0.2428, 0.0737],
|
|
[0.5442, 0.9623, 0.6070, 0.9360, 0.1145],
|
|
[0.3642, 0.0887, 0.1696, 0.8027, 0.7121],
|
|
[0.2005, 0.9889, 0.4350, 0.0301, 0.4546],
|
|
[0.8360, 0.6766, 0.2257, 0.7589, 0.3443],
|
|
[0.0835, 0.1747, 0.1734, 0.6668, 0.4522],
|
|
[0.0851, 0.3146, 0.0316, 0.2250, 0.5729],
|
|
[0.7725, 0.4596, 0.2495, 0.3633, 0.6340],
|
|
]
|
|
)
|
|
target = torch.tensor(
|
|
[
|
|
[0.5365, 0.1495, 0.8120, 0.2595, 0.1409],
|
|
[0.7784, 0.7070, 0.9066, 0.0123, 0.5587],
|
|
[0.9097, 0.0773, 0.9430, 0.2747, 0.1912],
|
|
[0.2307, 0.5068, 0.4624, 0.6708, 0.2844],
|
|
[0.3356, 0.5885, 0.2484, 0.8468, 0.1833],
|
|
[0.3354, 0.8831, 0.3489, 0.7165, 0.7953],
|
|
[0.7577, 0.8578, 0.2735, 0.1029, 0.5621],
|
|
[0.9124, 0.3476, 0.2012, 0.5830, 0.4615],
|
|
]
|
|
)
|
|
|
|
loss = qn.sample_energy_loss(prediction, target)
|
|
|
|
self.assertAlmostEqual(loss.item(), 0.09165202)
|
|
|
|
def test_cvar(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
measure = qn.measures[qn.measure_cvar](qn, 0.5)
|
|
|
|
# Quantiles for 3 agents
|
|
input = torch.tensor(
|
|
[
|
|
[0.1056, 0.0609, 0.3523, 0.3033, 0.1779],
|
|
[0.2049, 0.1425, 0.0767, 0.1868, 0.3891],
|
|
[0.1899, 0.1527, 0.2420, 0.2623, 0.1532],
|
|
]
|
|
)
|
|
correct_output = torch.tensor(
|
|
[
|
|
(0.4 * 0.0609 + 0.4 * 0.1056 + 0.2 * 0.1779),
|
|
(0.4 * 0.0767 + 0.4 * 0.1425 + 0.2 * 0.1868),
|
|
(0.4 * 0.1527 + 0.4 * 0.1532 + 0.2 * 0.1899),
|
|
]
|
|
)
|
|
|
|
computed_output = measure(input)
|
|
|
|
self.assertTrue(torch.isclose(computed_output, correct_output).all())
|
|
|
|
def test_cvar_adaptive(self):
|
|
qn = QuantileNetwork(10, 1, quantile_count=5)
|
|
|
|
input = torch.tensor(
|
|
[
|
|
[0.95, 0.21, 0.27, 0.26, 0.19],
|
|
[0.38, 0.34, 0.18, 0.32, 0.97],
|
|
[0.70, 0.24, 0.38, 0.89, 0.96],
|
|
]
|
|
)
|
|
confidence_levels = torch.tensor([0.1, 0.7, 0.9])
|
|
correct_output = torch.tensor(
|
|
[
|
|
0.19,
|
|
(0.18 / 3.5 + 0.32 / 3.5 + 0.34 / 3.5 + 0.38 / 7.0),
|
|
(0.24 / 4.5 + 0.38 / 4.5 + 0.70 / 4.5 + 0.89 / 4.5 + 0.96 / 9.0),
|
|
]
|
|
)
|
|
|
|
measure = qn.measures[qn.measure_cvar](qn, confidence_levels)
|
|
computed_output = measure(input)
|
|
|
|
self.assertTrue(torch.isclose(computed_output, correct_output).all())
|