diff --git a/docs/tutorials/data/US_Inflation_rates_prepared.csv b/docs/tutorials/data/US_Inflation_rates_prepared.csv new file mode 100644 index 000000000..1c75d4c80 --- /dev/null +++ b/docs/tutorials/data/US_Inflation_rates_prepared.csv @@ -0,0 +1,905 @@ +value,target +-1.1512732989231769,0 +-1.149549704381617,0 +-1.1448713763402407,0 +-1.1448713763402407,0 +-1.1454869458193693,0 +-1.1438864651736351,0 +-1.1420397567362497,0 +-1.1399468205072132,0 +-1.1345298090908824,0 +-1.1336680118201026,0 +-1.1318213033827171,0 +-1.127512317028818,0 +-1.1241882418415243,0 +-1.1243113557373499,0 +-1.1264042919663866,0 +-1.1224646472999644,0 +-1.1201254832792762,0 +-1.1184018887377165,0 +-1.1153240413420742,0 +-1.1149546996545971,0 +-1.115816496925377,0 +-1.1164320664045053,0 +-1.1182787748418908,0 +-1.1196330276959734,0 +-1.1201254832792762,0 +-1.1213566222375333,0 +-1.1213566222375333,0 +-1.1212335083417075,0 +-1.1213566222375333,0 +-1.1212335083417075,0 +-1.1239420140498728,0 +-1.1239420140498728,0 +-1.1233264445707443,0 +-1.1243113557373499,0 +-1.1239420140498728,0 +-1.125050039112304,0 +-1.1262811780705608,0 +-1.125050039112304,0 +-1.124680697424827,0 +-1.1245575835290014,0 +-1.123080216779093,0 +-1.1217259639250103,0 +-1.1193867999043219,0 +-1.117786319258588,0 +-1.1160627247170283,0 +-1.1140929023838173,0 +-1.1128617634255602,0 +-1.108183435384184,0 +-1.1032588795511562,0 +-1.097718754239,0 +-1.0971031847598716,0 +-1.0966107291765685,0 +-1.095748931905789,0 +-1.096487615280743,0 +-1.0967338430723943,0 +-1.0973494125515229,0 +-1.0952564763224861,0 +-1.093655995676752,0 +-1.0916861733435408,0 +-1.0898394649061556,0 +-1.0900856926978069,0 +-1.0905781482811097,0 +-1.090824376072761,0 +-1.0899625788019813,0 +-1.0898394649061556,0 +-1.0891007815312013,0 +-1.087254073093816,0 +-1.0871309591979903,0 +-1.0878696425729444,0 +-1.0871309591979903,0 +-1.0871309591979903,0 +-1.086884731406339,0 +-1.0877465286771186,0 +-1.0883620981562472,0 +-1.0878696425729444,0 +-1.0871309591979903,0 +-1.0870078453021645,0 +-1.0861460480313847,0 +-1.0858998202397332,0 +-1.085161136864779,0 +-1.0846686812814763,0 +-1.0839299979065222,0 +-1.085161136864779,0 +-1.0849149090731276,0 +-1.084053111802348,0 +-1.0834375423232194,0 +-1.0841762256981737,0 +-1.0850380229689534,0 +-1.0841762256981737,0 +-1.084053111802348,0 +-1.0850380229689534,0 +-1.085161136864779,0 +-1.085653592448082,0 +-1.0867616175105133,0 +-1.086022934135559,0 +-1.0861460480313847,0 +-1.0861460480313847,0 +-1.0855304785522562,0 +-1.0858998202397332,0 +-1.0858998202397332,0 +-1.0861460480313847,0 +-1.086884731406339,0 +-1.0862691619272102,0 +-1.0867616175105133,0 +-1.085161136864779,0 +-1.0855304785522562,0 +-1.084791795177302,0 +-1.0849149090731276,0 +-1.0854073646564306,0 +-1.0850380229689534,0 +-1.0846686812814763,0 +-1.0841762256981737,0 +-1.0829450867399166,0 +-1.0814677199900085,0 +-1.0797441254484486,0 +-1.0794978976567973,0 +-1.0790054420734942,0 +-1.0770356197402833,0 +-1.0770356197402833,0 +-1.075558252990375,0 +-1.075065797407072,0 +-1.0734653167613382,0 +-1.072726633386384,0 +-1.0718648361156042,0 +-1.0710030388448244,0 +-1.0696487859907418,0 +-1.0686638748241362,0 +-1.0675558497617048,0 +-1.067063394178402,0 +-1.067063394178402,0 +-1.0659553691159709,0 +-1.0652166857410168,0 +-1.0631237495119799,0 +-1.0623850661370258,0 +-1.0602921299079888,0 +-1.059430332637209,0 +-1.059430332637209,0 +-1.0597996743246862,0 +-1.0600459021163375,0 +-1.059430332637209,0 +-1.0597996743246862,0 +-1.0597996743246862,0 +-1.0593072187413832,0 +-1.059060990949732,0 +-1.0585685353664291,0 +-1.058691649262255,0 +-1.059060990949732,0 +-1.0589378770539062,0 +-1.058199193678952,0 +-1.0573373964081723,0 +-1.0568449408248697,0 +-1.0564755991373924,0 +-1.0556138018666126,0 +-1.0543826629083555,0 +-1.0543826629083555,0 +-1.0536439795334014,0 +-1.0541364351167042,0 +-1.0536439795334014,0 +-1.0536439795334014,0 +-1.0520434988876672,0 +-1.0516741572001902,0 +-1.0511817016168876,0 +-1.0519203849918417,0 +-1.0511817016168876,0 +-1.0511817016168876,0 +-1.049458107075328,0 +-1.0490887653878507,0 +-1.0487194237003736,0 +-1.0483500820128966,0 +-1.0483500820128966,0 +-1.0483500820128966,0 +-1.0487194237003736,0 +-1.0483500820128966,0 +-1.0483500820128966,0 +-1.047365170846291,0 +-1.0471189430546397,0 +-1.046626487471337,0 +-1.046626487471337,0 +-1.046626487471337,0 +-1.0462571457838596,0 +-1.0458878040963826,0 +-1.0450260068256028,0 +-1.0442873234506487,0 +-1.043794867867346,0 +-1.0434255261798688,0 +-1.043794867867346,0 +-1.0436717539715203,0 +-1.042933070596566,0 +-1.0412094760550064,0 +-1.0417019316383092,0 +-1.0417019316383092,0 +-1.0417019316383092,0 +-1.040963248263355,0 +-1.040470792680052,0 +-1.040101450992575,0 +-1.040470792680052,0 +-1.040101450992575,0 +-1.0388703120343181,0 +-1.0378854008677125,0 +-1.0371467174927584,0 +-1.0375160591802355,0 +-1.0371467174927584,0 +-1.0367773758052814,0 +-1.0355462368470245,0 +-1.0348075534720702,0 +-1.0351768951595475,0 +-1.0348075534720702,0 +-1.0346844395762445,0 +-1.0343150978887674,0 +-1.0339457562012904,0 +-1.0338226423054648,0 +-1.0334533006179876,0 +-1.0330839589305105,0 +-1.0325915033472077,0 +-1.0314834782847766,0 +-1.0309910227014738,0 +-1.0306216810139965,0 +-1.0306216810139965,0 +-1.0302523393265195,0 +-1.0293905420557399,0 +-1.0281594030974828,0 +-1.0265589224517488,0 +-1.0269282641392259,0 +-1.027297605826703,0 +-1.026435808555923,0 +-1.026066466868446,0 +-1.024835327910189,0 +-1.023604188951932,0 +-1.023234847264455,0 +-1.0207725693479412,0 +-1.0195414303896844,0 +-1.0183102914314273,0 +-1.0174484941606472,0 +-1.0170791524731702,0 +-1.0162173552023903,0 +-1.0137550772858768,0 +-1.0125239383276197,0 +-1.0112927993693626,0 +-1.0109234576818855,0 +-1.0104310020985827,0 +-1.0106772298902342,0 +-1.0094460909319773,0 +-1.0094460909319773,0 +-1.0082149519737202,0 +-1.0082149519737202,0 +-1.0057526740572065,0 +-1.0045215350989496,0 +-1.0032903961406925,0 +-1.0020592571824356,0 +-1.0008281182241787,0 +-0.998365840307665,0 +-0.9971347013494078,0 +-0.9959035623911509,0 +-0.994672423432894,0 +-0.9934412844746371,0 +-0.9922101455163802,0 +-0.9909790065581232,0 +-0.9885167286416092,0 +-0.9860544507250956,0 +-0.9848233117668385,0 +-0.9835921728085815,0 +-0.9811298948920677,0 +-0.9798987559338108,0 +-0.9774364780172967,0 +-0.9762053390590398,0 +-0.974974200100783,0 +-0.971280783226012,0 +-0.9688185053094983,0 +-0.9675873663512414,0 +-0.9651250884347273,0 +-0.9626628105182136,0 +-0.9614316715599567,0 +-0.9589693936434426,0 +-0.9565071157269289,0 +-0.9540448378104149,0 +-0.9515825598939011,0 +-0.9491202819773873,0 +-0.9466580040608732,0 +-0.9441957261443594,0 +-0.9417334482278455,0 +-0.9405023092695886,0 +-0.9380400313530748,0 +-0.9368088923948179,0 +-0.9355777534365608,0 +-0.9331154755200469,0 +-0.9306531976035332,0 +-0.9281909196870191,0 +-0.9257286417705054,0 +-0.9244975028122485,0 +-0.9244975028122485,0 +-0.9232663638539914,0 +-0.9220352248957344,0 +-0.9195729469792207,0 +-0.9171106690627068,0 +-0.9158795301044497,0 +-0.9146483911461928,0 +-0.9134172521879359,0 +-0.9121861132296791,0 +-0.910954974271422,0 +-0.909723835313165,0 +-0.9084926963549081,0 +-0.9060304184383944,0 +-0.9060304184383944,0 +-0.9047992794801374,0 +-0.9035681405218803,0 +-0.9023370015636234,0 +-0.9011058626053665,0 +-0.8998747236471096,0 +-0.8974124457305956,0 +-0.8961813067723388,0 +-0.893719028855825,0 +-0.8924878898975679,0 +-0.890025611981054,0 +-0.8863321951062832,0 +-0.8814076392732556,0 +-0.8777142223984846,0 +-0.8752519444819709,0 +-0.8715585276071999,0 +-0.8715585276071999,0 +-0.8617094159411445,0 +-0.8592471380246305,0 +-0.8543225821916026,0 +-0.8506291653168321,0 +-0.8457046094838042,0 +-0.8395489146925196,0 +-0.8333932199012348,0 +-0.8272375251099501,0 +-0.8235441082351792,0 +-0.8173884134438945,0 +-0.8124638576108668,0 +-0.808770440736096,0 +-0.8013836069865544,0 +-0.7927656342787557,0 +-0.787841078445728,0 +-0.7816853836544433,0 +-0.7767608278214155,0 +-0.7718362719883878,0 +-0.7681428551136169,0 +-0.7656805771971031,0 +-0.7632182992805892,0 +-0.7619871603223323,0 +-0.7570626044893045,0 +-0.7509069096980198,0 +-0.7484446317815059,0 +-0.7435200759484781,0 +-0.7398266590737074,0 +-0.7349021032406796,0 +-0.7312086863659087,0 +-0.7287464084493949,0 +-0.7275152694911379,0 +-0.726284130532881,0 +-0.725052991574624,0 +-0.7213595746998532,0 +-0.7176661578250824,0 +-0.7139727409503116,0 +-0.7102793240755407,0 +-0.7065859072007699,0 +-0.7028924903259991,0 +-0.7004302124094852,0 +-0.6967367955347145,0 +-0.6930433786599436,0 +-0.685656544910402,0 +-0.6819631280356311,0 +-0.6770385722026033,0 +-0.6745762942860895,0 +-0.6708828774113187,0 +-0.6671894605365478,0 +-0.663496043661777,0 +-0.6610337657452632,0 +-0.6573403488704923,0 +-0.6524157930374646,0 +-0.6487223761626938,0 +-0.6437978203296659,0 +-0.6401044034548952,0 +-0.6351798476218674,0 +-0.6290241528305827,0 +-0.621637319081041,0 +-0.6154816242897563,0 +-0.6093259294984716,0 +-0.6044013736654438,0 +-0.5970145399159023,0 +-0.5896277061663606,0 +-0.5847031503333329,0 +-0.579778594500305,0 +-0.5723917607507634,0 +-0.5637737880429647,0 +-0.5551558153351662,0 +-0.5465378426273677,0 +-0.5366887309613121,0 +-0.5268396192952566,0 +-0.5169905076292011,0 +-0.5083725349214024,0 +-0.49975456221360387,0 +-0.48990545054754836,0 +-0.48005633888149285,0 +-0.46897608825718035,0 +-0.4554335597163541,0 +-0.44312217013378463,0 +-0.42957964159295836,0 +-0.41973052992690274,0 +-0.40988141826084723,0 +-0.4000323065947917,0 +-0.3988011676365349,0 +-0.3914143338869931,0 +-0.3827963611791945,0 +-0.372947249513139,0 +-0.3618669988888267,0 +-0.352017887222771,0 +-0.3421687755567155,0 +-0.33231966389066003,0 +-0.32493283014111846,0 +-0.31877713534983376,0 +-0.31139030160029196,0 +-0.3015411899342365,0 +-0.2892298003516671,0 +-0.2806118276438685,0 +-0.2695315770195561,0 +-0.2658381601447852,0 +-0.2609136043117575,0 +-0.2572201874369867,0 +-0.25352677056221573,0 +-0.24983335368744497,0 +-0.24983335368744497,0 +-0.24613993681267418,0 +-0.23505968618836165,0 +-0.22151715764753538,0 +-0.21536146285625066,0 +-0.21289918493973675,0 +-0.21289918493973675,0 +-0.20797462910670908,0 +-0.20920576806496596,0 +-0.21289918493973675,0 +-0.21043690702322285,0 +-0.20920576806496596,0 +-0.20797462910670908,0 +-0.19935665639891048,0 +-0.19443210056588264,0 +-0.19196982264936874,0 +-0.18704526681634107,0 +-0.18335184994157028,0 +-0.17965843306679932,0 +-0.17473387723377165,0 +-0.1710404603590009,0 +-0.1673470434842299,0 +-0.15872907077643147,0 +-0.15257337598514678,0 +-0.1488799591103758,0 +-0.14395540327734815,0 +-0.14149312536083425,0 +-0.13903084744432032,0 +-0.13410629161129267,0 +-0.1304128747365217,0 +-0.12671945786175093,0 +-0.12179490202872326,0 +-0.11933262411220934,0 +-0.11687034619569543,0 +-0.11440806827918151,0 +-0.10702123452963994,0 +-0.10086553973835523,0 +-0.09840326182184132,0 +-0.0959409839053274,0 +-0.09224756703055662,0 +-0.0897852891140427,0 +-0.08732301119752879,0 +-0.08486073328101504,0 +-0.07993617744798721,0 +-0.0737804826567025,0 +-0.06762478786541781,0 +-0.06270023203238997,0 +-0.0651625099489039,0 +-0.07254934369844564,0 +-0.0774738995314733,0 +-0.0737804826567025,0 +-0.06885592682367468,0 +-0.06762478786541781,0 +-0.06639364890716094,0 +-0.061469093074133105,0 +-0.059006815157619186,0 +-0.056544537241105274,0 +-0.051619981408077616,0 +-0.044233147658535865,0 +-0.03930859182550821,0 +-0.034384035992480376,0 +-0.028228341201195675,0 +-0.02453492432642489,0 +-0.018379229535140186,0 +-0.014685812660369401,0 +-0.008530117869084699,0 +-0.0036055620360568663,0 +0.00008785483871391996,0 +0.005012410671741752,0 +0.007474688588255493,0 +0.012399244421283324,0 +0.01486152233779724,0 +0.01855493921256803,0 +0.027172911920366645,0 +0.030866328795137433,0 +0.037022023586422134,0 +0.043177718377706835,0 +0.04933341316899154,0 +0.05548910796027624,0 +0.060413663793304075,0 +0.06533821962633173,0 +0.07026277545935956,0 +0.07641847025064427,0 +0.08134302608367192,0 +0.08872985983321367,0 +0.09981011045752602,0 +0.10719694420706778,0 +0.11212150004009544,0 +0.11704605587312326,0 +0.11704605587312326,0 +0.12073947274789405,0 +0.1281263064974358,0 +0.1342820012887205,0 +0.13920655712174815,0 +0.15398022462083147,0 +0.1601359194121162,0 +0.16752275316165777,0 +0.17121617003642872,0 +0.17367844795294246,0 +0.1835275596189981,0 +0.1909143933685397,0 +0.20445692190936596,0 +0.21553717253367852,0 +0.22661742315799105,0 +0.23031084003276164,0 +0.23646653482404636,0 +0.24262222961533106,0 +0.24385336857358828,0 +0.24385336857358828,0 +0.2475467854483589,0 +0.2537024802396436,0 +0.2586270360726714,0 +0.26108931398918517,0 +0.266013869822213,0 +0.27093842565524084,0 +0.27340070357175456,0 +0.28078753732129647,0 +0.285712093154324,0 +0.2869432321125812,0 +0.2906366489873518,0 +0.2967923437786365,0 +0.3004857606534075,0 +0.3041791775281781,0 +0.3091037333612059,0 +0.31402828919423376,0 +0.3177217060690047,0 +0.32141512294377533,0 +0.3288019566933169,0 +0.3337265125263447,0 +0.3361887904428588,0 +0.3423444852341435,0 +0.3460379021089141,0 +0.3485001800254282,0 +0.3546558748167129,0 +0.3595804306497404,0 +0.3608115696079976,0 +0.36327384752451136,0 +0.36696726439928234,0 +0.36942954231579606,0 +0.37681637606533763,0 +0.3817409318983655,0 +0.3854343487731364,0 +0.3854343487731364,0 +0.3903589046061639,0 +0.39528346043919177,0 +0.3965145993974486,0 +0.4002080162722196,0 +0.4051325721052474,0 +0.41128826689653214,0 +0.41867510064607366,0 +0.42236851752084464,0 +0.4235996564791015,0 +0.42852421231212934,0 +0.4322176291869,0 +0.4371421850199278,0 +0.4420667408529556,0 +0.44576015772772626,0 +0.45314699147726817,0 +0.45684040835203876,0 +0.46053382522680975,0 +0.46299610314332346,0 +0.46668952001809444,0 +0.46915179793460815,0 +0.474076353767636,0 +0.4765386316841497,0 +0.4790009096006638,0 +0.48885002126671917,0 +0.4925434381414901,0 +0.4986991329327748,0 +0.5060859666823164,0 +0.5097793835570874,0 +0.513472800431858,0 +0.5171662173066289,0 +0.5196284952231427,0 +0.5257841900144273,0 +0.5319398848057121,0 +0.5380955795969967,0 +0.5430201354300246,0 +0.5467135523047956,0 +0.5504069691795662,0 +0.5516381081378234,0 +0.5528692470960803,0 +0.5528692470960803,0 +0.5565626639708509,0 +0.559024941887365,0 +0.5639494977203928,0 +0.5688740535534202,0 +0.5725674704281912,0 +0.575029748344705,0 +0.5762608873029622,0 +0.578723165219476,0 +0.578723165219476,0 +0.578723165219476,0 +0.5811854431359897,0 +0.5861099989690175,0 +0.5885722768855316,0 +0.5934968327185591,0 +0.5959591106350732,0 +0.59719024959333,0 +0.6021148054263579,0 +0.6045770833428716,0 +0.6082705002176426,0 +0.6119639170924132,0 +0.6119639170924132,0 +0.6131950560506704,0 +0.6267375845914966,0 +0.6279687235497535,0 +0.6279687235497535,0 +0.636586696257552,0 +0.6415112520905798,0 +0.6501292247983786,0 +0.6538226416731493,0 +0.6575160585479202,0 +0.662440614380948,0 +0.6685963091722328,0 +0.6772142818800312,0 +0.6895256714626006,0 +0.6882945325043437,0 +0.6919879493791143,0 +0.7042993389616837,0 +0.7104550337529685,0 +0.7104550337529685,0 +0.721535284377281,0 +0.725228701252052,0 +0.7289221181268225,0 +0.7338466739598504,0 +0.7461580635424198,0 +0.7510826193754476,0 +0.7523137583337045,0 +0.7560071752084755,0 +0.767087425832788,0 +0.7720119816658154,0 +0.7683185647910449,0 +0.7683185647910449,0 +0.7769365374988433,0 +0.7707808427075586,0 +0.7695497037493018,0 +0.7683185647910449,0 +0.7720119816658154,0 +0.7757053985405864,0 +0.7818610933318711,0 +0.7917102049979268,0 +0.7941724829144405,0 +0.7954036218726974,0 +0.8003281777057253,0 +0.8064838724970099,0 +0.8101772893717809,0 +0.8151018452048083,0 +0.8187952620795793,0 +0.8224886789543503,0 +0.8323377906204056,0 +0.844649180202975,0 +0.848342597077746,0 +0.8397246243699472,0 +0.8360312074951766,0 +0.8384934854116903,0 +0.8458803191612319,0 +0.8557294308272876,0 +0.8631162645768291,0 +0.8606539866603153,0 +0.8618851256185722,0 +0.868040820409857,0 +0.8778899320759126,0 +0.8828144879089401,0 +0.887739043741968,0 +0.8914324606167389,0 +0.9012815722827943,0 +0.909899544990593,0 +0.9123618229071068,0 +0.9135929618653637,0 +0.9209797956149055,0 +0.933291185197475,0 +0.9443714358217872,0 +0.9443714358217872,0 +0.9431402968635303,0 +0.9529894085295859,0 +0.9616073812373843,0 +0.968994214986926,0 +0.9677630760286691,0 +0.968994214986926,0 +0.9837678824860094,0 +0.9985415499850926,0 +1.03178230185803,0 +1.0354757187328008,0 +1.0231643291502315,0 +1.0231643291502315,0 +1.037937996649315,0 +1.0391691356075718,0 +1.0428625524823423,0 +1.0551739420649118,0 +1.0625607758144537,0 +1.0687164706057384,0 +1.0822589991465648,0 +1.0933392497708772,0 +1.0810278601883079,0 +1.0699476095639953,0 +1.0711787485222521,0 +1.0847212770630785,0 +1.0888702153524046,0 +1.0985839017330516,0 +1.1116585974697404,0 +1.1192424134526031,0 +1.1297194059873696,0 +1.1356165615974205,0 +1.1401594643533886,0 +1.140947393286673,0 +1.151781416119334,0 +1.1596976396209262,0 +1.1799375640946703,0 +1.18745982312962,0 +1.1964348261353133,0 +1.2027505689911715,0 +1.2121195364635067,0 +1.218201362917296,0 +1.2337875821288289,0 +1.2615497656375227,0 +1.280669353659253,0 +1.2766558406553354,0 +1.278958070507276,0 +1.2557880353128805,0 +1.2084876765366486,0 +1.1868811878192393,0 +1.193467781245914,0 +1.2029721740036579,0 +1.2003867821913181,0 +1.203021419561988,0 +1.206874884501332,0 +1.2286414212833148,0 +1.2278534923500304,0 +1.2367053814598976,0 +1.2418269195262466,0 +1.2498046999757515,0 +1.2587304574231146,0 +1.2601216444459449,0 +1.261857550377087,0 +1.2593090927334953,0 +1.2601955127834403,0 +1.2608110822625687,0 +1.2594198952397382,0 +1.2582995587877246,0 +1.2632979829582476,0 +1.2672130048455048,0 +1.2715466139785694,0 +1.2809032700613219,0 +1.287736091279648,0 +1.2985947368914743,0 +1.3073973804430115,0 +1.316150778436218,0 +1.3302842536770076,0 +1.3431742785699579,0 +1.35195229934233,0 +1.35195229934233,0 +1.3592037078064636,0 +1.3679571057996702,0 +1.3740019980847118,0 +1.3758856406908448,0 +1.3810441129259416,0 +1.3817089279634005,0 +1.3893296781150108,0 +1.3953253248417221,0 +1.40121016906219,0 +1.4058884971035668,0 +1.4000528984414287,0 +1.3977260458103231,0 +1.3985385975227727,0 +1.414888122888425,0 +1.4283937172605032,0 +1.4360637129704443,0 +1.4312745824228246,0 +1.4309298635145127,0 +1.4365684799433296,0 +1.452056208038202,0 +1.443992247861619,0 +1.4380212239140726,0 +1.4392031173139994,0 +1.4459990043635775,0 +1.4516006866236468,0 +1.4584458192315553,0 +1.4595292215148217,0 +1.4610681452126428,0 +1.4663743541227299,0 +1.4739951042743404,0 +1.4810002849468227,0 +1.4841889348487078,0 +1.4901107132379237,0 +1.4955277246542542,0 +1.5010678499664107,0 +1.5049213149057548,0 +1.5082084559243008,0 +1.5077406231201633,0 +1.507949916743067,0 +1.5073712814326863,0 +1.5018680902892776,0 +1.4928684645044195,0 +1.4743398231826526,0 +1.4816650999842813,0 +1.4894705209796302,0 +1.4924991228169424,0 +1.502089695301764,0 +1.5101659668679295,0 +1.5148073607405579,0 +1.5147950493509754,0 +1.5082084559243008,0 +1.5111016324762048,0 +1.5145980671176544,0 +1.5114463513845167,0 +1.5101044099200165,0 +1.5062140108119249,0 +1.5153736846613564,0 +1.5266016719606594,0 +1.5335576070748111,0 +1.54174468114722,0 +1.540255003007729,0 +1.5457212599823897,0 +1.553489746808991,0 +1.560445681923143,0 +1.5639544279541753,0 +1.571476686989125,0 +1.5835541601696257,0 +1.5883309793276623,0 +1.5869274809152498,0 +1.590633209179603,0 +1.5883063565484972,0 +1.5902638674921263,0 +1.5912487786587317,0 +1.602821484866347,0 +1.6182353446237239,0 +1.6205868200339943,0 +1.6286877143793252,0 +1.635101948351844,0 +1.6480781529718718,0 +1.6563267839921936,0 +1.656917730692157,0 +1.6649201339208266,0 +1.6718760690349783,0 +1.674658443080639,0 +1.677071475438823,0 +1.6825992893613968,0 +1.6889889005547496,0 +1.696252620408466,0 +1.6940611930627687,0 +1.6961910634605533,0 +1.695587805371007,0 +1.7030238846788794,0 +1.7138579075115403,0 +1.7262800996003531,0 +1.7272526993773754,0 +1.7256399073420592,0 +1.7321156982624908,0 +1.736720157966372,0 +1.7422849060576935,0 +1.7513091546217168,0 +1.7581912213983728,0 +1.7682003811290015,0 +1.7733834761432632,0 +1.7759811793451854,0 +1.7621431774543779,0 +1.7371264338225965,0 +1.7343686825561009,0 +1.7481328161094134,0 +1.7640391314500936,0 +1.7779879358471444,0 +1.7855963746091723,0 +1.7887604017318928,0 +1.7952485040419066,0 +1.8102930221118068,0 +1.8178645267050864,0 +1.830028179612665,0 +1.8457375127200237,0 +1.8673563128270159,0 +1.8891967179464941,0 +1.9152353069136285,0 +1.9300705313606243,0 +1.9436869282389462,0 +1.9579435173755617,0 +1.9886481229944895,0 +2.015597754790734,0 +2.042387338522405,0 +2.0634644374877635,0 +2.088222641938311,0 +2.123457838923624,0 +2.137480511658171,0 diff --git a/src/safeds/ml/nn/layers/__init__.py b/src/safeds/ml/nn/layers/__init__.py index 71ef0ab64..8d195dde2 100644 --- a/src/safeds/ml/nn/layers/__init__.py +++ b/src/safeds/ml/nn/layers/__init__.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from ._convolutional2d_layer import Convolutional2DLayer, ConvolutionalTranspose2DLayer + from ._dropout_layer import DropoutLayer from ._flatten_layer import FlattenLayer from ._forward_layer import ForwardLayer from ._gru_layer import GRULayer @@ -25,6 +26,7 @@ "GRULayer": "._gru_layer:GRULayer", "AveragePooling2DLayer": "._pooling2d_layer:AveragePooling2DLayer", "MaxPooling2DLayer": "._pooling2d_layer:MaxPooling2DLayer", + "DropoutLayer": "._dropout_layer:DropoutLayer", }, ) @@ -38,4 +40,5 @@ "GRULayer", "AveragePooling2DLayer", "MaxPooling2DLayer", + "DropoutLayer", ] diff --git a/src/safeds/ml/nn/layers/_dropout_layer.py b/src/safeds/ml/nn/layers/_dropout_layer.py new file mode 100644 index 000000000..1814e4383 --- /dev/null +++ b/src/safeds/ml/nn/layers/_dropout_layer.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from safeds._utils import _structural_hash +from safeds._validation import _check_bounds, _ClosedBound +from safeds.ml.nn.typing import ModelImageSize + +from ._layer import Layer + +if TYPE_CHECKING: + from torch import nn + + +class DropoutLayer(Layer): + """ + A dropout layer. + + Parameters + ---------- + probability: + The probability that an input element gets zeroed out. + + Raises + ------ + OutOfBoundsError + If probability is not in the range [0, 1]. + """ + + def __init__(self, probability: float): + _check_bounds("probability", probability, lower_bound=_ClosedBound(0), upper_bound=_ClosedBound(1)) + self._probability = probability + self._input_size: int | ModelImageSize | None = None + + @property + def probability(self) -> float: + """The probability that an input element gets zeroed out.""" + return self._probability + + def _get_internal_layer(self, **_kwargs: Any) -> nn.Module: + from ._internal_layers import _InternalDropoutLayer # slow import on global level + + if self._input_size is None: + raise ValueError( + "The input_size is not yet set. The internal layer can only be created when the input_size is set.", + ) + return _InternalDropoutLayer(self.probability) + + @property + def input_size(self) -> int | ModelImageSize: + """ + Get the input_size of this layer. + + Returns + ------- + result: + The amount of values being passed into this layer. + + Raises + ------ + ValueError + If the input_size is not yet set + """ + if self._input_size is None: + raise ValueError("The input_size is not yet set.") + return self._input_size + + @property + def output_size(self) -> int | ModelImageSize: + """ + Get the output_size of this layer. + + Returns + ------- + result: + The amount of values being passed out of this layer. + + Raises + ------ + ValueError + If the input_size is not yet set + """ + if self._input_size is None: + raise ValueError("The input_size is not yet set.") + return self._input_size + + def _set_input_size(self, input_size: int | ModelImageSize) -> None: + self._input_size = input_size + + def __hash__(self) -> int: + return _structural_hash(self._input_size, self._probability) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DropoutLayer): + return NotImplemented + if self is other: + return True + return self._input_size == other._input_size and self._probability == other._probability + + def __sizeof__(self) -> int: + if self._input_size is None: + raise ValueError("The input_size is not yet set.") + if isinstance(self._input_size, int): + return int(self._input_size) + elif isinstance(self._input_size, ModelImageSize): + return self._input_size.__sizeof__() diff --git a/src/safeds/ml/nn/layers/_internal_layers.py b/src/safeds/ml/nn/layers/_internal_layers.py index 321ca21eb..c2587e2e8 100644 --- a/src/safeds/ml/nn/layers/_internal_layers.py +++ b/src/safeds/ml/nn/layers/_internal_layers.py @@ -56,6 +56,18 @@ def forward(self, x: Tensor) -> Tensor: return self._fn(self._layer(x)) +class _InternalDropoutLayer(nn.Module): + def __init__(self, probability: float) -> None: + super().__init__() + + _init_default_device() + + self._layer = nn.Dropout(probability) + + def forward(self, x: Tensor) -> Tensor: + return self._layer(x) + + class _InternalFlattenLayer(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/safeds/ml/nn/layers/test_dropout_layer.py b/tests/safeds/ml/nn/layers/test_dropout_layer.py new file mode 100644 index 000000000..5dad5f3a8 --- /dev/null +++ b/tests/safeds/ml/nn/layers/test_dropout_layer.py @@ -0,0 +1,69 @@ +import sys + +import pytest +from safeds.data.tabular.containers import Table +from safeds.exceptions import OutOfBoundsError +from safeds.ml.nn.layers import DropoutLayer +from safeds.ml.nn.typing import ConstantImageSize +from torch import nn + + +class TestProbability: + def test_should_be_accessible(self) -> None: + probability = 0.5 + layer = DropoutLayer(probability) + assert layer.probability == probability + + @pytest.mark.parametrize("probability", [-1, 2], ids=["too low", "too high"]) + def test_should_raise_if_out_of_bounds(self, probability: int) -> None: + with pytest.raises(OutOfBoundsError): + DropoutLayer(probability) + + +class TestDropoutLayer: + def test_should_create_dropout_layer(self) -> None: + size = 10 + layer = DropoutLayer(0.5) + layer._set_input_size(size) + assert layer.input_size == size + assert layer.output_size == size + assert isinstance(next(next(layer._get_internal_layer().modules()).children()), nn.Dropout) + + def test_input_size_should_be_set(self) -> None: + layer = DropoutLayer(0.5) + with pytest.raises(ValueError, match=r"The input_size is not yet set."): + layer.input_size # noqa: B018 + + with pytest.raises(ValueError, match=r"The input_size is not yet set."): + layer.output_size # noqa: B018 + + with pytest.raises(ValueError, match=r"The input_size is not yet set."): + layer._get_internal_layer() + + with pytest.raises(ValueError, match=r"The input_size is not yet set."): + layer.__sizeof__() + + +class TestEq: + def test_should_be_equal(self) -> None: + assert DropoutLayer(0.5) == DropoutLayer(0.5) + + def test_should_be_not_implemented(self) -> None: + assert DropoutLayer(0.5).__eq__(Table()) is NotImplemented + + +class TestHash: + def test_hash_should_be_equal(self) -> None: + assert hash(DropoutLayer(0.5)) == hash(DropoutLayer(0.5)) + + +class TestSizeOf: + def test_should_int_size_be_greater_than_normal_object(self) -> None: + layer = DropoutLayer(0.5) + layer._set_input_size(10) + assert sys.getsizeof(layer) > sys.getsizeof(object()) + + def test_should_model_image_size_be_greater_than_normal_object(self) -> None: + layer = DropoutLayer(0.5) + layer._set_input_size(ConstantImageSize(1, 1, 1)) + assert sys.getsizeof(layer) > sys.getsizeof(object()) diff --git a/tests/safeds/ml/nn/test_dropout_workflow.py b/tests/safeds/ml/nn/test_dropout_workflow.py new file mode 100644 index 000000000..b51f0708b --- /dev/null +++ b/tests/safeds/ml/nn/test_dropout_workflow.py @@ -0,0 +1,37 @@ +import pytest +from safeds._config import _get_device +from safeds.data.tabular.containers import Table +from safeds.ml.nn import ( + NeuralNetworkRegressor, +) +from safeds.ml.nn.converters import ( + InputConversionTable, +) +from safeds.ml.nn.layers import ( + DropoutLayer, + ForwardLayer, +) +from torch.types import Device + +from tests.helpers import configure_test_with_device, get_devices, get_devices_ids + + +@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids()) +def test_forward_model(device: Device) -> None: + configure_test_with_device(device) + + train_table = Table( + { + "feature": [1, 2, 3], + "value": [1, 2, 3], + }, + ) + + model = NeuralNetworkRegressor( + InputConversionTable(), + [ForwardLayer(neuron_count=1), DropoutLayer(probability=0.5)], + ) + + fitted_model = model.fit(train_table.to_tabular_dataset("value"), epoch_size=1, learning_rate=0.01) + assert fitted_model._model is not None + assert fitted_model._model.state_dict()["_pytorch_layers.0._layer.weight"].device == _get_device()