From e316eda262e545d2098f533644c038bb6c23e043 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 10:40:42 -0400 Subject: [PATCH 01/16] Run black --- tests/infer/enum_growth.ipynb | 640 ++++++++++++++++++++++++++++++++-- 1 file changed, 619 insertions(+), 21 deletions(-) diff --git a/tests/infer/enum_growth.ipynb b/tests/infer/enum_growth.ipynb index 642f309d99..c413011fa3 100644 --- a/tests/infer/enum_growth.ipynb +++ b/tests/infer/enum_growth.ipynb @@ -19,6 +19,7 @@ "outputs": [], "source": [ "from matplotlib import pyplot\n", + "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'" ] @@ -34,24 +35,25 @@ "times1 = None\n", "times2 = None\n", "\n", + "\n", "def plot(title):\n", - " pyplot.figure(figsize=(8,5)).patch.set_color('white')\n", - " pyplot.title('{} data structures'.format(title))\n", + " pyplot.figure(figsize=(8, 5)).patch.set_color(\"white\")\n", + " pyplot.title(\"{} data structures\".format(title))\n", " for name, series in sorted(costs.items()):\n", " pyplot.plot(sizes, series, label=name)\n", - " pyplot.xlabel('problem size')\n", + " pyplot.xlabel(\"problem size\")\n", " pyplot.xlim(0, max(sizes))\n", - " pyplot.legend(loc='best')\n", + " pyplot.legend(loc=\"best\")\n", " pyplot.tight_layout()\n", "\n", - " pyplot.figure(figsize=(8,5)).patch.set_color('white')\n", - " pyplot.title('{} run time'.format(title))\n", - " pyplot.plot(sizes, times1, label='optim + compute')\n", - " pyplot.plot(sizes, times2, label='compute')\n", + " pyplot.figure(figsize=(8, 5)).patch.set_color(\"white\")\n", + " pyplot.title(\"{} run time\".format(title))\n", + " pyplot.plot(sizes, times1, label=\"optim + compute\")\n", + " pyplot.plot(sizes, times2, label=\"compute\")\n", " pyplot.xlim(0, max(sizes))\n", - " pyplot.xlabel('problem size')\n", - " pyplot.ylabel('time (sec)')\n", - " pyplot.legend(loc='best')\n", + " pyplot.xlabel(\"problem size\")\n", + " pyplot.ylabel(\"time (sec)\")\n", + " pyplot.legend(loc=\"best\")\n", " pyplot.tight_layout()" ] }, @@ -61,10 +63,308 @@ "metadata": {}, "outputs": [], "source": [ - "sizes = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]\n", - "costs = {'einsum': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48], 'tensordot': [12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 168, 172, 176, 180, 184, 188, 192, 196, 200], 'tensor': [22, 31, 40, 49, 58, 67, 76, 85, 94, 103, 112, 121, 130, 139, 148, 157, 166, 175, 184, 193, 202, 211, 220, 229, 238, 247, 256, 265, 274, 283, 292, 301, 310, 319, 328, 337, 346, 355, 364, 373, 382, 391, 400, 409, 418, 427, 436, 445]}\n", - "times1 = [0.01864790916442871, 0.015166997909545898, 0.017799854278564453, 0.021364927291870117, 0.0234529972076416, 0.03243708610534668, 0.03485298156738281, 0.03809309005737305, 0.04254293441772461, 0.043493032455444336, 0.04782605171203613, 0.051072120666503906, 0.05495715141296387, 0.06077980995178223, 0.06451010704040527, 0.06647181510925293, 0.07750391960144043, 0.10012388229370117, 0.09436392784118652, 0.08780503273010254, 0.09475111961364746, 0.08931398391723633, 0.1099538803100586, 0.10660696029663086, 0.10943722724914551, 0.11156201362609863, 0.11216998100280762, 0.11894893646240234, 0.12170791625976562, 0.1290268898010254, 0.13869500160217285, 0.1344318389892578, 0.13837814331054688, 0.14883112907409668, 0.14552593231201172, 0.1480569839477539, 0.14761590957641602, 0.15995121002197266, 0.16048288345336914, 0.16365408897399902, 0.16843199729919434, 0.2130718231201172, 0.17986297607421875, 0.1792001724243164, 0.1941969394683838, 0.2153019905090332, 0.20756793022155762, 0.19938111305236816]\n", - "times2 = [0.010827064514160156, 0.014249086380004883, 0.016450166702270508, 0.020006895065307617, 0.025799989700317383, 0.02879500389099121, 0.03235912322998047, 0.036743879318237305, 0.04072308540344238, 0.04432511329650879, 0.04558587074279785, 0.051867008209228516, 0.05726289749145508, 0.058149099349975586, 0.06532096862792969, 0.0634920597076416, 0.07218098640441895, 0.12434697151184082, 0.07972311973571777, 0.08487296104431152, 0.08191704750061035, 0.13434886932373047, 0.10629105567932129, 0.10842609405517578, 0.10170793533325195, 0.10760092735290527, 0.11115694046020508, 0.1158750057220459, 0.12462496757507324, 0.1272139549255371, 0.13429498672485352, 0.1305849552154541, 0.14617490768432617, 0.18872499465942383, 0.1460709571838379, 0.13549304008483887, 0.1373729705810547, 0.15271997451782227, 0.15703701972961426, 0.1608130931854248, 0.21175909042358398, 0.18168210983276367, 0.17579412460327148, 0.17799592018127441, 0.1961660385131836, 0.20264911651611328, 0.25041794776916504, 0.1808319091796875]" + "sizes = [\n", + " 3,\n", + " 4,\n", + " 5,\n", + " 6,\n", + " 7,\n", + " 8,\n", + " 9,\n", + " 10,\n", + " 11,\n", + " 12,\n", + " 13,\n", + " 14,\n", + " 15,\n", + " 16,\n", + " 17,\n", + " 18,\n", + " 19,\n", + " 20,\n", + " 21,\n", + " 22,\n", + " 23,\n", + " 24,\n", + " 25,\n", + " 26,\n", + " 27,\n", + " 28,\n", + " 29,\n", + " 30,\n", + " 31,\n", + " 32,\n", + " 33,\n", + " 34,\n", + " 35,\n", + " 36,\n", + " 37,\n", + " 38,\n", + " 39,\n", + " 40,\n", + " 41,\n", + " 42,\n", + " 43,\n", + " 44,\n", + " 45,\n", + " 46,\n", + " 47,\n", + " 48,\n", + " 49,\n", + " 50,\n", + "]\n", + "costs = {\n", + " \"einsum\": [\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " 5,\n", + " 6,\n", + " 7,\n", + " 8,\n", + " 9,\n", + " 10,\n", + " 11,\n", + " 12,\n", + " 13,\n", + " 14,\n", + " 15,\n", + " 16,\n", + " 17,\n", + " 18,\n", + " 19,\n", + " 20,\n", + " 21,\n", + " 22,\n", + " 23,\n", + " 24,\n", + " 25,\n", + " 26,\n", + " 27,\n", + " 28,\n", + " 29,\n", + " 30,\n", + " 31,\n", + " 32,\n", + " 33,\n", + " 34,\n", + " 35,\n", + " 36,\n", + " 37,\n", + " 38,\n", + " 39,\n", + " 40,\n", + " 41,\n", + " 42,\n", + " 43,\n", + " 44,\n", + " 45,\n", + " 46,\n", + " 47,\n", + " 48,\n", + " ],\n", + " \"tensordot\": [\n", + " 12,\n", + " 16,\n", + " 20,\n", + " 24,\n", + " 28,\n", + " 32,\n", + " 36,\n", + " 40,\n", + " 44,\n", + " 48,\n", + " 52,\n", + " 56,\n", + " 60,\n", + " 64,\n", + " 68,\n", + " 72,\n", + " 76,\n", + " 80,\n", + " 84,\n", + " 88,\n", + " 92,\n", + " 96,\n", + " 100,\n", + " 104,\n", + " 108,\n", + " 112,\n", + " 116,\n", + " 120,\n", + " 124,\n", + " 128,\n", + " 132,\n", + " 136,\n", + " 140,\n", + " 144,\n", + " 148,\n", + " 152,\n", + " 156,\n", + " 160,\n", + " 164,\n", + " 168,\n", + " 172,\n", + " 176,\n", + " 180,\n", + " 184,\n", + " 188,\n", + " 192,\n", + " 196,\n", + " 200,\n", + " ],\n", + " \"tensor\": [\n", + " 22,\n", + " 31,\n", + " 40,\n", + " 49,\n", + " 58,\n", + " 67,\n", + " 76,\n", + " 85,\n", + " 94,\n", + " 103,\n", + " 112,\n", + " 121,\n", + " 130,\n", + " 139,\n", + " 148,\n", + " 157,\n", + " 166,\n", + " 175,\n", + " 184,\n", + " 193,\n", + " 202,\n", + " 211,\n", + " 220,\n", + " 229,\n", + " 238,\n", + " 247,\n", + " 256,\n", + " 265,\n", + " 274,\n", + " 283,\n", + " 292,\n", + " 301,\n", + " 310,\n", + " 319,\n", + " 328,\n", + " 337,\n", + " 346,\n", + " 355,\n", + " 364,\n", + " 373,\n", + " 382,\n", + " 391,\n", + " 400,\n", + " 409,\n", + " 418,\n", + " 427,\n", + " 436,\n", + " 445,\n", + " ],\n", + "}\n", + "times1 = [\n", + " 0.01864790916442871,\n", + " 0.015166997909545898,\n", + " 0.017799854278564453,\n", + " 0.021364927291870117,\n", + " 0.0234529972076416,\n", + " 0.03243708610534668,\n", + " 0.03485298156738281,\n", + " 0.03809309005737305,\n", + " 0.04254293441772461,\n", + " 0.043493032455444336,\n", + " 0.04782605171203613,\n", + " 0.051072120666503906,\n", + " 0.05495715141296387,\n", + " 0.06077980995178223,\n", + " 0.06451010704040527,\n", + " 0.06647181510925293,\n", + " 0.07750391960144043,\n", + " 0.10012388229370117,\n", + " 0.09436392784118652,\n", + " 0.08780503273010254,\n", + " 0.09475111961364746,\n", + " 0.08931398391723633,\n", + " 0.1099538803100586,\n", + " 0.10660696029663086,\n", + " 0.10943722724914551,\n", + " 0.11156201362609863,\n", + " 0.11216998100280762,\n", + " 0.11894893646240234,\n", + " 0.12170791625976562,\n", + " 0.1290268898010254,\n", + " 0.13869500160217285,\n", + " 0.1344318389892578,\n", + " 0.13837814331054688,\n", + " 0.14883112907409668,\n", + " 0.14552593231201172,\n", + " 0.1480569839477539,\n", + " 0.14761590957641602,\n", + " 0.15995121002197266,\n", + " 0.16048288345336914,\n", + " 0.16365408897399902,\n", + " 0.16843199729919434,\n", + " 0.2130718231201172,\n", + " 0.17986297607421875,\n", + " 0.1792001724243164,\n", + " 0.1941969394683838,\n", + " 0.2153019905090332,\n", + " 0.20756793022155762,\n", + " 0.19938111305236816,\n", + "]\n", + "times2 = [\n", + " 0.010827064514160156,\n", + " 0.014249086380004883,\n", + " 0.016450166702270508,\n", + " 0.020006895065307617,\n", + " 0.025799989700317383,\n", + " 0.02879500389099121,\n", + " 0.03235912322998047,\n", + " 0.036743879318237305,\n", + " 0.04072308540344238,\n", + " 0.04432511329650879,\n", + " 0.04558587074279785,\n", + " 0.051867008209228516,\n", + " 0.05726289749145508,\n", + " 0.058149099349975586,\n", + " 0.06532096862792969,\n", + " 0.0634920597076416,\n", + " 0.07218098640441895,\n", + " 0.12434697151184082,\n", + " 0.07972311973571777,\n", + " 0.08487296104431152,\n", + " 0.08191704750061035,\n", + " 0.13434886932373047,\n", + " 0.10629105567932129,\n", + " 0.10842609405517578,\n", + " 0.10170793533325195,\n", + " 0.10760092735290527,\n", + " 0.11115694046020508,\n", + " 0.1158750057220459,\n", + " 0.12462496757507324,\n", + " 0.1272139549255371,\n", + " 0.13429498672485352,\n", + " 0.1305849552154541,\n", + " 0.14617490768432617,\n", + " 0.18872499465942383,\n", + " 0.1460709571838379,\n", + " 0.13549304008483887,\n", + " 0.1373729705810547,\n", + " 0.15271997451782227,\n", + " 0.15703701972961426,\n", + " 0.1608130931854248,\n", + " 0.21175909042358398,\n", + " 0.18168210983276367,\n", + " 0.17579412460327148,\n", + " 0.17799592018127441,\n", + " 0.1961660385131836,\n", + " 0.20264911651611328,\n", + " 0.25041794776916504,\n", + " 0.1808319091796875,\n", + "]" ] }, { @@ -73,7 +373,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot('HMM')" + "plot(\"HMM\")" ] }, { @@ -82,10 +382,308 @@ "metadata": {}, "outputs": [], "source": [ - "sizes = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]\n", - "costs = {'einsum': [7, 9, 13, 15, 19, 21, 25, 27, 31, 33, 37, 39, 43, 45, 49, 51, 55, 57, 61, 63, 67, 69, 73, 75, 79, 81, 85, 87, 91, 93, 97, 99, 103, 105, 109, 111, 115, 117, 121, 123, 127, 129, 133, 135, 139, 141, 145, 147], 'tensordot': [18, 25, 30, 37, 42, 49, 54, 61, 66, 73, 78, 85, 90, 97, 102, 109, 114, 121, 126, 133, 138, 145, 150, 157, 162, 169, 174, 181, 186, 193, 198, 205, 210, 217, 222, 229, 234, 241, 246, 253, 258, 265, 270, 277, 282, 289, 294, 301], 'tensor': [46, 63, 80, 97, 114, 131, 148, 165, 182, 199, 216, 233, 250, 267, 284, 301, 318, 335, 352, 369, 386, 403, 420, 437, 454, 471, 488, 505, 522, 539, 556, 573, 590, 607, 624, 641, 658, 675, 692, 709, 726, 743, 760, 777, 794, 811, 828, 845]}\n", - "times1 = [0.02198004722595215, 0.03037405014038086, 0.03350090980529785, 0.04224896430969238, 0.04834318161010742, 0.05909299850463867, 0.06626009941101074, 0.08351302146911621, 0.09097099304199219, 0.08897876739501953, 0.09535503387451172, 0.10136294364929199, 0.13000011444091797, 0.12712597846984863, 0.13105392456054688, 0.1476750373840332, 0.14663481712341309, 0.15439701080322266, 0.15521693229675293, 0.1650080680847168, 0.1742238998413086, 0.17893004417419434, 0.18517208099365234, 0.19159197807312012, 0.20879316329956055, 0.2737429141998291, 0.23352789878845215, 0.22190213203430176, 0.23365497589111328, 0.23900103569030762, 0.2523791790008545, 0.26091718673706055, 0.2820899486541748, 0.3140451908111572, 0.28127598762512207, 0.2906830310821533, 0.34561610221862793, 0.4711790084838867, 0.3032550811767578, 0.31789112091064453, 0.34140491485595703, 0.34586501121520996, 0.3419170379638672, 0.35588693618774414, 0.36873412132263184, 0.36976003646850586, 0.3961608409881592, 0.3883850574493408]\n", - "times2 = [0.022389888763427734, 0.026437997817993164, 0.03232693672180176, 0.041667938232421875, 0.05160379409790039, 0.055931806564331055, 0.07128310203552246, 0.08053183555603027, 0.08122706413269043, 0.0810542106628418, 0.0922250747680664, 0.10212492942810059, 0.11983704566955566, 0.1128089427947998, 0.13935494422912598, 0.12748098373413086, 0.13879609107971191, 0.1859588623046875, 0.14890193939208984, 0.15740394592285156, 0.16302895545959473, 0.17653393745422363, 0.1802539825439453, 0.18121719360351562, 0.20098400115966797, 0.19684600830078125, 0.2023460865020752, 0.22677183151245117, 0.23773717880249023, 0.23118090629577637, 0.23914885520935059, 0.2430558204650879, 0.31301093101501465, 0.2789499759674072, 0.26804518699645996, 0.28461790084838867, 0.3887619972229004, 0.31357502937316895, 0.2947719097137451, 0.3141598701477051, 0.4249720573425293, 0.32235097885131836, 0.3292689323425293, 0.32982301712036133, 0.39942502975463867, 0.3410038948059082, 0.3757472038269043, 0.38117194175720215]" + "sizes = [\n", + " 3,\n", + " 4,\n", + " 5,\n", + " 6,\n", + " 7,\n", + " 8,\n", + " 9,\n", + " 10,\n", + " 11,\n", + " 12,\n", + " 13,\n", + " 14,\n", + " 15,\n", + " 16,\n", + " 17,\n", + " 18,\n", + " 19,\n", + " 20,\n", + " 21,\n", + " 22,\n", + " 23,\n", + " 24,\n", + " 25,\n", + " 26,\n", + " 27,\n", + " 28,\n", + " 29,\n", + " 30,\n", + " 31,\n", + " 32,\n", + " 33,\n", + " 34,\n", + " 35,\n", + " 36,\n", + " 37,\n", + " 38,\n", + " 39,\n", + " 40,\n", + " 41,\n", + " 42,\n", + " 43,\n", + " 44,\n", + " 45,\n", + " 46,\n", + " 47,\n", + " 48,\n", + " 49,\n", + " 50,\n", + "]\n", + "costs = {\n", + " \"einsum\": [\n", + " 7,\n", + " 9,\n", + " 13,\n", + " 15,\n", + " 19,\n", + " 21,\n", + " 25,\n", + " 27,\n", + " 31,\n", + " 33,\n", + " 37,\n", + " 39,\n", + " 43,\n", + " 45,\n", + " 49,\n", + " 51,\n", + " 55,\n", + " 57,\n", + " 61,\n", + " 63,\n", + " 67,\n", + " 69,\n", + " 73,\n", + " 75,\n", + " 79,\n", + " 81,\n", + " 85,\n", + " 87,\n", + " 91,\n", + " 93,\n", + " 97,\n", + " 99,\n", + " 103,\n", + " 105,\n", + " 109,\n", + " 111,\n", + " 115,\n", + " 117,\n", + " 121,\n", + " 123,\n", + " 127,\n", + " 129,\n", + " 133,\n", + " 135,\n", + " 139,\n", + " 141,\n", + " 145,\n", + " 147,\n", + " ],\n", + " \"tensordot\": [\n", + " 18,\n", + " 25,\n", + " 30,\n", + " 37,\n", + " 42,\n", + " 49,\n", + " 54,\n", + " 61,\n", + " 66,\n", + " 73,\n", + " 78,\n", + " 85,\n", + " 90,\n", + " 97,\n", + " 102,\n", + " 109,\n", + " 114,\n", + " 121,\n", + " 126,\n", + " 133,\n", + " 138,\n", + " 145,\n", + " 150,\n", + " 157,\n", + " 162,\n", + " 169,\n", + " 174,\n", + " 181,\n", + " 186,\n", + " 193,\n", + " 198,\n", + " 205,\n", + " 210,\n", + " 217,\n", + " 222,\n", + " 229,\n", + " 234,\n", + " 241,\n", + " 246,\n", + " 253,\n", + " 258,\n", + " 265,\n", + " 270,\n", + " 277,\n", + " 282,\n", + " 289,\n", + " 294,\n", + " 301,\n", + " ],\n", + " \"tensor\": [\n", + " 46,\n", + " 63,\n", + " 80,\n", + " 97,\n", + " 114,\n", + " 131,\n", + " 148,\n", + " 165,\n", + " 182,\n", + " 199,\n", + " 216,\n", + " 233,\n", + " 250,\n", + " 267,\n", + " 284,\n", + " 301,\n", + " 318,\n", + " 335,\n", + " 352,\n", + " 369,\n", + " 386,\n", + " 403,\n", + " 420,\n", + " 437,\n", + " 454,\n", + " 471,\n", + " 488,\n", + " 505,\n", + " 522,\n", + " 539,\n", + " 556,\n", + " 573,\n", + " 590,\n", + " 607,\n", + " 624,\n", + " 641,\n", + " 658,\n", + " 675,\n", + " 692,\n", + " 709,\n", + " 726,\n", + " 743,\n", + " 760,\n", + " 777,\n", + " 794,\n", + " 811,\n", + " 828,\n", + " 845,\n", + " ],\n", + "}\n", + "times1 = [\n", + " 0.02198004722595215,\n", + " 0.03037405014038086,\n", + " 0.03350090980529785,\n", + " 0.04224896430969238,\n", + " 0.04834318161010742,\n", + " 0.05909299850463867,\n", + " 0.06626009941101074,\n", + " 0.08351302146911621,\n", + " 0.09097099304199219,\n", + " 0.08897876739501953,\n", + " 0.09535503387451172,\n", + " 0.10136294364929199,\n", + " 0.13000011444091797,\n", + " 0.12712597846984863,\n", + " 0.13105392456054688,\n", + " 0.1476750373840332,\n", + " 0.14663481712341309,\n", + " 0.15439701080322266,\n", + " 0.15521693229675293,\n", + " 0.1650080680847168,\n", + " 0.1742238998413086,\n", + " 0.17893004417419434,\n", + " 0.18517208099365234,\n", + " 0.19159197807312012,\n", + " 0.20879316329956055,\n", + " 0.2737429141998291,\n", + " 0.23352789878845215,\n", + " 0.22190213203430176,\n", + " 0.23365497589111328,\n", + " 0.23900103569030762,\n", + " 0.2523791790008545,\n", + " 0.26091718673706055,\n", + " 0.2820899486541748,\n", + " 0.3140451908111572,\n", + " 0.28127598762512207,\n", + " 0.2906830310821533,\n", + " 0.34561610221862793,\n", + " 0.4711790084838867,\n", + " 0.3032550811767578,\n", + " 0.31789112091064453,\n", + " 0.34140491485595703,\n", + " 0.34586501121520996,\n", + " 0.3419170379638672,\n", + " 0.35588693618774414,\n", + " 0.36873412132263184,\n", + " 0.36976003646850586,\n", + " 0.3961608409881592,\n", + " 0.3883850574493408,\n", + "]\n", + "times2 = [\n", + " 0.022389888763427734,\n", + " 0.026437997817993164,\n", + " 0.03232693672180176,\n", + " 0.041667938232421875,\n", + " 0.05160379409790039,\n", + " 0.055931806564331055,\n", + " 0.07128310203552246,\n", + " 0.08053183555603027,\n", + " 0.08122706413269043,\n", + " 0.0810542106628418,\n", + " 0.0922250747680664,\n", + " 0.10212492942810059,\n", + " 0.11983704566955566,\n", + " 0.1128089427947998,\n", + " 0.13935494422912598,\n", + " 0.12748098373413086,\n", + " 0.13879609107971191,\n", + " 0.1859588623046875,\n", + " 0.14890193939208984,\n", + " 0.15740394592285156,\n", + " 0.16302895545959473,\n", + " 0.17653393745422363,\n", + " 0.1802539825439453,\n", + " 0.18121719360351562,\n", + " 0.20098400115966797,\n", + " 0.19684600830078125,\n", + " 0.2023460865020752,\n", + " 0.22677183151245117,\n", + " 0.23773717880249023,\n", + " 0.23118090629577637,\n", + " 0.23914885520935059,\n", + " 0.2430558204650879,\n", + " 0.31301093101501465,\n", + " 0.2789499759674072,\n", + " 0.26804518699645996,\n", + " 0.28461790084838867,\n", + " 0.3887619972229004,\n", + " 0.31357502937316895,\n", + " 0.2947719097137451,\n", + " 0.3141598701477051,\n", + " 0.4249720573425293,\n", + " 0.32235097885131836,\n", + " 0.3292689323425293,\n", + " 0.32982301712036133,\n", + " 0.39942502975463867,\n", + " 0.3410038948059082,\n", + " 0.3757472038269043,\n", + " 0.38117194175720215,\n", + "]" ] }, { @@ -94,7 +692,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot('DBN')" + "plot(\"DBN\")" ] }, { From d70b4105f977dc0f9e9353ece92554fd02c1ed59 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 10:40:55 -0400 Subject: [PATCH 02/16] Fix ProvenanceTensor --- pyro/ops/provenance.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py index b1ec32d9a1..51a1ec9ecb 100644 --- a/pyro/ops/provenance.py +++ b/pyro/ops/provenance.py @@ -46,9 +46,7 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs): assert not isinstance(data, ProvenanceTensor) if not provenance: return data - instance = torch.Tensor.__new__(cls) - instance.__init__(data, provenance) - return instance + return super().__new__(cls) def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) From ce6f704173ef461cf673ea98d2256ff5a573b8df Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 11:00:00 -0400 Subject: [PATCH 03/16] Replace torch.triangular_solve -> torch.linalg.solve_triangular --- pyro/contrib/gp/models/sgpr.py | 8 ++++---- pyro/contrib/gp/util.py | 4 ++-- pyro/distributions/multivariate_studentt.py | 9 +++------ pyro/distributions/omt_mvn.py | 4 +++- .../transforms/generalized_channel_permute.py | 4 ++-- .../transforms/lower_cholesky_affine.py | 6 +++--- pyro/infer/autoguide/gaussian.py | 6 +++--- pyro/infer/mcmc/adaptation.py | 2 +- pyro/ops/arrowhead.py | 2 +- pyro/ops/gamma_gaussian.py | 12 +++++------- pyro/ops/tensor_utils.py | 8 ++++---- 11 files changed, 31 insertions(+), 34 deletions(-) diff --git a/pyro/contrib/gp/models/sgpr.py b/pyro/contrib/gp/models/sgpr.py index 1ca509240b..6e1010933b 100644 --- a/pyro/contrib/gp/models/sgpr.py +++ b/pyro/contrib/gp/models/sgpr.py @@ -146,7 +146,7 @@ def model(self): Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal Luu = torch.linalg.cholesky(Kuu) Kuf = self.kernel(self.Xu, self.X) - W = Kuf.triangular_solve(Luu, upper=False)[0].t() + W = torch.linalg.solve_triangular(Luu, Kuf, upper=False).t() D = self.noise.expand(N) if self.approx == "FITC" or self.approx == "VFE": @@ -227,7 +227,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): Kuf = self.kernel(self.Xu, self.X) - W = Kuf.triangular_solve(Luu, upper=False)[0] + W = torch.linalg.solve_triangular(Luu, Kuf, upper=False) D = self.noise.expand(N) if self.approx == "FITC": Kffdiag = self.kernel(self.X, diag=True) @@ -247,9 +247,9 @@ def forward(self, Xnew, full_cov=False, noiseless=True): # End caching ---------- Kus = self.kernel(self.Xu, Xnew) - Ws = Kus.triangular_solve(Luu, upper=False)[0] + Ws = torch.linalg.solve_triangular(Luu, Kus, upper=False) pack = torch.cat((W_Dinv_y, Ws), dim=1) - Linv_pack = pack.triangular_solve(L, upper=False)[0] + Linv_pack = torch.linalg.solve_triangular(L, pack, upper=False) # unpack Linv_W_Dinv_y = Linv_pack[:, : W_Dinv_y.shape[1]] Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1] :] diff --git a/pyro/contrib/gp/util.py b/pyro/contrib/gp/util.py index 4309ec96f2..582cd28f43 100644 --- a/pyro/contrib/gp/util.py +++ b/pyro/contrib/gp/util.py @@ -107,7 +107,7 @@ def conditional( if whiten: v_2D = f_loc_2D - W = Kfs.triangular_solve(Lff, upper=False)[0].t() + W = torch.linalg.solve_triangular(Lff, Kfs, upper=False).t() if f_scale_tril is not None: S_2D = f_scale_tril_2D else: @@ -115,7 +115,7 @@ def conditional( if f_scale_tril is not None: pack = torch.cat((pack, f_scale_tril_2D), dim=1) - Lffinv_pack = pack.triangular_solve(Lff, upper=False)[0] + Lffinv_pack = torch.linalg.solve_triangular(Lff, pack, upper=False) # unpack v_2D = Lffinv_pack[:, : f_loc_2D.size(1)] W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t() diff --git a/pyro/distributions/multivariate_studentt.py b/pyro/distributions/multivariate_studentt.py index 895ff66182..c71b631768 100644 --- a/pyro/distributions/multivariate_studentt.py +++ b/pyro/distributions/multivariate_studentt.py @@ -106,12 +106,9 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) n = self.loc.size(-1) - y = ( - (value - self.loc) - .unsqueeze(-1) - .triangular_solve(self.scale_tril, upper=False) - .solution.squeeze(-1) - ) + y = torch.linalg.solve_triangular( + self.scale_tril, (value - self.loc).unsqueeze(-1), upper=False + ).squeeze(-1) Z = ( self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + 0.5 * n * self.df.log() diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index c87d52f561..a7e0a6a18b 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -59,7 +59,9 @@ def backward(ctx, grad_output): loc_grad = sum_leftmost(grad_output, -1) identity = eye_like(g, dim) - R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0] + R_inv = torch.linalg.solve_triangular( + L.t(), identity, transpose=False, upper=True + ) z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) diff --git a/pyro/distributions/transforms/generalized_channel_permute.py b/pyro/distributions/transforms/generalized_channel_permute.py index e56fe72f30..c04023c3ce 100644 --- a/pyro/distributions/transforms/generalized_channel_permute.py +++ b/pyro/distributions/transforms/generalized_channel_permute.py @@ -86,10 +86,10 @@ def _inverse(self, y): LUx = (y_flat.unsqueeze(-3) * self.permutation.T.unsqueeze(-1)).sum(-2) # Solve L(Ux) = P^1y - Ux, _ = torch.triangular_solve(LUx, self.L, upper=False) + Ux = torch.linalg.solve_triangular(self.L, LUx, upper=False) # Solve Ux = (PL)^-1y - x, _ = torch.triangular_solve(Ux, self.U) + x = torch.linalg.solve_triangular(self.U, Ux) # Unflatten x (works when context variable has batch dim) return x.reshape(x.shape[:-1] + y.shape[-2:]) diff --git a/pyro/distributions/transforms/lower_cholesky_affine.py b/pyro/distributions/transforms/lower_cholesky_affine.py index 5a17eeb179..188a9aed08 100644 --- a/pyro/distributions/transforms/lower_cholesky_affine.py +++ b/pyro/distributions/transforms/lower_cholesky_affine.py @@ -57,9 +57,9 @@ def _inverse(self, y): Inverts y => x. """ - return torch.triangular_solve( - (y - self.loc).unsqueeze(-1), self.scale_tril, upper=False, transpose=False - )[0].squeeze(-1) + return torch.linalg.solve_triangular( + self.scale_tril, (y - self.loc).unsqueeze(-1), upper=False, transpose=False + ).squeeze(-1) def log_abs_det_jacobian(self, x, y): """ diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 8bdc05dc95..a127d51d82 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -553,9 +553,9 @@ def _precision_to_scale_tril(P): # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) - L = torch.triangular_solve( - torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False - )[0] + L = torch.solve_triangular( + L_inv, torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), upper=False + ) return L diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index 6fe082cf3e..09dbe6872d 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -232,7 +232,7 @@ def _triu_inverse(x): return x.reciprocal() else: identity = torch.eye(x.size(-1), dtype=x.dtype, device=x.device) - return torch.triangular_solve(identity, x, upper=True)[0] + return torch.linalg.solve_triangular(x, identity, upper=True) class BlockMassMatrix: diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index e97c8872b3..70c07bcb99 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -74,7 +74,7 @@ def triu_inverse(x): B_Dinv = B / x.bottom_diag.unsqueeze(-2) identity = torch.eye(head_size, dtype=A.dtype, device=A.device) - top_left = torch.triangular_solve(identity, A, upper=True)[0] + top_left = torch.linalg.solve_triangular(A, identity, upper=True) top_right = -top_left.matmul(B_Dinv) # complexity: head_size^2 x N top = torch.cat([top_left, top_right], -1) bottom_diag = x.bottom_diag.reciprocal() diff --git a/pyro/ops/gamma_gaussian.py b/pyro/ops/gamma_gaussian.py index 66807aa577..edffe0755e 100644 --- a/pyro/ops/gamma_gaussian.py +++ b/pyro/ops/gamma_gaussian.py @@ -275,13 +275,13 @@ def marginalize(self, left=0, right=0): P_ba = self.precision[..., b, a] P_bb = self.precision[..., b, b] P_b = torch.linalg.cholesky(P_bb) - P_a = P_ba.triangular_solve(P_b, upper=False).solution + P_a = torch.linalg.solve_triangular(P_b, P_ba, upper=False) P_at = P_a.transpose(-1, -2) precision = P_aa - P_at.matmul(P_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] - b_tmp = info_b.unsqueeze(-1).triangular_solve(P_b, upper=False).solution + b_tmp = torch.linalg.solve_triangular(P_b, info_b.unsqueeze(-1), upper=False) info_vec = info_a if n_b < n: info_vec = info_vec - P_at.matmul(b_tmp).squeeze(-1) @@ -320,11 +320,9 @@ def event_logsumexp(self): """ n = self.dim() chol_P = torch.linalg.cholesky(self.precision) - chol_P_u = ( - self.info_vec.unsqueeze(-1) - .triangular_solve(chol_P, upper=False) - .solution.squeeze(-1) - ) + chol_P_u = torch.linalg.solve_triangular( + chol_P, self.info_vec.unsqueeze(-1), upper=False + ).squeeze(-1) u_P_u = chol_P_u.pow(2).sum(-1) # considering GammaGaussian as a Gaussian with precision = s * precision, info_vec = s * info_vec, # marginalize x variable, we get diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 01eed87f46..581b91bdb0 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -420,15 +420,15 @@ def matvecmul(x, y): def triangular_solve(x, y, upper=False, transpose=False): if y.size(-1) == 1: return x / y - return x.triangular_solve(y, upper=upper, transpose=transpose).solution + return torch.linalg.solve_triangular(y, x, upper=upper, transpose=transpose) def precision_to_scale_tril(P): Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) - L = torch.triangular_solve( - torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False - )[0] + L = torch.linalg.solve_triangular( + L_inv, torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), upper=False + ) return L From 8148d58a4aefce5c12177cff05c4d5f6a015a234 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 11:56:08 -0400 Subject: [PATCH 04/16] More fixes to torch.linalg.solve_triangular --- .../transforms/lower_cholesky_affine.py | 2 +- pyro/ops/tensor_utils.py | 4 ++- setup.cfg | 2 ++ tests/ops/test_linalg.py | 26 ++++++++++++++++++- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/transforms/lower_cholesky_affine.py b/pyro/distributions/transforms/lower_cholesky_affine.py index 188a9aed08..c078e0ef53 100644 --- a/pyro/distributions/transforms/lower_cholesky_affine.py +++ b/pyro/distributions/transforms/lower_cholesky_affine.py @@ -58,7 +58,7 @@ def _inverse(self, y): Inverts y => x. """ return torch.linalg.solve_triangular( - self.scale_tril, (y - self.loc).unsqueeze(-1), upper=False, transpose=False + self.scale_tril, (y - self.loc).unsqueeze(-1), upper=False ).squeeze(-1) def log_abs_det_jacobian(self, x, y): diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 581b91bdb0..3eb4a85b0f 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -420,7 +420,9 @@ def matvecmul(x, y): def triangular_solve(x, y, upper=False, transpose=False): if y.size(-1) == 1: return x / y - return torch.linalg.solve_triangular(y, x, upper=upper, transpose=transpose) + if transpose: + y = y.transpose(-1, -2) + return torch.linalg.solve_triangular(y, x, upper=upper) def precision_to_scale_tril(P): diff --git a/setup.cfg b/setup.cfg index 4e98b0f2fc..8cef20b674 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,10 @@ filterwarnings = error ignore:numpy.dtype size changed:RuntimeWarning ignore:Mixed memory format inputs detected:UserWarning ignore:Setting attributes on ParameterDict:UserWarning + ignore:Creating a tensor from a list of numpy.ndarrays is extremely slow ignore::DeprecationWarning ignore:CUDA initialization:UserWarning + ignore:__floordiv__ is deprecated:UserWarning ignore:floor_divide is deprecated:UserWarning ignore:torch.tensor results are registered as constants in the trace once::DeprecationWarning diff --git a/tests/ops/test_linalg.py b/tests/ops/test_linalg.py index 5b4497567d..4287c9ad10 100644 --- a/tests/ops/test_linalg.py +++ b/tests/ops/test_linalg.py @@ -5,7 +5,7 @@ import torch from pyro.ops.linalg import rinverse -from tests.common import assert_equal +from tests.common import assert_close, assert_equal @pytest.mark.parametrize( @@ -35,3 +35,27 @@ def test_sym_rinverse(A, use_sym): batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8) + + +# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular +@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated") +@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"]) +def test_triangular_solve(upper): + b = torch.randn(5, 6) + A = torch.randn(5, 5) + A = A.triu() if upper else A.tril() + expected = torch.triangular_solve(b, A, upper=upper).solution + actual = torch.linalg.solve_triangular(A, b, upper=upper) + assert_close(actual, expected) + + +# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular +@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated") +@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"]) +def test_triangular_solve_transpose(upper): + b = torch.randn(5, 6) + A = torch.randn(5, 5) + A = A.triu() if upper else A.tril() + expected = torch.triangular_solve(b, A.T, upper=upper, transpose=True).solution + actual = torch.linalg.solve_triangular(A.T, b, upper=upper) + assert_close(actual, expected) From 9dabb2a8dd2019cdf8642c7d9fd3856c744f1f45 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 12:00:57 -0400 Subject: [PATCH 05/16] Bump torch version --- .github/workflows/ci.yml | 12 ++++++------ docs/source/conf.py | 2 +- setup.py | 6 ++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0fd67dd5d7..b2247e5dd7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install -r docs/requirements.txt pip freeze @@ -82,7 +82,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install --upgrade coveralls pip freeze @@ -116,7 +116,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install --upgrade coveralls pip freeze @@ -150,7 +150,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install --upgrade coveralls pip freeze @@ -182,7 +182,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install --upgrade coveralls pip freeze @@ -214,7 +214,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[test] pip install -e .[funsor] pip install --upgrade coveralls diff --git a/docs/source/conf.py b/docs/source/conf.py index 979c5998ba..3ff5c7c3c1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -222,6 +222,6 @@ def setup(app): if "READTHEDOCS" in os.environ: os.system("pip install numpy") os.system( - "pip install torch==1.9.0+cpu torchvision==0.10.0+cpu " + "pip install torch==1.11.0+cpu torchvision==0.12.0+cpu " "-f https://download.pytorch.org/whl/torch_stable.html" ) diff --git a/setup.py b/setup.py index f11dec61dd..4fe304695a 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ "jupyter>=1.0.0", "graphviz>=0.8", "matplotlib>=1.3", - "torchvision>=0.10.0", + "torchvision>=0.12.0", "visdom>=0.1.4", "pandas", "pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125 @@ -98,7 +98,7 @@ "numpy>=1.7", "opt_einsum>=2.3.2", "pyro-api>=0.1.1", - "torch>=1.9.0", + "torch>=1.11.0", "tqdm>=4.36", ], extras_require={ @@ -151,6 +151,8 @@ "Operating System :: MacOS :: MacOS X", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], # yapf ) From adf5fcfa13f3570e60544d8d07afe63fdae16cf2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 12:31:40 -0400 Subject: [PATCH 06/16] Bump Python version 3.6 -> 3.7 --- .github/workflows/ci.yml | 16 ++++++++-------- .readthedocs.yml | 2 +- setup.py | 3 +-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2247e5dd7..e9e21f0967 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -39,7 +39,7 @@ jobs: needs: lint strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -67,7 +67,7 @@ jobs: needs: docs strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -99,7 +99,7 @@ jobs: needs: docs strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -135,7 +135,7 @@ jobs: needs: docs strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -167,7 +167,7 @@ jobs: needs: docs strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -199,7 +199,7 @@ jobs: needs: docs strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -233,7 +233,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: [3.6] + python-version: [3.7] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.readthedocs.yml b/.readthedocs.yml index e468fcfb94..321000738b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -10,6 +10,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.6 + version: 3.7 install: - requirements: docs/requirements.txt diff --git a/setup.py b/setup.py index 4fe304695a..c9ad88b307 100644 --- a/setup.py +++ b/setup.py @@ -139,7 +139,7 @@ "funsor[torch]==0.4.2", ], }, - python_requires=">=3.6", + python_requires=">=3.7", keywords="machine learning statistics probabilistic programming bayesian modeling pytorch", license="Apache 2.0", classifiers=[ @@ -149,7 +149,6 @@ "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", From f8f1c7f135580d45abb9543d5cf0661d7ff95460 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 16:19:52 -0400 Subject: [PATCH 07/16] Fix some tests --- pyro/distributions/omt_mvn.py | 4 +--- .../transforms/generalized_channel_permute.py | 2 +- pyro/infer/autoguide/gaussian.py | 4 +++- pyro/ops/linalg.py | 10 ++++++++++ pyro/poutine/collapse_messenger.py | 2 ++ tests/poutine/test_poutines.py | 15 +++------------ tests/test_examples.py | 6 +++++- 7 files changed, 25 insertions(+), 18 deletions(-) diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index a7e0a6a18b..d997f806f5 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -59,9 +59,7 @@ def backward(ctx, grad_output): loc_grad = sum_leftmost(grad_output, -1) identity = eye_like(g, dim) - R_inv = torch.linalg.solve_triangular( - L.t(), identity, transpose=False, upper=True - ) + R_inv = torch.linalg.solve_triangular(L.t(), identity, upper=True) z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) diff --git a/pyro/distributions/transforms/generalized_channel_permute.py b/pyro/distributions/transforms/generalized_channel_permute.py index c04023c3ce..0231c63a46 100644 --- a/pyro/distributions/transforms/generalized_channel_permute.py +++ b/pyro/distributions/transforms/generalized_channel_permute.py @@ -89,7 +89,7 @@ def _inverse(self, y): Ux = torch.linalg.solve_triangular(self.L, LUx, upper=False) # Solve Ux = (PL)^-1y - x = torch.linalg.solve_triangular(self.U, Ux) + x = torch.linalg.solve_triangular(self.U, Ux, upper=True) # Unflatten x (works when context variable has batch dim) return x.reshape(x.shape[:-1] + y.shape[-2:]) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index a127d51d82..a6e7bb7bfc 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -17,6 +17,7 @@ from pyro.distributions import constraints from pyro.infer.inspect import get_dependencies, is_sample_site from pyro.nn.module import PyroModule, PyroParam +from pyro.ops.linalg import ignore_torch_deprecation_warnings from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample @@ -553,12 +554,13 @@ def _precision_to_scale_tril(P): # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) - L = torch.solve_triangular( + L = torch.linalg.solve_triangular( L_inv, torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), upper=False ) return L +@ignore_torch_deprecation_warnings() def _try_possibly_intractable(fn, *args, **kwargs): # Convert ValueError into NotImplementedError. try: diff --git a/pyro/ops/linalg.py b/pyro/ops/linalg.py index d6d247c1fd..f3b3af798d 100644 --- a/pyro/ops/linalg.py +++ b/pyro/ops/linalg.py @@ -2,10 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import math +import warnings +from contextlib import contextmanager import torch +@contextmanager +def ignore_torch_deprecation_warnings(): + with warnings.catch_warnings(): + # Ignore deprecation warning until funsor updates to torch>=1.10. + warnings.filterwarnings("ignore", "torch.triangular_solve is deprecated") + yield + + def rinverse(M, sym=False): """Matrix inversion of rightmost dimensions (batched). diff --git a/pyro/poutine/collapse_messenger.py b/pyro/poutine/collapse_messenger.py index 826a7f957e..6206894943 100644 --- a/pyro/poutine/collapse_messenger.py +++ b/pyro/poutine/collapse_messenger.py @@ -5,6 +5,7 @@ import pyro from pyro.distributions.distribution import COERCIONS +from pyro.ops.linalg import ignore_torch_deprecation_warnings from pyro.poutine.util import site_is_subsample from .runtime import _PYRO_STACK @@ -166,6 +167,7 @@ def __exit__(self, *args): name, log_prob, _, _ = self._get_log_prob() pyro.factor(name, log_prob.data) + @ignore_torch_deprecation_warnings() def _get_log_prob(self): # Convert delayed statements to pyro.factor() reduced_vars = [] diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index cfea0a1822..2f57536213 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -893,10 +893,7 @@ def model(v): pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) tr = poutine.trace(model).get_trace(torch.tensor(2.0)) - exp_msg = ( - r"Error while computing log_prob at site 'test_site':\s*" - r"The value argument must be within the support" - ) + exp_msg = r"Error while computing log_prob at site 'test_site':.*" with pytest.raises(ValueError, match=exp_msg): tr.compute_log_prob() @@ -906,10 +903,7 @@ def model(v): pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) tr = poutine.trace(model).get_trace(torch.tensor(2.0)) - exp_msg = ( - r"Error while computing log_prob_sum at site 'test_site':\s*" - r"The value argument must be within the support" - ) + exp_msg = r"Error while computing log_prob_sum at site 'test_site':.*" with pytest.raises(ValueError, match=exp_msg): tr.log_prob_sum() @@ -919,10 +913,7 @@ def guide(v): pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) tr = poutine.trace(guide).get_trace(torch.tensor(2.0)) - exp_msg = ( - r"Error while computing score_parts at site 'test_site':\s*" - r"The value argument must be within the support" - ) + exp_msg = r"Error while computing score_parts at site 'test_site':.*" with pytest.raises(ValueError, match=exp_msg): tr.compute_score_parts() diff --git a/tests/test_examples.py b/tests/test_examples.py index e4e65bdd4f..b46e5cedca 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -124,7 +124,11 @@ "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 ", - "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ", + xfail_param( + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ", + reason="OOM", + run=False, + ), "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 ", From 042f5401e952741fcf9f4386e8fd61d138bc047f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 17:28:49 -0400 Subject: [PATCH 08/16] Decrease tolerance on stable tests --- tests/infer/reparam/test_hmm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/infer/reparam/test_hmm.py b/tests/infer/reparam/test_hmm.py index 719210377e..2b210dc957 100644 --- a/tests/infer/reparam/test_hmm.py +++ b/tests/infer/reparam/test_hmm.py @@ -190,8 +190,7 @@ def get_hmm_moments(samples): delta = samples - loc cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).sqrt().mean(0) scale = cov.diagonal(dim1=-2, dim2=-1) - sigma = scale.sqrt() - corr = cov / (sigma.unsqueeze(-1) * sigma.unsqueeze(-2)) + corr = cov / (scale.unsqueeze(-1) * scale.unsqueeze(-2)).sqrt() return loc, scale, corr @@ -225,7 +224,7 @@ def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) - assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05) + assert_close(actual_scale, expected_scale, atol=0.2, rtol=0.2) assert_close(actual_corr, expected_corr, atol=0.01) From 302f4cbaf137f442288e376916c0e6d73a58e9a0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 17:32:00 -0400 Subject: [PATCH 09/16] Remove obsolete xfail --- tests/infer/test_jit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index b459ef6a74..a19080b20d 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -173,7 +173,6 @@ def f(y, mask): assert_equal(jit_f(y, mask), f(y, mask)) -@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11614") def test_scatter(): def make_one_hot(x, i): return torch.zeros_like(x).scatter(-1, i.unsqueeze(-1), 1.0) From 28d37b35b924304377b8ee5268a8bfed00f69fe5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 19:39:46 -0400 Subject: [PATCH 10/16] Resolve #3032 --- tutorial/source/gp.ipynb | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tutorial/source/gp.ipynb b/tutorial/source/gp.ipynb index b1cd36b3a0..589117f3b4 100644 --- a/tutorial/source/gp.ipynb +++ b/tutorial/source/gp.ipynb @@ -1205,14 +1205,7 @@ "source": [ "xs = torch.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, steps=100)\n", "ys = torch.linspace(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5, steps=100)\n", - "try:\n", - " # torch 1.10 or greater defaults to using indexing\n", - " xx, yy = torch.meshgrid(xs, ys, indexing=\"xy\")\n", - "except:\n", - " xx, yy = torch.meshgrid(xs, ys)\n", - " xx = xx.t()\n", - " yy = yy.t()\n", - "\n", + "xx, yy = torch.meshgrid(xs, ys, indexing=\"xy\")\n", "\n", "with torch.no_grad():\n", " mean, var = model(torch.vstack((xx.ravel(), yy.ravel())).t())\n", From 240e7d2b849c4734100fdc0e922bb596d0f09b8c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 19:54:51 -0400 Subject: [PATCH 11/16] Fix solve_triangular in ops.gaussian, revert test weakening --- pyro/ops/tensor_utils.py | 1 + tests/infer/reparam/test_hmm.py | 5 +++-- tests/ops/test_linalg.py | 6 ++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 3eb4a85b0f..7efc847d59 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -422,6 +422,7 @@ def triangular_solve(x, y, upper=False, transpose=False): return x / y if transpose: y = y.transpose(-1, -2) + upper = not upper return torch.linalg.solve_triangular(y, x, upper=upper) diff --git a/tests/infer/reparam/test_hmm.py b/tests/infer/reparam/test_hmm.py index 2b210dc957..719210377e 100644 --- a/tests/infer/reparam/test_hmm.py +++ b/tests/infer/reparam/test_hmm.py @@ -190,7 +190,8 @@ def get_hmm_moments(samples): delta = samples - loc cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).sqrt().mean(0) scale = cov.diagonal(dim1=-2, dim2=-1) - corr = cov / (scale.unsqueeze(-1) * scale.unsqueeze(-2)).sqrt() + sigma = scale.sqrt() + corr = cov / (sigma.unsqueeze(-1) * sigma.unsqueeze(-2)) return loc, scale, corr @@ -224,7 +225,7 @@ def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) - assert_close(actual_scale, expected_scale, atol=0.2, rtol=0.2) + assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05) assert_close(actual_corr, expected_corr, atol=0.01) diff --git a/tests/ops/test_linalg.py b/tests/ops/test_linalg.py index 4287c9ad10..813666bc4d 100644 --- a/tests/ops/test_linalg.py +++ b/tests/ops/test_linalg.py @@ -43,7 +43,6 @@ def test_sym_rinverse(A, use_sym): def test_triangular_solve(upper): b = torch.randn(5, 6) A = torch.randn(5, 5) - A = A.triu() if upper else A.tril() expected = torch.triangular_solve(b, A, upper=upper).solution actual = torch.linalg.solve_triangular(A, b, upper=upper) assert_close(actual, expected) @@ -55,7 +54,6 @@ def test_triangular_solve(upper): def test_triangular_solve_transpose(upper): b = torch.randn(5, 6) A = torch.randn(5, 5) - A = A.triu() if upper else A.tril() - expected = torch.triangular_solve(b, A.T, upper=upper, transpose=True).solution - actual = torch.linalg.solve_triangular(A.T, b, upper=upper) + expected = torch.triangular_solve(b, A, upper=upper, transpose=True).solution + actual = torch.linalg.solve_triangular(A.T, b, upper=not upper) assert_close(actual, expected) From ff40b1116432163ae71bc431f3b45a4a550e2363 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 20:05:44 -0400 Subject: [PATCH 12/16] Fix catching of singular matrices in hmc --- pyro/ops/integrator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 8fafd212d7..750090248f 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -76,7 +76,7 @@ def potential_grad(potential_fn, z): potential_energy = potential_fn(z) # deal with singular matrices except RuntimeError as e: - if "singular U" in str(e) or "input is not positive-definite" in str(e): + if "singular" in str(e) or "input is not positive-definite" in str(e): grads = {k: v.new_zeros(v.shape) for k, v in z.items()} return grads, z_nodes[0].new_tensor(float("nan")) else: From d6b2815a7492291afc4df19e8819b3286d08f94f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 16 Mar 2022 21:24:02 -0400 Subject: [PATCH 13/16] xfail funsor tests --- tests/contrib/funsor/test_enum_funsor.py | 11 +++++++++++ tests/ops/test_gaussian.py | 1 + tests/test_examples.py | 12 ++++++------ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index e8dbc60c69..1dd6915610 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -264,6 +264,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pytest.mark.parametrize( "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] @@ -498,6 +499,7 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pytest.mark.parametrize( "outer_obs,inner_obs", [(False, True), (True, False), (True, True)] @@ -640,6 +642,7 @@ def guide_iplate(): _check_loss_and_grads(expected_loss, actual_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("enumerate1", ["parallel", "sequential"]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_6(enumerate1): @@ -702,6 +705,7 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_7(scale): @@ -873,6 +877,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_2(scale): @@ -929,6 +934,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_3(scale): @@ -981,6 +987,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_4(scale): @@ -1040,6 +1047,7 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_5(scale): @@ -1103,6 +1111,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_6(scale): @@ -1241,6 +1250,7 @@ def guide(data): elbo.differentiable_loss(model_plate_plate, guide, data) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_7(scale): @@ -1393,6 +1403,7 @@ def guide(data): _check_loss_and_grads(loss_iplate_iplate, loss_plate_plate) +@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3046") @pytest.mark.parametrize("guide_scale", [1]) @pytest.mark.parametrize("model_scale", [1]) @pytest.mark.parametrize( diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index b16fcfbfdc..fa5924b128 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -424,6 +424,7 @@ def test_gaussian_tensordot( @pytest.mark.stage("funsor") @pytest.mark.parametrize("batch_shape", [(), (5,), (4, 2)], ids=str) +@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated") def test_gaussian_funsor(batch_shape): # This tests sample distribution, rsample gradients, log_prob, and log_prob # gradients for both Pyro's and Funsor's Gaussian. diff --git a/tests/test_examples.py b/tests/test_examples.py index b46e5cedca..931a61ac14 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -124,11 +124,7 @@ "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 ", - xfail_param( - "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ", - reason="OOM", - run=False, - ), + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 ", @@ -292,7 +288,11 @@ def xfail_jit(*args, **kwargs): "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --funsor", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --funsor", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --funsor", - "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --funsor", + xfail_param( + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --funsor", + reason="https://github.com/pyro-ppl/pyro/issues/3046", + run=False, + ), "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --funsor", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization --funsor", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --funsor", From 977a921364e892cae527b4ef8a9c756d4aa2913e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 17 Mar 2022 09:51:00 -0400 Subject: [PATCH 14/16] Work around pandas 1.3 bug --- examples/baseball.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/baseball.py b/examples/baseball.py index 8523b7ffd4..30d6164c4f 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -164,6 +164,7 @@ def get_summary_table( if site_summary["mean"].shape: site_df = pd.DataFrame(site_summary, index=player_names) else: + site_summary = {k: float(v) for k, v in site_summary.items()} site_df = pd.DataFrame(site_summary, index=[0]) if not diagnostics: site_df = site_df.drop(["n_eff", "r_hat"], axis=1) From b7e6580b6505ee2fd3f792331a22553460de0be9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 17 Mar 2022 09:53:36 -0400 Subject: [PATCH 15/16] Allow mypy to install missing stubs --- Makefile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 3c0e8ef221..f98e2cae72 100644 --- a/Makefile +++ b/Makefile @@ -22,9 +22,7 @@ lint: FORCE black --check *.py pyro examples tests scripts profiler isort --check . python scripts/update_headers.py --check - mypy pyro - # mypy examples # FIXME - mypy scripts + mypy --install-types --non-interactive pyro scripts license: FORCE python scripts/update_headers.py From e2a25c7934baac6017aa444ee7702c689d0d8ea9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 17 Mar 2022 10:55:40 -0400 Subject: [PATCH 16/16] Clarify triangular_solve test --- tests/ops/test_linalg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/ops/test_linalg.py b/tests/ops/test_linalg.py index 813666bc4d..f476a9a1f2 100644 --- a/tests/ops/test_linalg.py +++ b/tests/ops/test_linalg.py @@ -46,6 +46,8 @@ def test_triangular_solve(upper): expected = torch.triangular_solve(b, A, upper=upper).solution actual = torch.linalg.solve_triangular(A, b, upper=upper) assert_close(actual, expected) + A = A.triu() if upper else A.tril() + assert_close(A @ actual, b) # Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular @@ -57,3 +59,5 @@ def test_triangular_solve_transpose(upper): expected = torch.triangular_solve(b, A, upper=upper, transpose=True).solution actual = torch.linalg.solve_triangular(A.T, b, upper=not upper) assert_close(actual, expected) + A = A.triu() if upper else A.tril() + assert_close(A.T @ actual, b)