Add Tora example workflow

This commit is contained in:
kijai 2024-10-21 00:44:10 +03:00
parent 47f028e0bd
commit 4d437dae66
2 changed files with 1059 additions and 461 deletions

File diff suppressed because it is too large Load Diff

View File

@ -4,467 +4,6 @@ import torch
# Note that the coordinates passed to the model must not exceed 256. # Note that the coordinates passed to the model must not exceed 256.
# xy range 256 # xy range 256
PROVIDED_TRAJS = {
"circle1": [
[120, 194],
[144, 193],
[155, 189],
[158, 170],
[160, 153],
[159, 123],
[152, 113],
[136, 100],
[124, 100],
[108, 100],
[101, 106],
[90, 110],
[84, 129],
[79, 146],
[78, 165],
[83, 182],
[87, 189],
[94, 192],
[100, 194],
[106, 194],
[112, 194],
[118, 195],
],
"circle2": [
[100, 127],
[105, 117],
[122, 117],
[132, 129],
[133, 158],
[125, 181],
[108, 189],
[92, 185],
[84, 179],
[79, 163],
[75, 142],
[73, 118],
[75, 82],
[91, 63],
[115, 52],
[139, 46],
[154, 55],
[167, 93],
[175, 112],
[177, 137],
[177, 158],
[177, 171],
[175, 188],
[173, 204],
],
"coaster": [
[40, 208],
[40, 148],
[40, 100],
[52, 58],
[60, 57],
[74, 68],
[78, 90],
[84, 123],
[88, 148],
[96, 168],
[100, 181],
[102, 188],
[105, 192],
[113, 118],
[119, 80],
[128, 68],
[145, 109],
[149, 155],
[157, 175],
[161, 184],
[164, 184],
[172, 166],
[183, 107],
[189, 84],
[198, 76],
],
"dance": [
[81, 112],
[86, 112],
[92, 112],
[100, 113],
[102, 114],
[97, 115],
[92, 114],
[86, 112],
[81, 112],
[80, 112],
[84, 113],
[89, 114],
[95, 114],
[101, 114],
[102, 114],
[103, 124],
[105, 137],
[109, 156],
[114, 172],
[119, 180],
[124, 184],
[131, 181],
[140, 168],
[146, 152],
[150, 128],
[151, 117],
[152, 116],
[156, 116],
[163, 115],
[169, 116],
[175, 116],
[173, 116],
[167, 116],
[162, 114],
[157, 114],
[152, 115],
[156, 115],
[163, 115],
[168, 115],
[174, 116],
[175, 116],
[168, 116],
[162, 116],
[152, 114],
[149, 134],
[145, 156],
[139, 168],
[130, 183],
[118, 180],
[112, 170],
[107, 151],
[102, 128],
[103, 117],
[96, 113],
[88, 113],
[83, 112],
[80, 112],
],
"infinity": [
[60, 141],
[71, 127],
[92, 120],
[112, 123],
[130, 145],
[145, 163],
[167, 178],
[189, 187],
[206, 176],
[213, 147],
[208, 124],
[190, 112],
[176, 111],
[158, 124],
[145, 147],
[125, 172],
[104, 189],
[72, 189],
[59, 184],
[55, 153],
[57, 140],
[75, 119],
[112, 118],
[129, 142],
[149, 163],
[168, 180],
[194, 186],
[206, 175],
[211, 159],
[212, 149],
[212, 134],
[206, 122],
[180, 112],
[163, 116],
[149, 138],
[128, 170],
[108, 184],
[86, 190],
[63, 181],
[57, 152],
[57, 139],
],
"pause": [
[98, 186],
[100, 188],
[98, 186],
[100, 188],
[101, 187],
[104, 187],
[111, 184],
[116, 176],
[125, 162],
[132, 140],
[136, 119],
[137, 104],
[138, 96],
[139, 94],
[140, 94],
[140, 96],
[138, 98],
[138, 96],
[136, 94],
[137, 92],
[140, 92],
[144, 92],
[149, 92],
[152, 92],
[151, 92],
[147, 92],
[142, 92],
[140, 92],
[139, 95],
[139, 105],
[141, 122],
[142, 143],
[140, 167],
[136, 184],
[135, 188],
[132, 195],
[132, 192],
[131, 192],
[131, 192],
[130, 192],
[130, 195],
],
"shake": [
[103, 89],
[104, 89],
[106, 89],
[107, 89],
[108, 89],
[109, 89],
[110, 89],
[111, 89],
[112, 89],
[113, 89],
[114, 89],
[115, 89],
[116, 89],
[117, 89],
[118, 89],
[119, 89],
[120, 89],
[122, 89],
[123, 89],
[124, 89],
[125, 89],
[126, 89],
[127, 88],
[128, 88],
[129, 88],
[130, 88],
[131, 88],
[133, 87],
[136, 86],
[137, 86],
[138, 86],
[139, 86],
[140, 86],
[141, 86],
[142, 86],
[143, 86],
[144, 86],
[145, 86],
[146, 87],
[147, 87],
[148, 87],
[149, 87],
[148, 87],
[146, 87],
[145, 88],
[144, 88],
[142, 89],
[141, 89],
[140, 90],
[140, 91],
[138, 91],
[137, 92],
[136, 92],
[136, 93],
[135, 93],
[134, 93],
[133, 93],
[132, 93],
[131, 93],
[130, 93],
[129, 93],
[128, 93],
[127, 92],
[125, 92],
[124, 92],
[123, 92],
[122, 92],
[121, 92],
[120, 92],
[119, 92],
[118, 92],
[117, 92],
[116, 92],
[115, 92],
[113, 92],
[112, 92],
[111, 92],
[110, 92],
[109, 92],
[108, 92],
[108, 91],
[108, 90],
[109, 90],
[110, 90],
[111, 89],
[112, 89],
[113, 89],
[114, 89],
[115, 89],
[115, 88],
[116, 88],
[117, 88],
[118, 88],
[118, 87],
[119, 87],
[120, 87],
[121, 87],
[122, 86],
[123, 86],
[124, 86],
[125, 86],
[126, 85],
[127, 85],
[128, 85],
[129, 85],
[130, 85],
[131, 85],
[132, 85],
[133, 85],
[134, 85],
[135, 85],
[136, 85],
[137, 85],
[138, 85],
[139, 85],
[140, 85],
[141, 85],
[142, 85],
[143, 85],
[143, 84],
[144, 84],
[145, 84],
[146, 84],
[147, 84],
[148, 84],
[149, 84],
[148, 84],
[147, 84],
[145, 84],
[144, 84],
[143, 84],
[142, 84],
[141, 84],
[140, 85],
[139, 85],
[138, 85],
[137, 86],
[136, 86],
[136, 87],
[135, 87],
[134, 87],
[133, 87],
[132, 88],
[131, 88],
[130, 88],
[129, 88],
[129, 89],
[128, 89],
[127, 89],
[126, 89],
[125, 89],
[124, 90],
[123, 90],
[122, 90],
[121, 90],
[120, 91],
[119, 91],
[118, 91],
[117, 91],
[116, 91],
[115, 91],
[114, 91],
[113, 91],
[112, 91],
[111, 91],
[110, 91],
[109, 91],
[109, 90],
[108, 90],
[110, 90],
[111, 90],
[113, 90],
[114, 90],
[115, 90],
[116, 90],
[118, 90],
[120, 90],
[121, 90],
[122, 90],
[123, 90],
[124, 90],
[126, 90],
[127, 90],
[128, 90],
[129, 90],
[130, 90],
[131, 90],
[132, 90],
[133, 90],
[134, 90],
[135, 90],
[136, 90],
[137, 90],
[138, 90],
[139, 90],
[140, 90],
[141, 89],
[142, 89],
[143, 89],
[144, 89],
[145, 89],
[146, 89],
[147, 89],
[147, 89],
[147, 89],
],
"spiral": [
[16, 152],
[23, 138],
[39, 122],
[54, 115],
[75, 118],
[88, 130],
[93, 150],
[89, 176],
[75, 184],
[63, 177],
[65, 152],
[77, 135],
[98, 121],
[116, 120],
[135, 127],
[148, 136],
[156, 145],
[160, 165],
[158, 176],
[138, 187],
[133, 185],
[129, 148],
[140, 133],
[156, 120],
[177, 118],
[197, 118],
[214, 119],
[225, 118],
],
}
def pdf2(sigma_matrix, grid): def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution. """Calculate PDF of the bivariate Gaussian distribution.