From d2e3d24d2c6e82c93b8396118cb121ca3e78a1bd Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Tue, 20 Feb 2024 23:59:53 +0000 Subject: [PATCH] chore: Initial Commit --- .gitignore | 7 + Lab1_2/A1_template.csv | 2 + Lab1_2/A1_template_template.csv | 2 + Lab1_2/Assignment1.csv | 50 + Lab1_2/Lab1&2_Transformers-base.ipynb | 1146 ++++ Lab1_2/Lab1&2_Transformers.ipynb | 93 + Lab3/Week3_Autoencoder+MAE - Copy.ipynb | 7058 +++++++++++++++++++++++ Lab3/Week3_Autoencoder+MAE - Copy.py | 562 ++ 8 files changed, 8920 insertions(+) create mode 100644 .gitignore create mode 100644 Lab1_2/A1_template.csv create mode 100644 Lab1_2/A1_template_template.csv create mode 100644 Lab1_2/Assignment1.csv create mode 100644 Lab1_2/Lab1&2_Transformers-base.ipynb create mode 100644 Lab1_2/Lab1&2_Transformers.ipynb create mode 100644 Lab3/Week3_Autoencoder+MAE - Copy.ipynb create mode 100644 Lab3/Week3_Autoencoder+MAE - Copy.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9197318 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +Lab3/dataset/ +Lab3/mae +Lab3/masked +Lab3/st +Lab3/st2 +*.pth +.ipynb_checkpoints diff --git a/Lab1_2/A1_template.csv b/Lab1_2/A1_template.csv new file mode 100644 index 0000000..298104f --- /dev/null +++ b/Lab1_2/A1_template.csv @@ -0,0 +1,2 @@ +URN,Q,K,V,ANSWER +6644818,"[[-0.19737370312213898, -1.0540887117385864, 0.02383515052497387, 0.46185705065727234], [-1.2415547370910645, 0.8366656303405762, 0.3741966784000397, 0.9099264740943909], [0.3436168134212494, 0.6154376268386841, 1.1926648616790771, 1.6477248668670654]]","[[1.9663442373275757, 0.15551914274692535, -0.8715013861656189, 0.32070425152778625], [-5.85474967956543, 1.7047394514083862, -1.0024793148040771, 1.3307985067367554], [0.06319630891084671, -2.030783176422119, -5.436811447143555, -0.42979586124420166]]","[[-82.127197265625, 0.9534303545951843, -28.78610610961914, -10.762138366699219], [-16.467313766479492, 60.92831802368164, -36.08392333984375, 31.648052215576172], [20.485767364501953, 45.4570198059082, 15.208494186401367, 31.43212890625]]","[[-7.56060266494751, 40.530540466308594, -4.961359024047852, 23.440505981445312], [-16.6014461517334, 60.75471496582031, -36.01152420043945, 31.536420822143555], [-50.659423828125, 29.360170364379883, -31.904930114746094, 9.3984956741333]]" diff --git a/Lab1_2/A1_template_template.csv b/Lab1_2/A1_template_template.csv new file mode 100644 index 0000000..cca8198 --- /dev/null +++ b/Lab1_2/A1_template_template.csv @@ -0,0 +1,2 @@ +URN,Q,K,V,ANSWER +Copy your URN here,Copy your Q value here,Copy your K value here,Copy your V value here,Fill in your Answer Here diff --git a/Lab1_2/Assignment1.csv b/Lab1_2/Assignment1.csv new file mode 100644 index 0000000..557aa17 --- /dev/null +++ b/Lab1_2/Assignment1.csv @@ -0,0 +1,50 @@ +URN,Q,K,V,ANSWER +6424515,"[[61.812015533447266, 40.04122543334961, -1.3776825666427612, 5.644913196563721], [-4.9865241050720215, -9.148029327392578, -90.78352355957031, 27.30191993713379], [21.7161808013916, -63.25381851196289, -28.20044708251953, 14.372629165649414]]","[[-26.79865264892578, 42.945430755615234, -14.29906177520752, 10.068202018737793], [-18.409151077270508, 16.403100967407227, 14.94759750366211, 6.1012983322143555], [-3.0599536895751953, -7.43846321105957, -29.793472290039062, -6.154738903045654]]","[[0.6668468117713928, 0.0321093387901783, 0.06967663019895554, -1.0507230758666992], [0.5716139674186707, -0.160260871052742, -0.08285751193761826, -0.5788695812225342], [0.8007872700691223, -0.27879026532173157, 1.2085530757904053, 1.4593729972839355]]", +6483559,"[[5.575249671936035, 23.299489974975586, -10.568557739257812, 14.052878379821777], [1.1961194276809692, 7.634078502655029, -3.072789430618286, 19.930212020874023], [10.905213356018066, 5.266031265258789, 3.967888355255127, -19.720064163208008]]","[[7.83995246887207, 3.862151622772217, 3.054084300994873, 1.4503860473632812], [9.037196159362793, 7.029181480407715, -0.9516154527664185, -6.377782821655273], [-3.326775550842285, -1.8302644491195679, -5.7486701011657715, 8.669612884521484]]","[[-32.508087158203125, 50.03919982910156, -20.734249114990234, 9.130086898803711], [25.76167869567871, -12.930599212646484, -10.532285690307617, -23.300447463989258], [12.573369026184082, -76.31478881835938, 0.0935261994600296, -23.858386993408203]]", +6488397,"[[2.6853091716766357, 17.725229263305664, 1.788853645324707, 0.6750410795211792], [-32.531497955322266, 18.576560974121094, -34.81031036376953, -10.207343101501465], [-19.915924072265625, -9.083245277404785, -4.631546974182129, 25.35552406311035]]","[[-29.389095306396484, 24.075138092041016, -9.420976638793945, -7.860787391662598], [-13.026573181152344, -20.35844612121582, -4.833100318908691, 16.232547760009766], [-24.538742065429688, 2.2693328857421875, -13.806289672851562, 33.24972915649414]]","[[-3.366933822631836, -1.0369418859481812, -4.604970932006836, 2.527782917022705], [-3.817992925643921, -3.409868001937866, 8.358521461486816, -3.794895648956299], [-1.9248569011688232, 3.9590156078338623, -0.23751775920391083, -5.782866954803467]]", +6541000,"[[25.85124397277832, -14.847068786621094, -8.007427215576172, 13.616039276123047], [22.194473266601562, -41.049625396728516, -48.59806442260742, -30.103748321533203], [-0.9766722321510315, 16.258834838867188, 5.409304141998291, -34.920753479003906]]","[[-19.665332794189453, -9.125407218933105, 25.826215744018555, -24.121002197265625], [-0.6768970489501953, 32.871368408203125, 28.883134841918945, 15.97001838684082], [-32.61582946777344, -26.842838287353516, 29.825098037719727, 25.55153465270996]]","[[-22.1315860748291, -19.117671966552734, -86.41238403320312, -41.470970153808594], [-67.3498306274414, -20.870065689086914, 46.31827926635742, -49.43082809448242], [28.707395553588867, 49.85984802246094, -20.697355270385742, -19.470643997192383]]", +6564898,"[[17.990629196166992, 18.082250595092773, 6.080262660980225, -1.9313312768936157], [-1.1121807098388672, 14.952383041381836, 5.227292060852051, 11.955268859863281], [-12.968427658081055, 10.1666841506958, 9.647360801696777, 10.25912094116211]]","[[33.59651184082031, -2.7849419116973877, -4.290157318115234, 16.049331665039062], [16.138713836669922, -3.1907596588134766, 1.9617745876312256, -13.963973999023438], [1.2874808311462402, -11.218521118164062, -4.645533561706543, -21.415369033813477]]","[[29.302244186401367, 37.714359283447266, -14.346443176269531, 26.861482620239258], [15.229442596435547, -30.665781021118164, 27.858675003051758, -3.7787418365478516], [19.023386001586914, 33.741153717041016, 17.80762481689453, -14.525671005249023]]", +6595203,"[[-4.682967185974121, -0.46032536029815674, 1.9287296533584595, 1.098872423171997], [1.813373327255249, -0.17624400556087494, -7.465083122253418, 4.692303657531738], [-11.090826034545898, 7.349782943725586, 4.164590835571289, -4.623814582824707]]","[[52.20319747924805, 39.19321060180664, 4.55007791519165, 32.2530403137207], [61.92286682128906, -44.482208251953125, -35.478302001953125, -68.6395263671875], [20.798810958862305, -43.60276412963867, 8.565412521362305, 12.54694938659668]]","[[0.6951367855072021, -0.21053913235664368, 1.9876152276992798, 0.10447879880666733], [0.9846767783164978, 0.6022341847419739, -0.6896607279777527, -1.6564579010009766], [-0.7948723435401917, 0.6899239420890808, -1.8456658124923706, 0.6393752098083496]]", +6595493,"[[0.1511625498533249, -0.9745533466339111, -2.2466001510620117, 3.180349349975586], [-5.036211967468262, -3.2606935501098633, 0.8353381156921387, -0.9949823617935181], [2.559250593185425, 1.6193833351135254, -1.08794367313385, -0.50386643409729]]","[[0.7303028106689453, -3.6174004077911377, -2.7354886531829834, 2.467529296875], [-2.1577024459838867, -0.8152199387550354, 9.242465019226074, -0.7949981689453125], [-1.6239731311798096, 5.8632731437683105, -7.212184906005859, 1.169894814491272]]","[[-8.279622077941895, 6.367185115814209, -0.10335130989551544, -15.839896202087402], [38.9275016784668, -5.544412612915039, -17.214385986328125, -35.62171173095703], [7.223878383636475, 2.735114812850952, 7.348245143890381, 13.252175331115723]]", +6620065,"[[-8.823065757751465, -0.38451871275901794, 11.639081001281738, -1.696522831916809], [10.205957412719727, -6.988275051116943, -16.136812210083008, -3.569871425628662], [-12.256014823913574, -10.282251358032227, -0.021453989669680595, -4.233187198638916]]","[[-7.298791408538818, 4.976022243499756, 12.335108757019043, 28.913663864135742], [14.395263671875, 20.777713775634766, 21.951862335205078, 10.135398864746094], [-12.99707317352295, 32.36512756347656, -23.65703582763672, -22.85142707824707]]","[[-21.05515480041504, -12.511702537536621, 6.4885172843933105, 7.027595520019531], [10.353309631347656, 8.395435333251953, 2.216571807861328, -2.703104257583618], [-1.446360468864441, 3.8767664432525635, 9.25796127319336, 7.1006669998168945]]", +6621031,"[[75.76603698730469, 85.38924407958984, 33.53232192993164, 99.98907470703125], [-65.45123291015625, 6.995937824249268, 0.8409870862960815, -67.66314697265625], [20.245336532592773, 56.575958251953125, -21.347116470336914, 72.61317443847656]]","[[-29.93255043029785, -17.499338150024414, -73.63774108886719, 46.85200119018555], [4.411999702453613, -2.149071216583252, -60.68436050415039, -49.36929702758789], [-48.77145004272461, 44.28728103637695, -33.22221374511719, -71.23680877685547]]","[[-23.651613235473633, -21.843677520751953, -11.714576721191406, 8.189347267150879], [-20.203529357910156, -0.08122234791517258, 36.678199768066406, 7.1701836585998535], [-3.1512069702148438, 13.690710067749023, -15.387824058532715, -37.97343063354492]]", +6622936,"[[-8.63397216796875, 4.2975921630859375, -4.522152423858643, -2.323275327682495], [8.693713188171387, -3.304124355316162, 6.605628490447998, -1.9376112222671509], [-2.9686343669891357, -1.4563707113265991, 4.499622821807861, 1.661992073059082]]","[[-2.6573855876922607, 17.252361297607422, 13.382349967956543, -69.54588317871094], [33.054359436035156, -5.027382850646973, 30.04429054260254, -102.58087158203125], [-15.303414344787598, -38.95169448852539, -22.762155532836914, 7.357682228088379]]","[[-35.71727752685547, -27.900691986083984, 37.486934661865234, 10.394813537597656], [21.43458366394043, -68.33717346191406, -45.12615966796875, -24.56119728088379], [-16.092315673828125, 69.08224487304688, -8.384400367736816, 16.746034622192383]]", +6623139,"[[10.012069702148438, 14.790369033813477, 10.731794357299805, -20.580612182617188], [8.332863807678223, 1.4901442527770996, -28.56690216064453, -4.9121785163879395], [31.787015914916992, -15.51191520690918, 29.45456314086914, 15.340608596801758]]","[[17.096817016601562, 3.979304552078247, -6.065049171447754, -7.021650314331055], [22.048603057861328, 16.307952880859375, -8.492615699768066, 3.741187810897827], [20.787437438964844, -20.73921012878418, 8.975601196289062, 1.641739010810852]]","[[24.196489334106445, 10.123150825500488, -16.860660552978516, 2.4109020233154297], [11.026750564575195, -0.042206697165966034, 3.502345561981201, 34.46019744873047], [44.30955123901367, -31.93711280822754, -6.689527988433838, -17.10219955444336]]", +6627063,"[[-1.0590970516204834, 0.7565467357635498, 0.20126891136169434, -0.6364921927452087], [0.004584392067044973, -2.307931900024414, 1.4653981924057007, 0.8097707629203796], [0.4444141685962677, -1.1070520877838135, -0.4160226285457611, 1.122387170791626]]","[[-13.375662803649902, 8.191596031188965, -32.305240631103516, 44.86949157714844], [15.18038272857666, -61.08460998535156, -41.45825958251953, -32.916629791259766], [48.51971435546875, 32.51784133911133, 1.7294328212738037, -62.1268310546875]]","[[-36.20771408081055, 26.2475643157959, 8.842000961303711, 43.27354431152344], [-62.30213165283203, 41.48006820678711, -6.880943775177002, -21.451093673706055], [-5.928592205047607, -0.6987577080726624, -12.249312400817871, 48.32103729248047]]", +6634908,"[[-50.7746696472168, -85.51974487304688, -10.108343124389648, 17.88229751586914], [-57.28345489501953, 28.485719680786133, -44.194435119628906, -51.714439392089844], [-53.68115997314453, 9.014735221862793, 45.90130615234375, 30.313220977783203]]","[[-13.381082534790039, -5.4600396156311035, -23.924701690673828, 49.98859786987305], [-29.438074111938477, 13.526389122009277, 5.885373115539551, -22.310937881469727], [-32.35851287841797, 15.571061134338379, -17.888227462768555, -8.200702667236328]]","[[-0.16054195165634155, 7.1971588134765625, -17.90960121154785, 27.956819534301758], [-10.453465461730957, 9.157516479492188, 8.646942138671875, -25.173913955688477], [38.51534652709961, -1.6076290607452393, 5.644974708557129, 20.121034622192383]]", +6635583,"[[-12.722966194152832, 71.30017852783203, 17.90323257446289, -12.047688484191895], [3.3535823822021484, 78.04386901855469, -50.834327697753906, 54.36956787109375], [35.454402923583984, -20.053314208984375, -25.02804183959961, -47.87135314941406]]","[[-43.83270263671875, 32.041934967041016, 11.113154411315918, -20.175634384155273], [32.75693130493164, 17.827775955200195, 4.969196319580078, -20.265155792236328], [-43.46480178833008, 12.0179443359375, 40.43904113769531, -60.19890594482422]]","[[4.519973278045654, -6.386752605438232, -3.6178507804870605, 6.750082015991211], [2.758470058441162, 17.978330612182617, -3.6265249252319336, -11.85763168334961], [7.1331305503845215, -1.1875079870224, 3.5858330726623535, -6.4079179763793945]]", +6638234,"[[-4.133094787597656, 5.524759769439697, -0.41305333375930786, -3.8603403568267822], [2.0264086723327637, 3.7351791858673096, 4.967355251312256, 5.839091777801514], [2.6239752769470215, -6.777492523193359, -5.668360233306885, -0.9872243404388428]]","[[129.13450622558594, 40.467010498046875, -4.255222797393799, -26.798913955688477], [-1.1929067373275757, 49.3392333984375, -22.62386131286621, 27.016782760620117], [-69.80500793457031, 31.683269500732422, -89.05852508544922, -3.53053879737854]]","[[1.06186842918396, 14.646303176879883, 5.54833984375, -10.574101448059082], [-0.4040781557559967, 1.0404623746871948, 9.219944953918457, -8.739606857299805], [-6.794262409210205, 1.0300147533416748, -11.154375076293945, 5.335386276245117]]", +6640106,"[[3.3120908737182617, 11.714212417602539, -3.5947165489196777, -5.550578594207764], [9.422024726867676, 6.105061054229736, -2.6055192947387695, -4.635824680328369], [-8.967622756958008, 1.748089075088501, 4.3646955490112305, 2.2075018882751465]]","[[49.388126373291016, -3.505145788192749, 6.966372013092041, 20.567304611206055], [9.32462215423584, 30.148683547973633, -1.962703824043274, -9.71722412109375], [-96.27658081054688, 24.82595443725586, 108.67559051513672, -20.08876609802246]]","[[-25.174936294555664, 9.584662437438965, -25.932655334472656, -26.213214874267578], [17.62385368347168, -25.324230194091797, -35.68978500366211, -26.302963256835938], [12.274901390075684, 16.87058448791504, -14.158609390258789, 8.233379364013672]]", +6644818,"[[-0.19737370312213898, -1.0540887117385864, 0.02383515052497387, 0.46185705065727234], [-1.2415547370910645, 0.8366656303405762, 0.3741966784000397, 0.9099264740943909], [0.3436168134212494, 0.6154376268386841, 1.1926648616790771, 1.6477248668670654]]","[[1.9663442373275757, 0.15551914274692535, -0.8715013861656189, 0.32070425152778625], [-5.85474967956543, 1.7047394514083862, -1.0024793148040771, 1.3307985067367554], [0.06319630891084671, -2.030783176422119, -5.436811447143555, -0.42979586124420166]]","[[-82.127197265625, 0.9534303545951843, -28.78610610961914, -10.762138366699219], [-16.467313766479492, 60.92831802368164, -36.08392333984375, 31.648052215576172], [20.485767364501953, 45.4570198059082, 15.208494186401367, 31.43212890625]]", +6647000,"[[17.123088836669922, 0.7197161912918091, 67.95402526855469, 30.830045700073242], [1.918267011642456, 3.1925175189971924, -60.516944885253906, 33.09083557128906], [-17.23439598083496, 4.878037452697754, -27.22907829284668, 44.515987396240234]]","[[13.926458358764648, 41.11029815673828, 3.837980031967163, 29.635908126831055], [-45.734840393066406, 52.18793487548828, 5.066276550292969, 11.72782039642334], [-97.8349838256836, 28.44172477722168, -43.70535659790039, -25.272418975830078]]","[[-4.117476940155029, 2.6530921459198, -2.3165719509124756, 2.692505359649658], [-4.646360397338867, -6.495508670806885, -2.042623281478882, -4.2362542152404785], [-6.218053817749023, -5.21392822265625, 4.337059020996094, 5.960870742797852]]", +6650398,"[[15.136759757995605, -43.95924377441406, -113.8025894165039, 75.7243423461914], [-53.820682525634766, 6.7568206787109375, 11.68793773651123, -59.304115295410156], [17.12495231628418, -80.94425964355469, -24.56743621826172, 72.69660949707031]]","[[-44.538997650146484, 17.452163696289062, -22.793365478515625, -19.52366828918457], [22.004854202270508, -30.501188278198242, 17.9410343170166, -11.477399826049805], [13.915644645690918, -3.8742470741271973, -20.8011531829834, 10.137035369873047]]","[[94.21515655517578, 38.48592758178711, 8.827954292297363, -11.255606651306152], [9.103065490722656, -26.855743408203125, -58.49977111816406, 56.034507751464844], [36.73604202270508, 72.35386657714844, -5.083021640777588, -91.17439270019531]]", +6654031,"[[-0.3339273929595947, -0.5318685173988342, 2.0381877422332764, 0.33716848492622375], [-0.5744379758834839, -0.005252655595541, 1.7914447784423828, -0.27126064896583557], [-0.5965532064437866, -1.8395336866378784, 0.9394988417625427, 0.33245497941970825]]","[[14.26134204864502, 6.246121883392334, 16.684396743774414, 13.413249015808105], [18.52135467529297, -4.069742202758789, -8.969866752624512, -4.116239547729492], [-15.628847122192383, 1.7585363388061523, -7.5017409324646, 14.045808792114258]]","[[22.479833602905273, 10.961353302001953, -37.169498443603516, -3.8920676708221436], [22.053361892700195, 6.5353474617004395, 16.050573348999023, 1.3947471380233765], [53.140769958496094, 9.212106704711914, -23.101652145385742, 16.18545913696289]]", +6657209,"[[35.31223678588867, -22.971651077270508, 39.2910270690918, 53.64673614501953], [4.348316192626953, -16.771831512451172, -43.288639068603516, 28.353843688964844], [-39.88884353637695, -15.335161209106445, 31.237241744995117, 79.95108795166016]]","[[65.86861419677734, -66.7222671508789, -65.81453704833984, -53.20375442504883], [-21.004255294799805, -36.96867370605469, -49.629032135009766, 4.206972122192383], [-3.789904832839966, 53.388763427734375, -9.336297035217285, 59.61064147949219]]","[[-23.81308937072754, 17.60015106201172, -3.6998114585876465, -44.74623107910156], [39.35554122924805, 21.279781341552734, 25.772464752197266, -23.060672760009766], [23.75358772277832, 24.503376007080078, -16.24953842163086, -52.87109375]]", +6664919,"[[-4.587403774261475, 21.389644622802734, -0.8058542013168335, -15.177589416503906], [-30.81719207763672, 16.282472610473633, -30.25719451904297, -4.179492473602295], [-25.774181365966797, 2.7025911808013916, -15.140970230102539, 35.96717071533203]]","[[7.414963722229004, -73.7457275390625, -23.34357261657715, -59.4050178527832], [-37.04792022705078, 39.79411697387695, -28.534637451171875, -9.650754928588867], [-51.72665786743164, 5.047584533691406, 49.07307815551758, -10.990396499633789]]","[[-68.04022979736328, -18.380733489990234, -48.01260757446289, 4.971027851104736], [69.10401153564453, -24.525249481201172, -13.450369834899902, -31.265018463134766], [-17.114112854003906, 88.62091827392578, -49.525413513183594, -44.12730026245117]]", +6665234,"[[-58.47925567626953, -6.0381927490234375, 59.27907943725586, 3.1732094287872314], [18.26332664489746, -7.849457263946533, -13.465660095214844, 24.281553268432617], [50.79582214355469, -76.18656158447266, 37.01697540283203, 11.83233642578125]]","[[-19.288162231445312, -10.204944610595703, -34.52219009399414, 6.608401775360107], [6.105591297149658, 17.177522659301758, 41.61949920654297, 59.6090202331543], [-24.104351043701172, -3.944885015487671, 21.40576934814453, 4.275631427764893]]","[[9.385692596435547, 16.75665855407715, -22.38917350769043, 16.042783737182617], [25.004865646362305, -13.535341262817383, -1.943082571029663, -46.629024505615234], [-30.355073928833008, 7.768237590789795, -38.768218994140625, 10.225730895996094]]", +6667584,"[[-19.96772003173828, 78.34546661376953, 75.89759826660156, -2.6900768280029297], [-29.359573364257812, 36.52899932861328, -12.29699993133545, -49.58512496948242], [31.139745712280273, -32.53242111206055, -29.126243591308594, 47.14161682128906]]","[[7.564728736877441, -26.98961067199707, 14.280712127685547, -15.75666332244873], [16.308378219604492, -22.065216064453125, -4.871706008911133, -17.236865997314453], [34.73072814941406, 17.954309463500977, -9.708355903625488, -5.783907413482666]]","[[-37.74998474121094, 4.323561191558838, -15.50539779663086, 11.531072616577148], [42.675811767578125, -22.62323760986328, 16.255107879638672, 39.860836029052734], [-47.075279235839844, -16.6245174407959, 25.671066284179688, 0.7380026578903198]]", +6673385,"[[-36.83106231689453, 7.701159954071045, 68.32613372802734, 27.919288635253906], [-4.696883678436279, -3.020927906036377, -58.74855422973633, 9.028343200683594], [-35.0689811706543, 1.2082970142364502, 34.020774841308594, 1.3042429685592651]]","[[-16.659854888916016, 4.926011562347412, -10.322761535644531, 2.21010422706604], [-2.2169840335845947, -29.646329879760742, 3.1019740104675293, -26.244598388671875], [9.877544403076172, 14.357820510864258, -9.583166122436523, -2.8210270404815674]]","[[-23.14494514465332, -20.618330001831055, 21.521303176879883, 22.030481338500977], [-8.146241188049316, 40.73670196533203, -14.541313171386719, 36.61677169799805], [-4.874123573303223, 19.764659881591797, 22.12094497680664, 19.16327667236328]]", +6674521,"[[-25.607240676879883, 0.8884523510932922, 0.6508675217628479, -55.56379699707031], [-13.399002075195312, -16.82448387145996, 38.20143508911133, 24.327342987060547], [15.179723739624023, 0.8729835152626038, -21.22319793701172, -12.520669937133789]]","[[-1.71498441696167, 5.201033592224121, -11.60346508026123, -5.809443950653076], [5.660852909088135, 3.109971284866333, 8.395805358886719, -13.36353874206543], [8.876840591430664, -4.089223861694336, -5.3249406814575195, -0.06105200946331024]]","[[1.4979137182235718, -1.590577244758606, 0.5862289071083069, -0.25769785046577454], [-2.033968687057495, -2.840846300125122, -0.10808251053094864, -0.877410352230072], [-1.72908616065979, -2.7925174236297607, -0.9345141053199768, -1.441590666770935]]", +6676367,"[[69.78858947753906, 53.50902557373047, 63.16095733642578, -34.84336471557617], [28.3581485748291, 8.022404670715332, -29.504745483398438, 47.268455505371094], [19.14493751525879, 24.85756492614746, -53.912784576416016, -74.6889877319336]]","[[18.038612365722656, -22.782411575317383, -23.893470764160156, -1.615665078163147], [-15.572731018066406, 6.45712947845459, -22.30083465576172, 43.831302642822266], [14.26174259185791, -11.669622421264648, -0.7779999375343323, 23.53053855895996]]","[[-11.959007263183594, 4.956019401550293, -7.027940273284912, 0.33905893564224243], [-2.8083086013793945, 3.9247255325317383, -6.848259449005127, -15.556193351745605], [-59.83305358886719, -8.178778648376465, 12.3436861038208, -59.3747673034668]]", +6679119,"[[-101.50489044189453, 36.98146057128906, 1.1366711854934692, 63.51799774169922], [-46.073787689208984, 46.47813034057617, -104.55518341064453, 22.0133113861084], [53.618961334228516, 13.772979736328125, -20.572906494140625, 70.72864532470703]]","[[-17.259140014648438, -40.22257614135742, -83.89826965332031, 51.26460266113281], [-75.39704895019531, -31.925426483154297, 59.456993103027344, 45.64452362060547], [69.73789978027344, -66.56124877929688, -11.179588317871094, 32.963687896728516]]","[[-3.6455090045928955, -0.9025506377220154, -0.38173770904541016, -0.06332780420780182], [0.3781833350658417, 0.2325819581747055, -1.8777588605880737, -2.328604221343994], [-3.388141632080078, 3.184062957763672, 1.2819944620132446, -1.5070877075195312]]", +6679413,"[[-4.243306636810303, -7.72474479675293, 9.24826717376709, 0.6170316934585571], [6.635672092437744, -5.348725318908691, -9.360831260681152, -0.630814790725708], [7.321711540222168, 9.013751983642578, -3.012909412384033, -1.1782095432281494]]","[[37.10626983642578, -16.38568687438965, 28.856245040893555, 18.558990478515625], [34.18963623046875, -33.233551025390625, -18.14596939086914, -1.2347465753555298], [60.65745162963867, -18.7542781829834, -27.731416702270508, -7.704121112823486]]","[[-0.621583878993988, 5.9794158935546875, 6.679257392883301, -6.4304022789001465], [-14.296626091003418, -1.3633551597595215, 14.830915451049805, 5.049946308135986], [-7.077603340148926, 1.2459709644317627, -8.561247825622559, 26.62421417236328]]", +6684315,"[[-10.12170124053955, -3.4088053703308105, -4.571578502655029, -18.48915672302246], [11.414365768432617, -14.282251358032227, -3.3087334632873535, 5.553377151489258], [3.9793548583984375, 10.295501708984375, -13.825559616088867, 17.952943801879883]]","[[-28.288414001464844, -2.9670181274414062, -47.25498580932617, -24.528217315673828], [-29.686786651611328, -26.871307373046875, 9.226241111755371, -18.19982147216797], [-56.61760330200195, 3.2128732204437256, -17.26993179321289, 2.0199966430664062]]","[[5.996661186218262, 12.5982084274292, 2.2733395099639893, -3.376871347427368], [-1.1309608221054077, -7.995626926422119, 1.557140827178955, 2.26686692237854], [-0.9957780838012695, 8.70053482055664, -0.016650710254907608, -3.3884832859039307]]", +6684666,"[[50.06999969482422, -49.15105438232422, 3.551252841949463, -0.5807281732559204], [-40.77178955078125, -50.042633056640625, 5.525672435760498, -14.99709415435791], [33.131587982177734, -57.343509674072266, 57.94015884399414, 44.270660400390625]]","[[18.365150451660156, 20.553979873657227, 42.92595672607422, -18.730113983154297], [-50.91423034667969, 21.813535690307617, 3.634154796600342, 4.894844055175781], [-34.97544479370117, -49.256649017333984, -22.60733413696289, -22.322555541992188]]","[[-9.029924392700195, 15.004953384399414, 17.38901138305664, 7.252900123596191], [-18.01449203491211, 5.875588417053223, 5.970248222351074, -1.5177265405654907], [2.648449420928955, -4.20261812210083, -9.511507987976074, -12.106976509094238]]", +6685415,"[[35.113365173339844, -77.67644500732422, 50.20078659057617, 8.420661926269531], [-0.2904244661331177, 29.286212921142578, 3.7101621627807617, -40.164581298828125], [17.763580322265625, 18.826738357543945, -35.414276123046875, 1.2325835227966309]]","[[-18.88475227355957, -16.808897018432617, -34.2154541015625, -41.5155029296875], [1.1879208087921143, -34.65090560913086, -52.162071228027344, 35.442989349365234], [15.725982666015625, -15.781286239624023, -68.25137329101562, -37.75694274902344]]","[[-2.339825391769409, 0.8828161954879761, -0.016874248161911964, -4.491491794586182], [-1.3867239952087402, 1.5863149166107178, -1.9972578287124634, 2.9583418369293213], [-1.0680830478668213, -0.24776069819927216, -4.734472274780273, -1.0308974981307983]]", +6685730,"[[-25.302352905273438, -30.782779693603516, -27.899147033691406, -24.196697235107422], [-15.96635913848877, -8.835006713867188, -24.468276977539062, -8.283778190612793], [-44.42436981201172, 35.36659240722656, 13.494463920593262, 35.456520080566406]]","[[-1.8692127466201782, -64.31143951416016, 86.1197738647461, 62.821712493896484], [-69.96611022949219, -21.99472427368164, -10.378890991210938, 32.68306350708008], [55.907291412353516, -90.03367614746094, 2.1420021057128906, -67.6994857788086]]","[[-12.555275917053223, 18.81964874267578, -5.327393531799316, 49.367645263671875], [29.352062225341797, -30.749753952026367, -21.662687301635742, -8.036517143249512], [-53.613895416259766, 6.692556381225586, -0.9449663162231445, 24.52071762084961]]", +6687280,"[[-3.5868778228759766, -2.3832733631134033, 0.589530348777771, 11.288500785827637], [-10.782378196716309, -2.957477569580078, 7.451799392700195, 4.292666435241699], [7.849056720733643, -14.427438735961914, -11.46081829071045, 5.859859466552734]]","[[-48.13723373413086, -135.08383178710938, -40.40618133544922, 35.46659851074219], [-6.033477306365967, 40.350830078125, -42.84685516357422, -52.211002349853516], [-111.6717529296875, 9.398500442504883, -43.886268615722656, -82.76837921142578]]","[[33.583717346191406, 5.88064432144165, 6.754786014556885, -32.132755279541016], [-17.764244079589844, 82.55520629882812, -30.020442962646484, 24.64034652709961], [-89.43638610839844, -5.825024604797363, 56.58000564575195, -11.378416061401367]]", +6687805,"[[-41.159664154052734, -19.464439392089844, 25.74068260192871, 1.872757911682129], [-48.20577621459961, -22.682870864868164, -12.677288055419922, 10.750849723815918], [26.64324188232422, -59.10613250732422, 21.565441131591797, -40.196861267089844]]","[[18.406984329223633, -31.138717651367188, 25.356599807739258, 12.595141410827637], [-33.44847869873047, -21.521642684936523, 25.265287399291992, -23.24741554260254], [-29.790891647338867, 16.76827621459961, -24.88990020751953, 31.059619903564453]]","[[22.057342529296875, -39.534358978271484, 61.31902313232422, -44.508331298828125], [2.0680832862854004, 19.670103073120117, -17.426483154296875, -8.790121078491211], [10.202092170715332, -1.5388749837875366, 8.696026802062988, -5.0123610496521]]", +6687869,"[[0.07490905374288559, -1.273012638092041, 0.6518950462341309, -0.6229028701782227], [-0.36243799328804016, -2.1689934730529785, -1.3617095947265625, 2.652318000793457], [1.0193486213684082, -3.5647265911102295, 5.747159004211426, -1.6326216459274292]]","[[-7.607456207275391, 28.280895233154297, -18.318817138671875, 24.297714233398438], [11.687575340270996, -2.194053888320923, -3.6463708877563477, -9.412092208862305], [20.877317428588867, -12.823661804199219, 25.226224899291992, -1.2035841941833496]]","[[2.239046096801758, 2.992509365081787, -3.4312944412231445, 2.049949884414673], [-6.315691947937012, 0.6009021997451782, -1.2477636337280273, -1.7523036003112793], [1.2447832822799683, -5.492430210113525, 3.384784698486328, 1.0218923091888428]]", +6689012,"[[10.561138153076172, 24.293989181518555, 8.856136322021484, -8.001155853271484], [-19.200115203857422, -4.852842330932617, 14.093475341796875, 39.930023193359375], [-22.35072898864746, 33.056114196777344, -6.0201215744018555, -39.17753601074219]]","[[-16.157075881958008, -1.0930131673812866, 5.450170040130615, -40.11588668823242], [5.857770919799805, 23.14008331298828, -15.793992042541504, -5.903223037719727], [-13.125252723693848, 11.460956573486328, 117.74522399902344, -92.62993621826172]]","[[-6.055703639984131, -12.719776153564453, 8.75866985321045, 0.42572861909866333], [2.1340320110321045, -13.77920150756836, 14.589826583862305, 2.4945802688598633], [14.152591705322266, 13.32497501373291, -2.875643014907837, -6.7089128494262695]]", +6690072,"[[-1.419852614402771, 0.2286374419927597, 0.3345853388309479, 0.2729721665382385], [0.6338568925857544, -0.8546611666679382, 0.869610607624054, -2.08027982711792], [-1.1930214166641235, -0.2104170322418213, -0.9776290655136108, -0.7793132066726685]]","[[33.953224182128906, -25.176177978515625, 11.610986709594727, 27.586698532104492], [25.29509925842285, 8.522324562072754, 4.060436248779297, 37.154415130615234], [-37.96092987060547, 53.34375, 49.52286148071289, 62.254817962646484]]","[[104.4286117553711, 68.14720916748047, 24.368501663208008, 44.12657928466797], [-13.939213752746582, 7.622133255004883, 4.232577323913574, 22.11945343017578], [-9.448180198669434, 40.441768646240234, -8.889688491821289, -3.6762876510620117]]", +6691144,"[[22.178943634033203, 12.019502639770508, 31.294391632080078, 46.64274978637695], [44.24049758911133, 19.075437545776367, -3.5804359912872314, -37.055137634277344], [-79.22559356689453, 63.357879638671875, 19.648544311523438, 70.82246398925781]]","[[-0.05032595247030258, 1.217042088508606, 0.5083667635917664, 0.5189406871795654], [-0.6396045684814453, -0.5928763151168823, -0.7169412970542908, -0.1308005154132843], [1.3062458038330078, 1.1942483186721802, 1.5429742336273193, -1.3320108652114868]]","[[-0.9213723540306091, 2.8213400840759277, -1.3256995677947998, -1.1574915647506714], [0.6758412718772888, -0.9888588786125183, -0.7084240317344666, 0.0021383522544056177], [1.3449786901474, -0.31319916248321533, -0.27399250864982605, -0.09528028964996338]]", +6691398,"[[3.93902850151062, -18.370189666748047, 20.255130767822266, 12.914362907409668], [-9.456490516662598, 40.97371292114258, 40.47419738769531, -5.51539945602417], [27.83440589904785, -25.091535568237305, 27.184425354003906, -44.957481384277344]]","[[47.26085662841797, -0.5326224565505981, -32.40896987915039, 33.84670639038086], [28.53173065185547, -2.5092074871063232, 35.7825813293457, -13.36374282836914], [-78.28602600097656, 70.24180603027344, -28.044546127319336, 30.846582412719727]]","[[23.940338134765625, -29.73199462890625, -36.24672317504883, 38.23209762573242], [-9.304244995117188, 8.964479446411133, -10.025215148925781, -5.452454566955566], [34.69584274291992, -28.6850643157959, 51.733543395996094, 5.806751251220703]]", +6694392,"[[-32.651275634765625, 6.973701477050781, -19.875314712524414, 3.2574033737182617], [52.17797088623047, 17.92009735107422, -2.8579277992248535, 18.69631576538086], [4.248002529144287, -1.815674901008606, -18.37828826904297, -46.7224235534668]]","[[-12.207379341125488, -31.712554931640625, 28.55495262145996, 15.864582061767578], [44.41409683227539, -52.78776550292969, -31.630704879760742, -8.993188858032227], [9.403763771057129, -15.210376739501953, -2.9357264041900635, 7.387798309326172]]","[[12.651803016662598, -1.8179872035980225, -8.39204216003418, -10.712897300720215], [28.703163146972656, 17.471147537231445, -16.652137756347656, -19.11032485961914], [-13.872947692871094, 4.021103858947754, 2.7464599609375, -9.595636367797852]]", +6695534,"[[-13.405084609985352, 9.418004035949707, 19.678552627563477, -3.701709032058716], [22.82848358154297, 24.53850746154785, -11.5900297164917, 1.3250523805618286], [-19.844371795654297, -3.264390707015991, 1.127580165863037, -31.00078010559082]]","[[-4.07348108291626, 39.030860900878906, 51.24266052246094, 6.860040187835693], [-30.623088836669922, 36.461055755615234, 0.7204872965812683, -15.802491188049316], [-57.89153289794922, 44.76567459106445, -42.03072738647461, 9.589776039123535]]","[[-3.89924693107605, -54.185546875, -29.025619506835938, -23.89274024963379], [11.800339698791504, 6.643435001373291, -28.981651306152344, -24.919710159301758], [-2.8200690746307373, 24.84758186340332, -3.097881317138672, -16.181795120239258]]", +6697080,"[[-1.2122610807418823, -58.39776611328125, -1.0357400178909302, -2.3798153400421143], [-0.0755167007446289, 103.1572265625, 64.51415252685547, -16.170059204101562], [29.636972427368164, 34.911067962646484, -45.17095184326172, 32.64284896850586]]","[[8.11724853515625, 40.43547821044922, -5.311676025390625, -17.611391067504883], [-54.05378341674805, 2.7808048725128174, 43.850887298583984, -14.37118148803711], [3.0494017601013184, -36.97222900390625, 17.106531143188477, 0.831373929977417]]","[[-13.555829048156738, -6.29565954208374, 10.418721199035645, 56.299983978271484], [35.56388854980469, -16.308576583862305, 11.416728019714355, 8.350481033325195], [-58.433876037597656, -4.182812690734863, 37.23177719116211, 10.30872631072998]]", +6698610,"[[-6.041436195373535, 9.156213760375977, 1.798213005065918, -11.413951873779297], [-9.282434463500977, -6.500437259674072, 11.692222595214844, -0.07979172468185425], [-10.110106468200684, -1.6897872686386108, -6.469929218292236, 7.471987724304199]]","[[-36.39547348022461, 82.41608428955078, 50.620750427246094, -42.36143112182617], [22.46283721923828, 28.61631965637207, 2.2778117656707764, -10.755844116210938], [-42.1403694152832, 73.60430145263672, 8.922992706298828, 23.86935806274414]]","[[0.27653270959854126, 4.387450695037842, 25.69333839416504, -21.38128662109375], [41.85139083862305, -5.535375118255615, -2.226491928100586, -15.576900482177734], [-4.816816806793213, 5.0897955894470215, -9.330107688903809, 15.888705253601074]]", +6699778,"[[-11.416746139526367, -6.792178630828857, -7.349616050720215, -17.06553077697754], [-22.967361450195312, -5.880982875823975, 23.729190826416016, 40.73758316040039], [10.738740921020508, 17.928733825683594, -5.958589553833008, -20.662246704101562]]","[[10.528449058532715, 5.952178955078125, -15.57581901550293, 3.7095937728881836], [-45.46097946166992, 18.745311737060547, -7.266114711761475, 17.222610473632812], [-6.2384467124938965, -5.815711975097656, 0.8986830711364746, -16.490060806274414]]","[[23.08553123474121, 24.72433853149414, -4.271855354309082, -7.6843366622924805], [3.9064691066741943, -2.9235575199127197, -2.4568283557891846, -11.444953918457031], [-2.5821399688720703, -16.44976806640625, 6.171452522277832, 11.762450218200684]]", +6699788,"[[12.881512641906738, -26.03125, -31.379695892333984, -0.18674679100513458], [-17.41887855529785, -15.735973358154297, -58.854496002197266, -8.337396621704102], [-51.53794860839844, -14.746176719665527, -56.51786422729492, 29.540775299072266]]","[[12.720685958862305, 5.674424171447754, 17.607324600219727, -5.377486228942871], [-61.59711456298828, -0.8513314127922058, 22.417675018310547, -11.627851486206055], [-5.549165725708008, -6.901993274688721, -21.57571029663086, 7.405230522155762]]","[[-46.049503326416016, 43.273468017578125, 1.1638909578323364, 16.160280227661133], [62.02693557739258, 58.560646057128906, -87.04951477050781, -86.58930969238281], [-28.130151748657227, -34.52029037475586, 14.721302032470703, 26.94708824157715]]", +6702876,"[[52.45893859863281, -100.3917007446289, -33.253238677978516, -29.724864959716797], [21.84295654296875, -13.550008773803711, -68.54496765136719, 49.517982482910156], [62.37989807128906, 3.997654914855957, 44.85454559326172, -36.87840270996094]]","[[32.890541076660156, -15.513359069824219, 20.226282119750977, 20.90232276916504], [-16.48346710205078, 4.854818820953369, 11.418033599853516, -0.8921002745628357], [-25.857093811035156, 6.862802982330322, -7.632472515106201, -3.579554319381714]]","[[-60.80701446533203, -5.687557220458984, 15.815590858459473, 14.843527793884277], [-45.50221252441406, 20.321155548095703, -51.37971115112305, 15.229931831359863], [107.47465515136719, -24.3437442779541, -16.632457733154297, -16.89650535583496]]", +6702928,"[[18.301334381103516, -78.94622802734375, -6.391412734985352, 17.14784812927246], [43.93109893798828, 22.418628692626953, -55.411903381347656, -62.78773498535156], [-10.213981628417969, -45.5717658996582, -47.030120849609375, -38.92612838745117]]","[[18.429698944091797, -13.791690826416016, -31.753252029418945, 7.937379360198975], [-9.21422290802002, -16.363967895507812, 3.4244227409362793, 11.589090347290039], [0.4459109306335449, -42.66963577270508, -26.912336349487305, -45.24558639526367]]","[[3.989264965057373, -10.394329071044922, 2.9203333854675293, -15.809683799743652], [0.13381671905517578, 6.432409763336182, -13.255852699279785, 5.919764518737793], [0.5510143637657166, -4.444141387939453, -3.3461360931396484, -8.943599700927734]]", +6705228,"[[-28.573619842529297, 40.02886199951172, -15.857441902160645, 40.771156311035156], [-15.363801002502441, -22.323780059814453, 28.274812698364258, 22.44999885559082], [5.446697235107422, 10.143810272216797, 37.238800048828125, 15.129122734069824]]","[[18.833009719848633, 45.32041931152344, -24.788484573364258, -7.510664463043213], [40.43198013305664, 4.926146984100342, -3.4637563228607178, 0.8702676892280579], [28.37637710571289, 17.996984481811523, -9.186692237854004, -20.234752655029297]]","[[6.580479145050049, 1.9564368724822998, -7.401310920715332, -8.521329879760742], [-3.6674721240997314, -5.470401763916016, 11.234831809997559, -4.046061038970947], [4.957516193389893, -7.476752758026123, 2.370039224624634, -5.951333999633789]]", diff --git a/Lab1_2/Lab1&2_Transformers-base.ipynb b/Lab1_2/Lab1&2_Transformers-base.ipynb new file mode 100644 index 0000000..02c6d80 --- /dev/null +++ b/Lab1_2/Lab1&2_Transformers-base.ipynb @@ -0,0 +1,1146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Cv-9Vzunb_tf" + }, + "source": [ + "# Import Necessary Library" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "4f-K54nHb-Uq" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data\n", + "import math\n", + "import os\n", + "import urllib.request\n", + "import pandas as pd\n", + "from functools import partial\n", + "from urllib.error import HTTPError\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BtW5eDFocsMA" + }, + "source": [ + "# What is Attention?\n", + "\n", + "Attention in neural networks, particularly relevant for sequential tasks, refers to a mechanism that selectively focuses on certain parts of input data. This concept has gained significant interest in recent years. In essence, attention computes a weighted average of elements in a sequence, with the weights being dynamically determined based on the relevance of each element to a specific query. This allows the model to prioritize certain inputs over others.\n", + "\n", + "The attention mechanism consists of four primary components:\n", + "\n", + "* **Query**: A feature vector representing the target of the attention, essentially indicating the information the model seeks within the sequence.\n", + "* **Keys**: Feature vectors corresponding to each input element, describing the content or relevance of the elements. The keys help the model identify which elements to focus on, relative to the query.\n", + "* **Values**: Feature vectors representing the actual content from each input element that the model should aggregate.\n", + "* **Score function**: A function used to calculate attention weights, representing the relevance of each key-query pair. Common implementations include simple operations like the dot product or more complex structures like a small neural network.\n", + "\n", + "The attention mechanism operates by first computing scores between the query and each key using the score function. These scores determine the attention weights through a softmax function, ensuring that they sum to one and are non-negative. The output is then calculated as the weighted sum of the value vectors, with weights corresponding to the calculated attention scores.\n", + "\n", + "Mathematically, this process can be represented as:\n", + "\n", + "$$\n", + "\\alpha_i = \\frac{\\exp\\left(f_{attn}\\left(\\text{key}_i, \\text{query}\\right)\\right)}{\\sum_j \\exp\\left(f_{attn}\\left(\\text{key}_j, \\text{query}\\right)\\right)}, \\hspace{5mm} \\text{out} = \\sum_i \\alpha_i \\cdot \\text{value}_i\n", + "$$\n", + "\n", + "In practice, attention mechanisms can vary based on the choice of queries, the definition of key and value vectors, and the specific score function used. A prominent example is the **self-attention** mechanism used in the Transformer architecture, where each element in a sequence provides its own key, value, and query. The self-attention mechanism allows each element to attend to all elements in the sequence, including itself, resulting in a representation that incorporates information from the entire sequence.\n", + "\n", + "The above explanation provides a conceptual understanding of the attention mechanism, highlighting its components and operational principles without delving into the specific details of any particular implementation, such as the scaled dot product attention used in Transformers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1DFh9Ic8dp-u" + }, + "source": [ + "### Scaled Dot Product Attention\n", + "\n", + "The scaled dot product attention is a fundamental component of the self-attention mechanism, enabling elements within a sequence to efficiently attend to one another. It operates on queries $Q\\in\\mathbb{R}^{T\\times d_k}$, keys $K\\in\\mathbb{R}^{T\\times d_k}$, and values $V\\in\\mathbb{R}^{T\\times d_v}$, where $T$ represents the sequence length and $d_k$, $d_v$ denote the dimensions of queries/keys and values, respectively.\n", + "\n", + "The mechanism calculates the attention values based on the dot product similarity between each query $Q_i$ and key $K_j$, and scales the results by the square root of the dimensionality of the keys, $d_k$. The formula for this calculation is:\n", + "\n", + "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n", + "\n", + "Here, the matrix product $QK^T$ computes the dot product between all pairs of queries and keys, forming a $T\\times T$ matrix where each entry represents the attention score from one element to another. After applying the softmax function, these scores are used as weights to compute a weighted average of the value vectors.\n", + "\n", + "The scaling factor $1/\\sqrt{d_k}$ is critical for maintaining the variance of the attention scores at an appropriate level. Without this scaling, the variance of the dot products could become too large, leading to a situation where the softmax function saturates, with most of its output concentrated on a single element. This would hinder learning by resulting in gradients that are almost zero.\n", + "\n", + "Additionally, the mechanism can include an optional masking step (denoted as `Mask (opt.)` in the diagram), useful in situations like batch processing of sequences of varying lengths. Padding is used to equalize the lengths of sequences, and the mask ensures that the padded positions do not affect the attention calculation, typically by assigning a very low value to these positions in the attention scores.\n", + "\n", + "In summary, the scaled dot product attention efficiently enables each element in a sequence to attend to all others, considering the relevance of each element, and is crucial for models that rely on self-attention, such as Transformers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VKvaGxqIdvba" + }, + "source": [ + "### Implementing Scaled Dot Product Attention\n", + "\n", + "Scaled dot product attention is a core mechanism allowing each element in a sequence to consider all other elements efficiently, which is fundamental in self-attention models like Transformers. Here's a detailed guide to implementing scaled dot product attention, breaking down the components and the steps involved.\n", + "\n", + "#### Inputs to the Attention Mechanism\n", + "The attention function takes three inputs:\n", + "1. **Queries (Q)**: $Q\\in\\mathbb{R}^{T\\times d_k}$, where $T$ is the sequence length and $d_k$ is the dimensionality of the queries and keys.\n", + "2. **Keys (K)**: $K\\in\\mathbb{R}^{T\\times d_k}$.\n", + "3. **Values (V)**: $V\\in\\mathbb{R}^{T\\times d_v}$, where $d_v$ is the dimensionality of the values.\n", + "\n", + "#### Step-by-Step Calculation\n", + "1. **Dot Product of Queries and Keys**: Calculate the dot product between each query and all keys to obtain a measure of compatibility or relevance between each query-key pair. This results in a matrix of shape $T \\times T$, where each element $(i, j)$ represents the dot product between query $i$ and key $j$.\n", + " \n", + " $$\\text{Score Matrix} = QK^T$$\n", + "\n", + "2. **Scaling**: Scale the scores obtained in the previous step by dividing by $\\sqrt{d_k}$ to ensure stable gradients, as larger values of $d_k$ can lead to extremely small gradients, which can slow down learning and model convergence.\n", + "\n", + " $$\\text{Scaled Score Matrix} = \\frac{\\text{Score Matrix}}{\\sqrt{d_k}}$$\n", + "\n", + "3. **Optional Masking**: If masking is required (e.g., for padded positions in a batch of sequences), apply the mask by setting the scores for masked positions to a very large negative value, ensuring that they have minimal impact after the softmax step.\n", + "\n", + "4. **Softmax**: Apply the softmax function to the scaled scores along each row. This step converts the scores into probabilities, indicating the importance of each key relative to each query.\n", + "\n", + " $$\\text{Attention Weights} = \\text{softmax}(\\text{Scaled Score Matrix})$$\n", + "\n", + "5. **Output Calculation**: Multiply the attention weights by the value vectors to obtain the final output. This step computes a weighted average of the value vectors, where the weights are determined by the attention scores.\n", + "\n", + " $$\\text{Output} = \\text{Attention Weights} \\times V$$\n", + "\n", + "#### Implementation Tips\n", + "- **Dimensionality**: Ensure the dimensions of your matrices are correct. Matrix multiplication will not be possible if the inner dimensions do not match.\n", + "- **Numerical Stability**: When implementing the softmax function, ensure numerical stability by subtracting the maximum value in each row of the scores matrix before applying the exponential function.\n", + "- **Batch Processing**: If implementing attention in batch, include an additional batch dimension in your matrices (e.g., $Q\\in\\mathbb{R}^{B\\times T\\times d_k}$ for a batch size of $B$) and ensure your implementation supports this.\n", + "- **Testing**: Verify the correctness of your implementation with simple test cases to ensure it behaves as expected.\n", + "\n", + "This framework should provide a clear structure for students to implement scaled dot product attention, enhancing their understanding of its role and functionality in self-attention models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jsFoInPLeFk9" + }, + "source": [ + "# Task: Please implement a scaled dot product function" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 2)" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dk = 2\n", + "t = 3\n", + "v = torch.randn(t, dk)\n", + "\n", + "len(v[0]),v.shape[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_54620/1567451330.py:14: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " softmax = torch.nn.functional.softmax(scaled)\n" + ] + }, + { + "data": { + "text/plain": [ + "(tensor([[5., 5.],\n", + " [5., 5.],\n", + " [5., 5.]]),\n", + " tensor([[5., 5.],\n", + " [5., 5.],\n", + " [5., 5.]]),\n", + " tensor([[5., 5.],\n", + " [5., 5.],\n", + " [5., 5.]]),\n", + " tensor([[5., 5., 5.],\n", + " [5., 5., 5.]]),\n", + " tensor([[50., 50., 50.],\n", + " [50., 50., 50.],\n", + " [50., 50., 50.]]),\n", + " tensor([[35.3553, 35.3553, 35.3553],\n", + " [35.3553, 35.3553, 35.3553],\n", + " [35.3553, 35.3553, 35.3553]]),\n", + " tensor([[0.3333, 0.3333, 0.3333],\n", + " [0.3333, 0.3333, 0.3333],\n", + " [0.3333, 0.3333, 0.3333]]),\n", + " tensor([[5., 5.],\n", + " [5., 5.],\n", + " [5., 5.]]),\n", + " 2,\n", + " tensor([5., 5.]))" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dk = 2\n", + "t = 3\n", + "\n", + "a = torch.zeros(t, dk) + 5\n", + "b = torch.zeros(t, dk) + 5\n", + "v = torch.zeros(t, dk) + 5\n", + "\n", + "bt = b.mT\n", + "\n", + "dot = torch.mm(a, bt)\n", + "\n", + "scaled = dot/math.sqrt(dk)\n", + "\n", + "softmax = torch.nn.functional.softmax(scaled)\n", + "\n", + "output = torch.mm(softmax, v)\n", + "\n", + "a,b,v,bt, dot, scaled, softmax, output, len(a[0]),a[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "XCv8_IzSdut4" + }, + "outputs": [], + "source": [ + "def scaled_dot_product(q, k, v, mask=None):\n", + " # implemented by the student, you can ignore the mask implementation currently\n", + " # just assignment all the mask is on\n", + "\n", + " shape_len = len(k.shape)\n", + "\n", + " transpose = k.mT\n", + " d = k.shape[-1]\n", + "\n", + " score_scale = torch.matmul(q, transpose)/math.sqrt(d)\n", + "\n", + " attention_weight = torch.nn.functional.softmax(score_scale, 1)\n", + "\n", + " output = torch.matmul(attention_weight, v)\n", + "\n", + " return output, attention_weight" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "Q = torch.Tensor([[-0.19737370312213898, -1.0540887117385864, 0.02383515052497387, 0.46185705065727234], [-1.2415547370910645, 0.8366656303405762, 0.3741966784000397, 0.9099264740943909], [0.3436168134212494, 0.6154376268386841, 1.1926648616790771, 1.6477248668670654]])\n", + "K = torch.Tensor([[1.9663442373275757, 0.15551914274692535, -0.8715013861656189, 0.32070425152778625], [-5.85474967956543, 1.7047394514083862, -1.0024793148040771, 1.3307985067367554], [0.06319630891084671, -2.030783176422119, -5.436811447143555, -0.42979586124420166]])\n", + "V = torch.Tensor([[-82.127197265625, 0.9534303545951843, -28.78610610961914, -10.762138366699219], [-16.467313766479492, 60.92831802368164, -36.08392333984375, 31.648052215576172], [20.485767364501953, 45.4570198059082, 15.208494186401367, 31.43212890625]])\n", + "\n", + "ans = scaled_dot_product(Q, K, V)[0].tolist()\n", + "\n", + "pf = pd.read_csv(\"A1_template_template.csv\")\n", + "\n", + "pf.loc[0] = [6644818, Q.tolist(), K.tolist(), V.tolist(), ans]\n", + "\n", + "pf.to_csv('A1_template.csv', sep=',', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wMIShH5wcrUK", + "outputId": "b5e6f270-0cae-4f2c-d388-e0e26ed28b6a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 1.2840, 0.9623],\n", + " [ 1.0821, -0.2264],\n", + " [ 0.4840, -1.0348]])\n", + "tensor([[ 0.0392, 0.2658],\n", + " [ 3.1410, 1.9842],\n", + " [ 1.2559, -1.1543]])\n", + "tensor([[ 0.2172, -0.7752],\n", + " [-1.0788, -1.9513],\n", + " [ 0.9364, -1.2229]])\n", + "tensor([[-1.0142, -1.9154],\n", + " [-0.4535, -1.6679],\n", + " [ 0.5474, -1.2476]])\n", + "tensor(-1.2095e-05)\n" + ] + } + ], + "source": [ + "# Test case\n", + "seq_len, d_k = 3, 2\n", + "torch.manual_seed(3025)\n", + "q = torch.randn(seq_len, d_k)\n", + "k = torch.randn(seq_len, d_k)\n", + "v = torch.randn(seq_len, d_k)\n", + "valid = torch.tensor([[-1.0142, -1.9154],\n", + " [-0.4535, -1.6679],\n", + " [ 0.5474, -1.2476]])\n", + "output, attention_weight = scaled_dot_product(q,k,v)\n", + "differences = (output - valid).mean()\n", + "print(q)\n", + "print(k)\n", + "print(v)\n", + "print(output)\n", + "print(differences)\n", + "assert torch.abs(differences) < 0.0001, 'the product must be similar output as expected'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnDq4vT7kEGE" + }, + "source": [ + "# Multi-Head Attention\n", + "\n", + "Multi-Head Attention is an advancement over the scaled dot product attention, enabling the model to concurrently attend to information from different representation subspaces at different positions. This is particularly useful when dealing with complex data where different elements of the sequence may have different types of relevance or relationships to other elements.\n", + "\n", + "#### Concept\n", + "Instead of a single attention \"head,\" Multi-Head Attention uses multiple sets of Query, Key, and Value weight matrices to project the input into different subspaces, allowing the model to capture various aspects of the information. Each set of projections is referred to as a \"head.\" The attention outputs from each head are then concatenated and linearly transformed into the expected dimension.\n", + "\n", + "#### Mathematical Representation\n", + "Given Query, Key, and Value matrices (Q, K, V), the process can be mathematically described as:\n", + "\n", + "$$\n", + "\\begin{split}\n", + " \\text{Multihead}(Q,K,V) & = \\text{Concat}(\\text{head}_1,...,\\text{head}_h)W^{O}\\\\\n", + " \\text{where } \\text{head}_i & = \\text{Attention}(QW_i^Q,KW_i^K, VW_i^V)\n", + "\\end{split}\n", + "$$\n", + "\n", + "In this formula:\n", + "- $W_i^Q \\in \\mathbb{R}^{D \\times d_k}$, $W_i^K \\in \\mathbb{R}^{D \\times d_k}$, and $W_i^V \\in \\mathbb{R}^{D \\times d_v}$ are parameter matrices for the $i$-th attention head.\n", + "- $W^O \\in \\mathbb{R}^{h \\cdot d_k \\times d_{out}}$ is the parameter matrix for the linear transformation after concatenating the heads.\n", + "- $D$ is the dimensionality of the input, $h$ is the number of heads, and $d_{out}$ is the output dimensionality.\n", + "\n", + "#### Integration in Neural Networks\n", + "In a neural network, the Multi-Head Attention layer is typically applied to a feature map $X \\in \\mathbb{R}^{B \\times T \\times d_{\\text{model}}}$, where $B$ is the batch size, $T$ is the sequence length, and $d_{\\text{model}}$ is the dimensionality of the model's hidden layer. Here, $X$ serves as $Q$, $K$, and $V$. The transformation to query, key, and value representations is done using separate learnable weight matrices $W^Q$, $W^K$, and $W^V$.\n", + "\n", + "#### Implementation Notes\n", + "- **Heads**: Each head captures different aspects of the input data. More heads allow the model to simultaneously focus on different subspaces.\n", + "- **Dimensionality**: Ensure the dimensions of your weight matrices and inputs align correctly.\n", + "- **Efficiency**: Despite the increased complexity, Multi-Head Attention can be efficiently parallelized, making it suitable for large-scale problems.\n", + "\n", + "By utilizing Multi-Head Attention, models can gain a more nuanced understanding of the data, capturing various types of relationships within the sequence. This is especially beneficial in complex tasks like language understanding, where different words or phrases may have different kinds of relationships with others in the sequence." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([0, 1, 2, 3]), tensor([4, 5, 6, 7]), tensor([ 8, 9, 10]))" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.arange(11).chunk(3, dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "id": "zOiDz_FkkDDm" + }, + "outputs": [], + "source": [ + "class MultiheadAttention(nn.Module):\n", + " def __init__(self, input_dim, embed_dim, num_heads):\n", + " super().__init__()\n", + " assert embed_dim % num_heads == 0, \"Embedding dimension must be 0 modulo number of heads.\"\n", + "\n", + " self.embed_dim = embed_dim\n", + " self.num_heads = num_heads\n", + " self.head_dim = embed_dim // num_heads\n", + " self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)\n", + " self.o_proj = nn.Linear(embed_dim, embed_dim)\n", + "\n", + " self._reset_parameters()\n", + "\n", + " def _reset_parameters(self):\n", + " # Original Transformer initialization, see PyTorch documentation\n", + " nn.init.xavier_uniform_(self.qkv_proj.weight)\n", + " self.qkv_proj.bias.data.fill_(0)\n", + " nn.init.xavier_uniform_(self.o_proj.weight)\n", + " self.o_proj.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, mask=None, return_attention=False):\n", + " batch_size, seq_length, embed_dim = x.size()\n", + " qkv = self.qkv_proj(x)\n", + " qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)\n", + " qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]\n", + " q, k, v = qkv.chunk(3, dim=-1)\n", + " values, attention = scaled_dot_product(q, k, v, mask=mask)\n", + " values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]\n", + " values = values.reshape(batch_size, seq_length, embed_dim)\n", + " o = self.o_proj(values)\n", + "\n", + " if return_attention:\n", + " return o, attention\n", + " else:\n", + " return o" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sLI_NEVtlSNI" + }, + "source": [ + "# Transformer Encoder\n", + "\n", + "The Transformer Encoder plays a crucial role in transforming input sequences into rich, attention-based representations, primarily used in Sequence-to-Sequence tasks like machine translation. While the original Transformer model consists of both encoder and decoder, the encoder alone has been foundational in numerous advances in NLP and beyond. This section focuses on the encoder's architecture, function, and key components.\n", + "\n", + "#### Overview\n", + "The Transformer Encoder is composed of a stack of $N$ identical layers, each containing two main sub-layers:\n", + "\n", + "1. **Multi-Head Attention Mechanism**: Enables the model to attend to different positions of the input sequence simultaneously.\n", + "2. **Position-wise Feed-Forward Networks**: Consists of fully connected layers applied to each position separately, allowing for individual processing of each sequence element.\n", + "\n", + "#### Encoder Architecture\n", + "Each layer in the encoder includes the following steps:\n", + "\n", + "1. **Input Processing**: The input $x$ (where $x$ can be $Q$, $K$, and $V$) is first passed through the Multi-Head Attention mechanism.\n", + "2. **Residual Connection and Layer Normalization**: The output from the Multi-Head Attention is then added back to the input $x$ through a residual connection, followed by layer normalization:\n", + " \n", + " $$\\text{LayerNorm}(x + \\text{Multihead}(x, x, x))$$\n", + "\n", + " The residual connections help in maintaining the flow of the original input information through the network and are crucial for training deeper models by improving gradient flow. Layer Normalization is used to stabilize the learning process and ensure consistent feature magnitude across sequence elements.\n", + "\n", + "3. **Position-wise Feed-Forward Networks (FFN)**: Each position is processed individually by a two-layered feed-forward network with ReLU activation in between:\n", + " \n", + " $$\n", + " \\begin{split}\n", + " \\text{FFN}(x) & = \\max(0, xW_1 + b_1)W_2 + b_2\\\\\n", + " x & = \\text{LayerNorm}(x + \\text{FFN}(x))\n", + " \\end{split}\n", + " $$\n", + "\n", + " This component allows for further processing of the information added by the attention mechanism, preparing it for the next layer.\n", + "\n", + "#### Considerations in Design\n", + "- **Layer Normalization**: Chosen over Batch Normalization due to its independence from batch size and better performance in language tasks.\n", + "- **Dimensionality of MLP in FFN**: Typically 2-8 times larger than the dimensionality of the input $x$ ($d_{\\text{model}}$), allowing for more complex transformations and faster parallelizable execution.\n", + "- **Dropout**: Applied in MLP and on the outputs of MLP and Multi-Head Attention for regularization.\n", + "\n", + "The Transformer Encoder's architecture, with its repetitive yet intricate structure, allows for effective processing and transformation of sequence data, making it a powerful tool in various sequence modeling tasks. The next steps involve implementing the encoder block, paying close attention to the integration of Multi-Head Attention, residual connections, layer normalization, and feed-forward networks within each layer." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "id": "a1HiddBnlW4J" + }, + "outputs": [], + "source": [ + "class EncoderBlock(nn.Module):\n", + " def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):\n", + " \"\"\"EncoderBlock.\n", + "\n", + " Args:\n", + " input_dim: Dimensionality of the input\n", + " num_heads: Number of heads to use in the attention block\n", + " dim_feedforward: Dimensionality of the hidden layer in the MLP\n", + " dropout: Dropout probability to use in the dropout layers\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " # Attention layer\n", + " self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)\n", + "\n", + " # Two-layer MLP\n", + " self.linear_net = nn.Sequential(\n", + " nn.Linear(input_dim, dim_feedforward),\n", + " nn.Dropout(dropout),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(dim_feedforward, input_dim),\n", + " )\n", + "\n", + " # Layers to apply in between the main layers\n", + " self.norm1 = nn.LayerNorm(input_dim)\n", + " self.norm2 = nn.LayerNorm(input_dim)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, mask=None):\n", + " # Attention part\n", + " attn_out = self.self_attn(x, mask=mask)\n", + " x = x + self.dropout(attn_out)\n", + " x = self.norm1(x)\n", + "\n", + " # MLP part\n", + " linear_out = self.linear_net(x)\n", + " x = x + self.dropout(linear_out)\n", + " x = self.norm2(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "\n", + "\n", + "class TransformerEncoder(nn.Module):\n", + " def __init__(self, num_layers, **block_args):\n", + " super().__init__()\n", + " self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])\n", + "\n", + " def forward(self, x, mask=None):\n", + " for layer in self.layers:\n", + " x = layer(x, mask=mask)\n", + " return x\n", + "\n", + " def get_attention_maps(self, x, mask=None):\n", + " attention_maps = []\n", + " for layer in self.layers:\n", + " _, attn_map = layer.self_attn(x, mask=mask, return_attention=True)\n", + " attention_maps.append(attn_map)\n", + " x = layer(x)\n", + " return attention_maps\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "class PositionalEncoding(nn.Module):\n", + " def __init__(self, d_model, max_len=5000):\n", + " \"\"\"Positional Encoding.\n", + "\n", + " Args:\n", + " d_model: Hidden dimensionality of the input.\n", + " max_len: Maximum length of a sequence to expect.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs\n", + " pe = torch.zeros(max_len, d_model)\n", + " position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n", + " div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = torch.sin(position * div_term)\n", + " pe[:, 1::2] = torch.cos(position * div_term)\n", + " pe = pe.unsqueeze(0)\n", + "\n", + " # register_buffer => Tensor which is not a parameter, but should be part of the modules state.\n", + " # Used for tensors that need to be on the same device as the module.\n", + " # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)\n", + " self.register_buffer(\"pe\", pe, persistent=False)\n", + "\n", + " def forward(self, x):\n", + " x = x + self.pe[:, : x.size(1)]\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bKeMF9xLmQH0" + }, + "source": [ + "# Sequence to Sequence Tasks\n", + "\n", + "Sequence to Sequence (Seq2Seq) tasks involve converting an input sequence into an output sequence, where the input and output may vary in length. This model structure is commonly used in applications like machine translation, text summarization, and more. Typically, a Seq2Seq model comprises an encoder to interpret the input sequence and a decoder to generate the output sequence autoregressively.\n", + "\n", + "#### Simplified Task: Sequence Reversal\n", + "For educational purposes, we'll focus on a simplified Seq2Seq task: reversing a sequence of numbers. Despite its simplicity, this task is a good testbed for understanding Seq2Seq models, especially since it requires capturing long-term dependencies, something traditional RNNs might struggle with, but Transformers are well-equipped to handle.\n", + "\n", + "#### Task Description:\n", + "- **Input**: A sequence of $N$ numbers ranging from $0$ to $M$.\n", + "- **Output**: The reversed sequence of the input.\n", + "\n", + "In Numpy, if our input sequence is $x$, the desired output is $x$[::-1]. Although straightforward, this task provides a clear demonstration of a model's ability to handle sequences and understand dependencies across positions.\n", + "\n", + "#### Implementation Steps:\n", + "- **Create a Dataset Class**: The first step is to create a dataset class that can generate sequences of numbers and their reversed counterparts. This class will be used to train and evaluate the Seq2Seq model.\n", + "\n", + "By starting with this simple task, we can focus on the mechanics and capabilities of the Transformer encoder in handling sequences, setting the stage for tackling more complex Seq2Seq tasks in the future." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "id": "PSBkeOmtmPhX" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(16, 10)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class ReverseDataset(data.Dataset):\n", + " def __init__(self, num_categories, seq_len, size):\n", + " super().__init__()\n", + " self.num_categories = num_categories\n", + " self.seq_len = seq_len\n", + " self.size = size\n", + "\n", + " self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))\n", + "\n", + " def __len__(self):\n", + " return self.size\n", + "\n", + " def __getitem__(self, idx):\n", + " inp_data = self.data[idx]\n", + " labels = torch.flip(inp_data, dims=(0,))\n", + " return inp_data, labels\n", + "\n", + "seq_len = 16\n", + "num_categories = 10\n", + "batch_size = 128\n", + "dataset = partial(ReverseDataset, num_categories, seq_len)\n", + "train_loader = data.DataLoader(dataset(10000), batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)\n", + "val_loader = data.DataLoader(dataset(1000), batch_size=64, drop_last=True, shuffle=False)\n", + "\n", + "seq_len, num_categories" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VZ52A-Hhma4b" + }, + "source": [ + "# Compose the network" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "id": "JxlCGvdomaDJ" + }, + "outputs": [], + "source": [ + "class TransformerPredictor(nn.Module):\n", + " def __init__(\n", + " self,\n", + " input_dim,\n", + " model_dim,\n", + " num_classes,\n", + " num_heads,\n", + " num_layers,\n", + " dropout=0.0,\n", + " input_dropout=0.0,\n", + " ):\n", + " \"\"\"TransformerPredictor.\n", + "\n", + " Args:\n", + " input_dim: Hidden dimensionality of the input\n", + " model_dim: Hidden dimensionality to use inside the Transformer\n", + " num_classes: Number of classes to predict per sequence element\n", + " num_heads: Number of heads to use in the Multi-Head Attention blocks\n", + " num_layers: Number of encoder blocks to use.\n", + " dropout: Dropout to apply inside the model\n", + " input_dropout: Dropout to apply on the input features\n", + " \"\"\"\n", + " super().__init__()\n", + " # Input dim -> Model dim\n", + " self.input_net = nn.Sequential(\n", + " nn.Dropout(input_dropout),\n", + " nn.Linear(input_dim, model_dim)\n", + " )\n", + " # Positional encoding for sequences\n", + " self.positional_encoding = PositionalEncoding(d_model=model_dim)\n", + " # Transformer\n", + " self.transformer = TransformerEncoder(\n", + " num_layers=num_layers,\n", + " input_dim=model_dim,\n", + " dim_feedforward=2 * model_dim,\n", + " num_heads=num_heads,\n", + " dropout=dropout,\n", + " )\n", + " # Output classifier per sequence lement\n", + " self.output_net = nn.Sequential(\n", + " nn.Linear(model_dim, model_dim),\n", + " nn.LayerNorm(model_dim),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(model_dim, num_classes),\n", + " )\n", + "\n", + " def forward(self, x, mask=None, add_positional_encoding=True):\n", + " \"\"\"\n", + " Args:\n", + " x: Input features of shape [Batch, SeqLen, input_dim]\n", + " mask: Mask to apply on the attention outputs (optional)\n", + " add_positional_encoding: If True, we add the positional encoding to the input.\n", + " Might not be desired for some tasks.\n", + " \"\"\"\n", + " x = self.input_net(x)\n", + " if add_positional_encoding:\n", + " x = self.positional_encoding(x)\n", + " x = self.transformer(x, mask=mask)\n", + " x = self.output_net(x)\n", + " return x\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uUuW7DbBnjsS" + }, + "source": [ + "# Task: Writing Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'NVIDIA GeForce RTX 3090 Ti'" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.get_device_name(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cZaOx-7qni7y", + "outputId": "a181d978-f0e0-451b-9e95-587ce9d8c2bd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "EPOCH 1:\n", + "LOSS train 2.3047391891479494\n", + "EPOCH 2:\n", + "LOSS train 2.3059343338012694\n", + "EPOCH 3:\n", + "LOSS train 2.303365612030029\n", + "EPOCH 4:\n", + "LOSS train 2.278448724746704\n", + "EPOCH 5:\n", + "LOSS train 2.271419906616211\n", + "EPOCH 6:\n", + "LOSS train 2.272932434082031\n", + "EPOCH 7:\n", + "LOSS train 2.2925734519958496\n", + "EPOCH 8:\n", + "LOSS train 2.273321104049683\n", + "EPOCH 9:\n", + "LOSS train 2.2358773708343507\n", + "EPOCH 10:\n", + "LOSS train 2.213832139968872\n", + "EPOCH 11:\n", + "LOSS train 2.1974945068359375\n", + "EPOCH 12:\n", + "LOSS train 2.138561820983887\n", + "EPOCH 13:\n", + "LOSS train 2.131034755706787\n", + "EPOCH 14:\n", + "LOSS train 2.092240905761719\n", + "EPOCH 15:\n", + "LOSS train 2.028573417663574\n", + "EPOCH 16:\n", + "LOSS train 2.009480619430542\n", + "EPOCH 17:\n", + "LOSS train 2.0119842529296874\n", + "EPOCH 18:\n", + "LOSS train 1.966892457008362\n", + "EPOCH 19:\n", + "LOSS train 1.9739716291427611\n", + "EPOCH 20:\n", + "LOSS train 1.9691336631774903\n", + "EPOCH 21:\n", + "LOSS train 1.9692431688308716\n", + "EPOCH 22:\n", + "LOSS train 1.9720533609390258\n", + "EPOCH 23:\n", + "LOSS train 1.9636199712753295\n", + "EPOCH 24:\n", + "LOSS train 2.0134324550628664\n", + "EPOCH 25:\n", + "LOSS train 1.9760711431503295\n", + "EPOCH 26:\n", + "LOSS train 1.9674718379974365\n", + "EPOCH 27:\n", + "LOSS train 1.9762607574462892\n", + "EPOCH 28:\n", + "LOSS train 1.9764994859695435\n", + "EPOCH 29:\n", + "LOSS train 1.9699514150619506\n", + "EPOCH 30:\n", + "LOSS train 1.95670006275177\n", + "EPOCH 31:\n", + "LOSS train 1.946057105064392\n", + "EPOCH 32:\n", + "LOSS train 1.9565371990203857\n", + "EPOCH 33:\n", + "LOSS train 1.9599705457687377\n", + "EPOCH 34:\n", + "LOSS train 1.9696622133255004\n", + "EPOCH 35:\n", + "LOSS train 1.9993957996368408\n", + "EPOCH 36:\n", + "LOSS train 1.9636467695236206\n", + "EPOCH 37:\n", + "LOSS train 1.980830192565918\n", + "EPOCH 38:\n", + "LOSS train 1.9654539108276368\n", + "EPOCH 39:\n", + "LOSS train 1.9689129829406737\n", + "EPOCH 40:\n", + "LOSS train 1.955962347984314\n", + "EPOCH 41:\n", + "LOSS train 1.9647478580474853\n", + "EPOCH 42:\n", + "LOSS train 1.9532663106918335\n", + "EPOCH 43:\n", + "LOSS train 1.9503717422485352\n", + "EPOCH 44:\n", + "LOSS train 1.9499874591827393\n", + "EPOCH 45:\n", + "LOSS train 1.9529696941375732\n", + "EPOCH 46:\n", + "LOSS train 1.9518198251724244\n", + "EPOCH 47:\n", + "LOSS train 1.9523835182189941\n", + "EPOCH 48:\n", + "LOSS train 1.9561205148696899\n", + "EPOCH 49:\n", + "LOSS train 1.9675297260284423\n", + "EPOCH 50:\n", + "LOSS train 2.123178768157959\n", + "EPOCH 51:\n", + "LOSS train 1.970911931991577\n", + "EPOCH 52:\n", + "LOSS train 1.9587018251419068\n", + "EPOCH 53:\n", + "LOSS train 1.9622526168823242\n", + "EPOCH 54:\n", + "LOSS train 1.9551706790924073\n", + "EPOCH 55:\n", + "LOSS train 1.953707218170166\n", + "EPOCH 56:\n", + "LOSS train 1.9466333389282227\n", + "EPOCH 57:\n", + "LOSS train 1.9582770824432374\n", + "EPOCH 58:\n", + "LOSS train 1.9466321229934693\n", + "EPOCH 59:\n", + "LOSS train 1.9557215929031373\n", + "EPOCH 60:\n", + "LOSS train 1.9505679607391357\n", + "EPOCH 61:\n", + "LOSS train 1.9520682334899901\n", + "EPOCH 62:\n", + "LOSS train 1.955586814880371\n", + "EPOCH 63:\n", + "LOSS train 1.9475157499313354\n", + "EPOCH 64:\n", + "LOSS train 1.9377191305160522\n", + "EPOCH 65:\n", + "LOSS train 1.938973307609558\n", + "EPOCH 66:\n", + "LOSS train 1.9429319143295287\n", + "EPOCH 67:\n", + "LOSS train 1.9438214540481566\n", + "EPOCH 68:\n", + "LOSS train 1.9364233016967773\n", + "EPOCH 69:\n", + "LOSS train 1.9589627027511596\n", + "EPOCH 70:\n", + "LOSS train 1.9416004180908204\n", + "EPOCH 71:\n", + "LOSS train 1.9382025003433228\n", + "EPOCH 72:\n", + "LOSS train 1.9310474157333375\n", + "EPOCH 73:\n", + "LOSS train 1.939470887184143\n", + "EPOCH 74:\n", + "LOSS train 1.9370584964752198\n", + "EPOCH 75:\n", + "LOSS train 1.9400960445404052\n", + "EPOCH 76:\n", + "LOSS train 1.9455738306045531\n", + "EPOCH 77:\n", + "LOSS train 1.9308057308197022\n", + "EPOCH 78:\n", + "LOSS train 1.9302024841308594\n", + "EPOCH 79:\n", + "LOSS train 1.9345638751983643\n", + "EPOCH 80:\n", + "LOSS train 1.9394316673278809\n", + "EPOCH 81:\n", + "LOSS train 1.9305338621139527\n", + "EPOCH 82:\n", + "LOSS train 1.9336636304855346\n", + "EPOCH 83:\n", + "LOSS train 1.921869659423828\n", + "EPOCH 84:\n", + "LOSS train 1.9273949146270752\n", + "EPOCH 85:\n", + "LOSS train 1.916986870765686\n", + "EPOCH 86:\n", + "LOSS train 1.9170085191726685\n", + "EPOCH 87:\n", + "LOSS train 1.9086025476455688\n", + "EPOCH 88:\n", + "LOSS train 1.911173439025879\n", + "EPOCH 89:\n", + "LOSS train 1.8953789472579956\n", + "EPOCH 90:\n", + "LOSS train 1.9057527542114259\n", + "EPOCH 91:\n", + "LOSS train 1.883132243156433\n", + "EPOCH 92:\n", + "LOSS train 1.895125651359558\n", + "EPOCH 93:\n", + "LOSS train 1.8788848161697387\n", + "EPOCH 94:\n", + "LOSS train 1.870681929588318\n", + "EPOCH 95:\n", + "LOSS train 1.8580753803253174\n", + "EPOCH 96:\n", + "LOSS train 1.8578021287918092\n", + "EPOCH 97:\n", + "LOSS train 1.8478908300399781\n", + "EPOCH 98:\n", + "LOSS train 1.8258821010589599\n", + "EPOCH 99:\n", + "LOSS train 1.8657661914825439\n", + "EPOCH 100:\n", + "LOSS train 1.8161733627319336\n" + ] + } + ], + "source": [ + "input_dim = 10 # This needs to be 10 because yes\n", + "model_dim = 1024 # size of the hidden layer (transformers)\n", + "num_classes = train_loader.dataset.num_categories\n", + "num_heads = 8\n", + "num_layers = 1\n", + "with torch.cuda.device(torch.device('cuda')):\n", + " \n", + " # please create the model\n", + " model = TransformerPredictor(input_dim, model_dim, num_classes, num_heads, num_layers).cuda()\n", + "\n", + " # please create the optimizer\n", + " optimizer = torch.optim.Adam(model.parameters())\n", + " loss_fn = torch.nn.CrossEntropyLoss()\n", + " # please train the model, with the whole training pipeline\n", + " \n", + " def train(epoch_index, tb_writer):\n", + " running_loss = 0\n", + " last_loss = 0\n", + " \n", + " for i, data in enumerate(train_loader):\n", + " inputs, labels = data\n", + " \n", + " # inputs = inputs.to(torch.float32)\n", + " inputs = F.one_hot(inputs, num_classes=num_classes).float().cuda()\n", + "\n", + " labels = labels.cuda()\n", + " \n", + " outputs = model(inputs)\n", + " \n", + " loss = loss_fn(outputs.view(-1, 10), labels.view(-1))\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " \n", + " # Adjust learning weights\n", + " optimizer.step()\n", + " \n", + " running_loss += loss.item()\n", + " if i % 5 == 0:\n", + " last_loss = running_loss / 5 # loss per batch\n", + " # print(' batch {} loss: {}'.format(i + 1, last_loss))\n", + " tb_x = epoch_index * len(train_loader) + i + 1\n", + " # tb_writer.add_scalar('Loss/train', last_loss, tb_x)\n", + " running_loss = 0.\n", + " \n", + " return last_loss\n", + " \n", + " \n", + " timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\n", + " # writer = torch.utils.tensorboard.writer.SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))\n", + " epoch_number = 0\n", + " \n", + " EPOCHS = 100\n", + " \n", + " best_vloss = 1_000_000.\n", + " \n", + " for epoch in range(EPOCHS):\n", + " print('EPOCH {}:'.format(epoch_number + 1))\n", + " \n", + " # Make sure gradient tracking is on, and do a pass over the data\n", + " model.train(True)\n", + " avg_loss = train(epoch_number, None)\n", + " \n", + " avg_vloss = 0\n", + " print('LOSS train {}'.format(avg_loss))\n", + " \n", + " epoch_number += 1\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NVqbotkCrCSy" + }, + "source": [ + "# Evaluation\n", + "Here is the evaluation code, can you do better than 2.0?" + ] + }, + { + "cell_type": "code", + "execution_count": 218, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vkRNkGBspuZh", + "outputId": "f5a0ff6d-e24c-4a94-d5f8-8113956f3b18" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 2.3046957651774087\n" + ] + } + ], + "source": [ + "# Validating the validation loss\n", + "criterion = nn.CrossEntropyLoss()\n", + "# Validation loop\n", + "model.eval()\n", + "with torch.no_grad():\n", + " val_loss = 0\n", + " for inputs, labels in val_loader:\n", + " inp_data = F.one_hot(inputs, num_classes=10).float().cuda()\n", + " outputs = model(inp_data)\n", + " loss = criterion(outputs.view(1024,10), labels.view(-1).cuda())\n", + " val_loss += loss.item()\n", + " print(f\"Validation Loss: {val_loss / len(val_loader)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Lab1_2/Lab1&2_Transformers.ipynb b/Lab1_2/Lab1&2_Transformers.ipynb new file mode 100644 index 0000000..a5a6b3f --- /dev/null +++ b/Lab1_2/Lab1&2_Transformers.ipynb @@ -0,0 +1,93 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Cv-9Vzunb_tf" + }, + "source": [ + "# Import Necessary Library" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "4f-K54nHb-Uq" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data\n", + "import math\n", + "import os\n", + "import urllib.request\n", + "import pandas as pd\n", + "from functools import partial\n", + "from urllib.error import HTTPError\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "XCv8_IzSdut4" + }, + "outputs": [], + "source": [ + "def scaled_dot_product(q, k, v, mask=None):\n", + " # implemented by the student, you can ignore the mask implementation currently\n", + " # just assignment all the mask is on\n", + "\n", + " shape_len = len(k.shape)\n", + "\n", + " transpose = k.mT\n", + " d = k.shape[-1]\n", + "\n", + " score_scale = torch.matmul(q, transpose)/math.sqrt(d)\n", + "\n", + " attention_weight = torch.nn.functional.softmax(score_scale, 1)\n", + "\n", + " output = torch.matmul(attention_weight, v)\n", + "\n", + " return output, attention_weight" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Lab3/Week3_Autoencoder+MAE - Copy.ipynb b/Lab3/Week3_Autoencoder+MAE - Copy.ipynb new file mode 100644 index 0000000..97d3ff1 --- /dev/null +++ b/Lab3/Week3_Autoencoder+MAE - Copy.ipynb @@ -0,0 +1,7058 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "YLLecc85VRCL" + }, + "source": [ + "# Introduction & Import Necessary Setup\n", + "In this labsheet, we'll delve into the fascinating world of autoencoders (AEs), a type of neural network renowned for its ability to compress and reconstruct data. Autoencoders work by first encoding input data, such as images, into a compact feature vector through an encoder network. This process effectively distills the essence of the data into a smaller, more manageable form. The feature vector, often referred to as the \"bottleneck,\" plays a crucial role in this compression process, allowing us to represent the input data with fewer features.\n", + "\n", + "Following compression, a second neural network, known as the decoder, takes over to reconstruct the original data from the compressed feature vector. This remarkable ability to compress and then reconstruct data makes autoencoders extremely valuable in various applications, including data compression and image comparison at a more meaningful level than mere pixel-by-pixel analysis.\n", + "\n", + "Moreover, our exploration will not stop at the autoencoder framework itself. We will also introduce the concept of \"deconvolution\" (also known as transposed convolution), a powerful operator used to enlarge feature maps in both height and width dimensions. Deconvolution networks are indispensable in scenarios where we begin with a compact feature vector and aim to generate a full-sized image. This technique is pivotal in various advanced neural network applications, such as Variational Autoencoders (VAEs), Generative Adversarial Networks (GANs), and super-resolution.\n", + "\n", + "To kick things off, we'll start by importing our standard libraries, setting the stage for our deep dive into the inner workings and applications of autoencoders." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4e2G_wAgIWxD", + "outputId": "a8fc0cbc-1aa7-4dd6-a94e-cf5320483c9f" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_192718/407458918.py:11: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n", + " set_matplotlib_formats('svg', 'pdf') # For export\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: cuda:0\n" + ] + } + ], + "source": [ + "## Standard libraries\n", + "import os\n", + "import json\n", + "import math\n", + "import numpy as np\n", + "\n", + "## Imports for plotting\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "from IPython.display import set_matplotlib_formats\n", + "set_matplotlib_formats('svg', 'pdf') # For export\n", + "from matplotlib.colors import to_rgb\n", + "import matplotlib\n", + "matplotlib.rcParams['lines.linewidth'] = 2.0\n", + "## Progress bar\n", + "from tqdm.notebook import tqdm\n", + "\n", + "## PyTorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.utils.data as data\n", + "import torch.optim as optim\n", + "# Torchvision\n", + "import torchvision\n", + "from torchvision.datasets import CIFAR10\n", + "from torchvision import transforms\n", + "\n", + "DATASET_PATH = \"dataset\"\n", + "\n", + "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "print(\"Device:\", device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "INLuLKepWdvC" + }, + "source": [ + "# Download and setup the dataset\n", + "In this labsheet, our focus shifts to the CIFAR10 dataset, a collection known for its rich, colored images. Each image within CIFAR10 is equipped with 3 color channels and boasts a resolution of 32x32 pixels. This characteristic is particularly advantageous when working with autoencoders, as they are not bound by the constraints of probabilistic image modeling.\n", + "\n", + "Should you already have the CIFAR10 dataset downloaded in a different directory, it's important to adjust the DATASET_PATH variable accordingly. This step ensures you avoid unnecessary additional downloads, streamlining your workflow and allowing you to dive into the practical exercises more swiftly." + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yH_BjGbuJIrJ", + "outputId": "fc3e192f-fd42-4cd6-a5f8-82971331940b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'list' object has no attribute 'DataLoader'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[105], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m test_set \u001b[38;5;241m=\u001b[39m CIFAR10(root\u001b[38;5;241m=\u001b[39mDATASET_PATH, train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, transform\u001b[38;5;241m=\u001b[39mtransform, download\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# We define a set of data loaders that we can use for various purposes later.\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDataLoader\u001b[49m(train_set, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m256\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, drop_last\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, pin_memory\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m)\n\u001b[1;32m 14\u001b[0m val_loader \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mDataLoader(val_set, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m256\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, drop_last\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m)\n\u001b[1;32m 15\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mDataLoader(test_set, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m256\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, drop_last\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'DataLoader'" + ] + } + ], + "source": [ + "# Transformations applied on each image => only make them a tensor\n", + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,),(0.5,))])\n", + "\n", + "# Loading the training dataset. We need to split it into a training and validation part\n", + "train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)\n", + "train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])\n", + "\n", + "# Loading the test set\n", + "test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)\n", + "\n", + "# We define a set of data loaders that we can use for various purposes later.\n", + "train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)\n", + "val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)\n", + "test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)\n", + "\n", + "def get_train_images(num):\n", + " return torch.stack([train_dataset[i][0] for i in range(num)], dim=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4Jl0CTGSkym-" + }, + "source": [ + "# Building the autoencoder\n", + "\n", + "In general, an autoencoder consists of an **encoder** that maps the input $x$ to a lower-dimensional feature vector $z$, and a **decoder** that reconstructs the input $\\hat{x}$ from $z$. We train the model by comparing $x$ to $\\hat{x}$ and optimizing the parameters to increase the similarity between $x$ and $\\hat{x}$. See below for a small illustration of the autoencoder framework.\n", + "\n", + "\n", + "![img](https://raw.githubusercontent.com/hqsiswiliam/COM3025_Torch/main/autoencoder.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_lhJZFR3k1e_" + }, + "source": [ + "\n", + "For an educational purpose revision in markdown format, the text could be enhanced as follows:\n", + "\n", + "To kick off our exploration, we initiate with the construction of the encoder. This component is fundamentally a deep convolutional network tailored for progressively diminishing the image's dimensions. This diminution is achieved through the use of strided convolutions, which methodically reduce the image's size layer by layer. Following the thrice-executed downscaling process, we transition the architecture from convolutional layers to a flattened feature representation. This is achieved by flattening the spatial features into a single vector, which is then processed through several linear layers. As a result, we obtain the latent representation, denoted as\n", + "$z$, encapsulating the compressed essence of the input image. The size of this latent vector, $d$, is adjustable, providing flexibility in the encoding capacity of our network." + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "id": "i6fToFroJMMT" + }, + "outputs": [], + "source": [ + "class Encoder(nn.Module):\n", + "\n", + " def __init__(self,\n", + " num_input_channels : int,\n", + " base_channel_size : int,\n", + " latent_dim : int,\n", + " act_fn : object = nn.GELU):\n", + " \"\"\"\n", + " Inputs:\n", + " - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3\n", + " - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.\n", + " - latent_dim : Dimensionality of latent representation z\n", + " - act_fn : Activation function used throughout the encoder network\n", + " \"\"\"\n", + " super().__init__()\n", + " c_hid = base_channel_size\n", + " self.net = nn.Sequential(\n", + " nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16\n", + " act_fn(),\n", + " nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),\n", + " act_fn(),\n", + " nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8\n", + " act_fn(),\n", + " nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),\n", + " act_fn(),\n", + " nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4\n", + " act_fn(),\n", + " nn.Flatten(), # Image grid to single feature vector\n", + " nn.Linear(2*16*c_hid, latent_dim)\n", + " )\n", + "\n", + " # self.flatten = nn.Sequential(\n", + " # nn.Flatten(), # Image grid to single feature vector\n", + " # nn.Linear(2*16*c_hid, latent_dim)\n", + " # )\n", + "\n", + " def forward(self, x):\n", + " # x = self.net(x)\n", + "\n", + " # print(x.shape)\n", + " \n", + " # return self.flatten(x)\n", + "\n", + " return self.net(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AOOi0C4wm99b" + }, + "source": [ + "# Task1\n", + "Now Complete the decoder implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "metadata": { + "id": "kV2FkEk6JTjk" + }, + "outputs": [], + "source": [ + "class Decoder(nn.Module):\n", + "\n", + " def __init__(self,\n", + " num_input_channels : int,\n", + " base_channel_size : int,\n", + " latent_dim : int,\n", + " act_fn : object = nn.GELU):\n", + " \"\"\"\n", + " Inputs:\n", + " - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3\n", + " - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.\n", + " - latent_dim : Dimensionality of latent representation z\n", + " - act_fn : Activation function used throughout the decoder network\n", + " \"\"\"\n", + " super().__init__()\n", + " c_hid = base_channel_size\n", + " self.net = nn.Sequential(\n", + " nn.Linear(latent_dim, 2*16*c_hid),\n", + " act_fn(),\n", + " nn.Unflatten(1, (2*c_hid, 4, 4)),\n", + " nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 8x8 <= 4x4\n", + " act_fn(),\n", + " nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),\n", + " act_fn(),\n", + " nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 16x16 <= 8x8\n", + " act_fn(),\n", + " nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),\n", + " act_fn(), \n", + " nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, padding=1, stride=2, output_padding=1), # 32x32 <= 16x16\n", + " nn.Tanh(),\n", + " # nn.Sigmoid(),\n", + " )\n", + " # You code goes here.\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + " # You code goes here." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-DYpDGTznGVL" + }, + "source": [ + "# Combining Encoder and Decoder\n", + "## Loss Function: Mean Squared Error (MSE)\n", + "\n", + "For our loss function, we opt for the Mean Squared Error (MSE). MSE is particularly effective in emphasizing the significance of accurately predicting pixel values that are substantially misestimated by the network. For instance, a minor deviation, such as predicting 127 instead of 128, is deemed less critical. However, larger discrepancies, like confusing a pixel value of 0 with 128, are considered more severe and detrimental to the reconstruction quality.\n", + "\n", + "Unlike Variational Autoencoders (VAEs) that predict the probability for each pixel value, we employ MSE as a straightforward distance measure. This approach significantly reduces the number of parameters, streamlining the training process. To enhance our understanding of the per-pixel performance, we calculate the summed squared error, averaged across the batch dimension. It's important to note that alternative aggregations (mean or sum) yield equivalent outcomes in terms of resulting parameters.\n", + "\n", + "### Limitations of MSE\n", + "\n", + "Despite its advantages, MSE is not without drawbacks. Primarily, it tends to produce blurrier images, as it inherently removes small noise and high-frequency patterns, which contribute minimally to the overall error. To mitigate this and achieve more realistic reconstructions, integrating Generative Adversarial Networks (GANs) with autoencoders has proven effective. This hybrid approach is explored in various studies ([example 1](https://arxiv.org/abs/1704.02304), [example 2](https://arxiv.org/abs/1511.05644), and [slides](http://elarosca.net/slides/iccv_autoencoder_gans.pdf)).\n", + "\n", + "Furthermore, MSE may not always accurately reflect visual similarity between images. A case in point is when an autoencoder produces an image that is slightly shifted—despite the near-identical appearance, the MSE can significantly increase, showcasing a limitation in capturing true visual fidelity. A potential solution involves leveraging a pre-trained CNN to measure distance based on visual features extracted from lower layers, offering a more nuanced comparison than pixel-level MSE.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": { + "id": "hd0hdMVuJxhZ" + }, + "outputs": [], + "source": [ + "class Autoencoder(nn.Module):\n", + "\n", + " def __init__(self,\n", + " base_channel_size: int,\n", + " latent_dim: int,\n", + " encoder_class : object = Encoder,\n", + " decoder_class : object = Decoder,\n", + " num_input_channels: int = 3,\n", + " width: int = 32,\n", + " height: int = 32):\n", + " super().__init__()\n", + " # Creating encoder and decoder\n", + " self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)\n", + " self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)\n", + " # Example input array needed for visualizing the graph of the network\n", + " self.example_input_array = torch.zeros(2, num_input_channels, width, height)\n", + "\n", + " def forward(self, x):\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " return x_hat\n", + "\n", + " def _get_reconstruction_loss(self, batch):\n", + " x = batch # We do not need the labels\n", + " x_hat = self.forward(x)\n", + " loss = F.mse_loss(x, x_hat, reduction=\"none\")\n", + " loss = loss.sum(dim=[1,2,3]).mean(dim=[0])\n", + " return loss\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AOHmolo8nkBM" + }, + "source": [ + "# Utility code for comparing Images" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 779 + }, + "id": "_ttCZos4JWpr", + "outputId": "aaae7d91-f4ca-4f36-a495-8e4380992e2e" + }, + "outputs": [ + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T13:49:03.109846\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T13:49:03.134798\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T13:49:03.158725\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T13:49:03.182878\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def compare_imgs(img1, img2, title_prefix=\"\"):\n", + " # Calculate MSE loss between both images\n", + " loss = F.mse_loss(img1, img2, reduction=\"sum\")\n", + " # Plot images for visual comparison\n", + " grid = torchvision.utils.make_grid(torch.stack([img1, img2], dim=0), nrow=2, normalize=True)\n", + " grid = grid.permute(1, 2, 0)\n", + " plt.figure(figsize=(4,2))\n", + " plt.title(f\"{title_prefix} Loss: {loss.item():4.2f}\")\n", + " plt.imshow(grid)\n", + " plt.axis('off')\n", + " plt.show()\n", + "\n", + "for i in range(2):\n", + " # Load example image\n", + " img, _ = train_dataset[i]\n", + " img_mean = img.mean(dim=[1,2], keepdims=True)\n", + "\n", + " # Shift image by one pixel\n", + " SHIFT = 1\n", + " img_shifted = torch.roll(img, shifts=SHIFT, dims=1)\n", + " img_shifted = torch.roll(img_shifted, shifts=SHIFT, dims=2)\n", + " img_shifted[:,:1,:] = img_mean\n", + " img_shifted[:,:,:1] = img_mean\n", + " compare_imgs(img, img_shifted, \"Shifted -\")\n", + "\n", + " # Set half of the image to zero\n", + " img_masked = img.clone()\n", + " img_masked[:,:img_masked.shape[1]//2,:] = img_mean\n", + " compare_imgs(img, img_masked, \"Masked -\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pwn3hj6hnq3z" + }, + "source": [ + "# Task2\n", + "Add training code to train the AutoEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "1c5e517cebbc4d96b4d260676eca961f", + "20f188a8bf53479983bdf9c2df2a55e0", + "6e196069369c4126acdc53ac6da328aa", + "f3b39579ce11475ea2ca67198c64f66a", + "b15d030dacc84d3d89a38d3f48c094e7", + "f0a2ddf4f8a54ccd8e301b54943bb88c", + "d9164be3d674410da5eee962bc243727", + "ba9e7a3a9bbb46faaabdcc944650f4af", + "a4c68ce78f024f9aa4bd2fe3ac296d25", + "20a2d55104cc4a98a06c5bf50a80b51c", + "94ea667c7972465e8166c8747ef24d94", + "35a4be6db57c4ae2b36604feae29d861", + "6cc313af7666494794ee75d95aff9289", + "9152e497a92b4ef1b33612cfd628f739", + "29bed95021784427be407f74be7daff0", + "323712cc4b21465f894dbb2ab3960178", + "9f4caca2a12a4e64998d8e4977e2d038", + "be92dee7c4dc46edb88c555550f9ae37", + "680ba64a00484feba87714c8ac2b2f1e", + "38c1e5f53c674dffb84d117171fc2563", + "9387e448cba046bb913d9d60ceefe363", + "ee038165676342fd952cc958de73697f", + "78274c291a2044a196f5ea743d2853a7", + "1c1872a626dd4856a621bbc425c4a947", + "6f1de415698948e2b4747e34f36684a4", + "529be14dea3e41d89f80b7dcb6347f22", + "a3fc766655ce445dabb70df3fe051df0", + "27587b51d0694c67b0465a4929911cd6", + "1e0d7aa70cf4490d97ad44009f6105c5", + "c9d3591403d242dab30233318dd592ea", + "126c0bb014d84c3abf231605148a8353", + "24729235af21409696bed8f0b01e5127", + "047ea2ea246342f18a2d70390099c0f5", + "caf7a900394c4705bc30d3cdecb0e24d", + "ea89a8bbbf8d43f19fcf91ff8935ec84", + "a0091f5723714384a05b0c27c489ff1b", + "baf572d365b74d8e81eb468b66e6b045", + "c5fac4f1cbf64dd08351ea32fb4b4a59", + "72a4a60175c34a6680df47c0e1002d90", + "c9b3bccd02ee402c99c5c10e6c03530d", + "e3806093ea654edcac0e153bd7ccdf9e", + "1f6e9b83a8744173a162f479dd04dc6f", + "d918542d0318488c88829bc650b6b8cc", + "fcb5fce90c82415589f58057fc51812b", + "223a29e7d81049debc22a75a4027e113", + "96ddd57e6ccf4e9a82a6575b4a9843f1", + "ef976e12f2144286af54f4ee339c08de", + "8da43c5164f64e7d8bb645099e1ee3e6", + "3d98aea664c645089d693365d784a580", + "ec552968d3f64c57b2f854f118dd234b", + "7913abef7f9147e19c39b54878f1d73e", + "4f882fc7f8054471855b77b116fd566b", + "b4f60946184f439ba90f79cda27aa34a", + "2189d23f386a4c00ae11995e974569eb", + "2805a35efa4a401ea88c3a22dd9752f5", + "bcc33f1b00f14139b3719c2f7a622960", + "45fead35e2114df598508e5694f62bef", + "f892714895654834a2bd95d04f2aff67", + "b6a804a7415c41a19cfdd2b3af153629", + "d64e3edbe0914733b7a27f35a71bc9c8", + "a5da50968aae432aa5b3c90c8e7ddb04", + "95c22c37f1b04034b4132ea248af2e94", + "93eb82f2191a4e42887695889f30a503", + "a3e241c8ed1449aa9e3adce6f9fa69bf", + "1a6a711d2b5e4ea7a8c769c36e194335", + "53dac5aeffcb41d388da7f4aaf5e19b9", + "1fc700d2efc1488b84cc18c540f6e497", + "8afa3c2d298b40c9be0312768d98fd7c", + "948e95b9112b4e31945e509c68ab8ec9", + "381466138eaa45278002150b1219293f", + "2a516ab4e8c247599fb7faf9ec95f676", + "e33d1e0e7cb646aa8803b7338f6da888", + "1de75ad1a9e740c18f8cf2ed2cd5955b", + "471be6fd1d8c4925a5ed4d2a9ec7673c", + "0c29f121e84a44608393ade2b1381116", + "0c60f13e34d14fef9d0bbf6d7ded673a", + "2434f5e02bdf4fc78c88d4c146ff6ae7", + "6599dc2951474e4282ff1894ec0851e8", + "b61d6e41157e43ee99f40d8b018877c4", + "2f347bd02cb944ccad43744dd7e4eeea", + "94f52e75bcfe48a2a51fbaf59c22352c", + "0c3bdf21300f4610a68d9dbfa566a1fd", + "8e950e1ea7d047618f38b6619a6312e1", + "5a0653f04c624aed8ba9af0eaad8b0fd", + "cdd01d42b3d24f99b390b8a8fbf8dcf7", + "6c39e034913845b797600de2fafe98aa", + "8dffad816bec4292a338af6c8b3a1e5d", + "9b2657e17fca494aa30a07f515e1d35d", + "4d5f2b6c66904e7bb2f49f7b39174de4", + "f5fbf0c6280c41b59e7f48e05afe8d20", + "c05fd736bc514fd4810efe5b3eaa9c55", + "3e6f882231d2445cbad2dc940eb1c056", + "828eeb2e8ab346a297928bbea0eec155", + "297cbe85fb7041cf94532a565e83397b", + "13477a6d079a45c8b340cfe7f18df03d", + "48d4a0722e9d42398e3ae796f44170e8", + "bb28b9a67e63424ba4203d4a44de8dd4", + "851361f7fb4f4db6a4a165b21d627af3", + "0f00ffdd775042aa8860456d5b3440a2", + "177d3c0836c14df986e732b5331825d6", + "d1c852ac83ca466b839620a2d362d587", + "2eb19f9206bf48809789b6eb15723c10", + "855c15c10a2844b29c036c2ce58c866a", + "070c3ed170534dcaa3da118dccc78aa5", + "c8a81fdd7aa143e9a7bb54d490cb105e", + "c0da3b83a124493fa8364441caa4cf00", + "581562b79aac46fb947d61130897e232", + "aaa0adc4eefd49c1a96055eae215d7eb", + "d2946fea7b7d42fa98ec040769478597", + "96b3a74edb9b4400bf9c1fe1b7010f03" + ] + }, + "id": "dnD8g-r8KB0K", + "outputId": "de2a26d7-50cd-44a2-fdb9-3e3952aa68c8" + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Autoencoder' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# for batch in tqdm(train_loader, total=len(train_loader)):\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moptim\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01moptim\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoencoder\u001b[49m(\u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m128\u001b[39m, ) \u001b[38;5;66;03m# you code here\u001b[39;00m\n\u001b[1;32m 5\u001b[0m model\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 6\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m) \u001b[38;5;66;03m# your code here\u001b[39;00m\n", + "\u001b[0;31mNameError\u001b[0m: name 'Autoencoder' is not defined" + ] + } + ], + "source": [ + "# for batch in tqdm(train_loader, total=len(train_loader)):\n", + "import torch.optim as optim\n", + "\n", + "model = Autoencoder(64, 128, ) # you code here\n", + "model.to(device)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # your code here\n", + "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # your code here, can use ReduceLROnPlateau\n", + "# Write training loop here\n", + "\n", + "loss_fn = nn.MSELoss()\n", + "\n", + "n_epoch = 40\n", + "model.train()\n", + "for epoch in range(n_epoch):\n", + " print(f\"\\nEpoch {epoch}:\")\n", + "\n", + " avg_loss = 0\n", + "\n", + " for i, data in enumerate(train_loader):\n", + " inputs, _ = data\n", + "\n", + " inputs = inputs.cuda()\n", + "\n", + " loss = model._get_reconstruction_loss(inputs) #loss_fn(outputs, inputs)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + "\n", + " avg_loss += loss\n", + "\n", + " print(f'\\rBatch: {i}: Loss:{loss} avg_Loss: {avg_loss/(i + 1)} ', end='')\n", + "\n", + " scheduler.step(loss)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 324 + }, + "id": "5OfaUMh-U3eJ", + "outputId": "bd25e0cd-7c0e-40fe-c472-527750e268cc" + }, + "outputs": [ + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T16:37:03.453161\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def visualize_reconstructions(model, input_imgs):\n", + " # Reconstruct images\n", + " model.eval()\n", + " with torch.no_grad():\n", + " reconst_imgs = model(input_imgs.to(device))\n", + " reconst_imgs = reconst_imgs.cpu()\n", + "\n", + " # Plotting\n", + " imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)\n", + " grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True)\n", + " grid = grid.permute(1, 2, 0)\n", + " plt.figure(figsize=(7,4.5))\n", + " plt.title(f\"Reconstructed from model\")\n", + " plt.imshow(grid)\n", + " plt.axis('off')\n", + " plt.show()\n", + " \n", + "input_imgs = get_train_images(6)\n", + "visualize_reconstructions(model, input_imgs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z5keSdPPnzhP" + }, + "source": [ + "# Masked AutoEncoder\n", + "The follow code are the demonstration of Masked Autoencoder implementation and visualization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qHp1VzhtoYql" + }, + "source": [ + "# Import Necessary Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DKr2eCkDny9B", + "outputId": "b9997a91-52ba-4ca5-c38a-3ac64a224c78" + }, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import requests\n", + "\n", + "import torch\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "\n", + "# check whether run in Colab\n", + "if 'google.colab' in sys.modules:\n", + " print('Running in Colab.')\n", + " !pip3 install timm==0.4.5 # 0.3.2 does not work in Colab\n", + " !git clone https://github.com/facebookresearch/mae.git\n", + " sys.path.append('./mae')\n", + "else:\n", + " sys.path.append('./mae')\n", + "import models_mae" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vrdyiqpWod8J" + }, + "source": [ + "# Build up necessary utillities" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": { + "id": "_De1rOh8ny51" + }, + "outputs": [], + "source": [ + "# define the utils\n", + "\n", + "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", + "imagenet_std = np.array([0.229, 0.224, 0.225])\n", + "\n", + "def show_image(image, title=''):\n", + " # image is [H, W, 3]\n", + " assert image.shape[2] == 3\n", + " plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())\n", + " plt.title(title, fontsize=16)\n", + " plt.axis('off')\n", + " return\n", + "\n", + "def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):\n", + " # build model\n", + " model = getattr(models_mae, arch)()\n", + " # load model\n", + " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", + " msg = model.load_state_dict(checkpoint['model'], strict=False)\n", + " print(msg)\n", + " return model\n", + "\n", + "def run_one_image(img, model):\n", + " x = torch.tensor(img)\n", + "\n", + " # make it a batch-like\n", + " x = x.unsqueeze(dim=0)\n", + " x = torch.einsum('nhwc->nchw', x)\n", + "\n", + " # run MAE\n", + " loss, y, mask = model(x.float(), mask_ratio= 0.75)\n", + " y = model.unpatchify(y)\n", + " y = torch.einsum('nchw->nhwc', y).detach().cpu()\n", + "\n", + " # visualize the mask\n", + " mask = mask.detach()\n", + " mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)\n", + " mask = model.unpatchify(mask) # 1 is removing, 0 is keeping\n", + " mask = torch.einsum('nchw->nhwc', mask).detach().cpu()\n", + "\n", + " x = torch.einsum('nchw->nhwc', x)\n", + "\n", + " # masked image\n", + " im_masked = x * (1 - mask)\n", + "\n", + " # MAE reconstruction pasted with visible patches\n", + " im_paste = x * (1 - mask) + y * mask\n", + "\n", + " # make the plt figure larger\n", + " plt.rcParams['figure.figsize'] = [24, 24]\n", + "\n", + " plt.subplot(1, 4, 1)\n", + " show_image(x[0], \"original\")\n", + "\n", + " plt.subplot(1, 4, 2)\n", + " show_image(im_masked[0], \"masked\")\n", + "\n", + " plt.subplot(1, 4, 3)\n", + " show_image(y[0], \"reconstruction\")\n", + "\n", + " plt.subplot(1, 4, 4)\n", + " show_image(im_paste[0], \"reconstruction + visible\")\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8wt9sd2tolyv" + }, + "source": [ + "# Load one image" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 422 + }, + "id": "_EbEF8gQolnq", + "outputId": "df774563-3070-4585-9809-8a950e3c303e" + }, + "outputs": [ + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T21:06:56.375875\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# load an image\n", + "img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145\n", + "# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851\n", + "img = Image.open(requests.get(img_url, stream=True).raw)\n", + "img = img.resize((224, 224))\n", + "img = np.array(img) / 255.\n", + "\n", + "assert img.shape == (224, 224, 3)\n", + "\n", + "# normalize by ImageNet mean and std\n", + "img = img - imagenet_mean\n", + "img = img / imagenet_std\n", + "\n", + "plt.rcParams['figure.figsize'] = [5, 5]\n", + "show_image(torch.tensor(img))" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RaMd6bemoqLB", + "outputId": "a8612b2b-6695-4fdf-eeef-e1b709fccab0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File ‘mae_visualize_vit_large.pth’ already there; not retrieving.\n", + "\n", + "\n", + "Model loaded.\n" + ] + } + ], + "source": [ + "# Patch for numpy error\n", + "np.float = float\n", + "np.int = int #module 'numpy' has no attribute 'int'\n", + "np.object = object #module 'numpy' has no attribute 'object'\n", + "np.bool = bool #module 'numpy' has no attribute 'bool'\n", + "# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)\n", + "\n", + "# download checkpoint if not exist\n", + "!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n", + "\n", + "chkpt_dir = 'mae_visualize_vit_large.pth'\n", + "model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')\n", + "print('Model loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RaMd6bemoqLB", + "outputId": "a8612b2b-6695-4fdf-eeef-e1b709fccab0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File ‘mae_visualize_vit_huge.pth’ already there; not retrieving.\n", + "\n", + "\n", + "Model loaded.\n" + ] + } + ], + "source": [ + "# Patch for numpy error\n", + "np.float = float\n", + "np.int = int #module 'numpy' has no attribute 'int'\n", + "np.object = object #module 'numpy' has no attribute 'object'\n", + "np.bool = bool #module 'numpy' has no attribute 'bool'\n", + "# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)\n", + "\n", + "# download checkpoint if not exist\n", + "!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_huge.pth\n", + "\n", + "chkpt_dir = 'mae_visualize_vit_huge.pth'\n", + "model_mae = prepare_model(chkpt_dir, 'mae_vit_huge_patch14')\n", + "print('Model loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 503 + }, + "id": "xymH8jt4orm6", + "outputId": "a60ce3ce-bfc2-48f4-e92f-451a55b1322a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAE with pixel reconstruction:\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (3) for operand 0 and no ellipsis was given", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[110], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMAE with pixel reconstruction:\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mrun_one_image\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_mae\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[46], line 28\u001b[0m, in \u001b[0;36mrun_one_image\u001b[0;34m(img, model)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# make it a batch-like\u001b[39;00m\n\u001b[1;32m 27\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39munsqueeze(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m---> 28\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meinsum\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnhwc->nchw\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# run MAE\u001b[39;00m\n\u001b[1;32m 31\u001b[0m loss, y, mask \u001b[38;5;241m=\u001b[39m model(x\u001b[38;5;241m.\u001b[39mfloat(), mask_ratio\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.75\u001b[39m)\n", + "File \u001b[0;32m/usr/lib/python3.11/site-packages/torch/functional.py:380\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 375\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m einsum(equation, \u001b[38;5;241m*\u001b[39m_operands)\n\u001b[1;32m 377\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(operands) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m opt_einsum\u001b[38;5;241m.\u001b[39menabled:\n\u001b[1;32m 378\u001b[0m \u001b[38;5;66;03m# the path for contracting 0 or 1 time(s) is already optimized\u001b[39;00m\n\u001b[1;32m 379\u001b[0m \u001b[38;5;66;03m# or the user has disabled using opt_einsum\u001b[39;00m\n\u001b[0;32m--> 380\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meinsum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mequation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperands\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[1;32m 382\u001b[0m path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opt_einsum\u001b[38;5;241m.\u001b[39mis_available():\n", + "\u001b[0;31mRuntimeError\u001b[0m: einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (3) for operand 0 and no ellipsis was given" + ] + } + ], + "source": [ + "# make random mask reproducible (comment out to make it change)\n", + "torch.manual_seed(2)\n", + "print('MAE with pixel reconstruction:')\n", + "run_one_image(img, model_mae)" + ] + }, + { + "cell_type": "code", + "execution_count": 347, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2400, 2400)\n" + ] + }, + { + "data": { + "text/plain": [ + "118" + ] + }, + "execution_count": 347, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T23:20:24.414200\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145\n", + "# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851\n", + "mine_img = Image.open('./st2/6644818.png', formats=('PNG',)).convert('RGB')# Image.open(requests.get(img_url, stream=True).raw)\n", + "\n", + "print(mine_img.size)\n", + "\n", + "# mine_img.show()\n", + "mine_img = mine_img.resize((224, 224))\n", + "\n", + "mine_img = np.array(mine_img) / 255.\n", + "\n", + "# print(mine_img.shape, mine_img[0][0])\n", + "\n", + "assert mine_img.shape == (224, 224, 3)\n", + "\n", + "# target = np.array([118, 111, 95])\n", + "target = np.array([123, 116, 103])\n", + "\n", + "pre_ids_to_restore = []\n", + "\n", + "for y in range(14):\n", + " for x in range(14):\n", + " if (np.array(mine_img[y * 16 + 8][x * 16 + 8]) * 255 == target).all():\n", + " pre_ids_to_restore.append(x + y * 14)\n", + " #if y == 0: \n", + " # print(np.array([[mine_img[y * 16 + 8][x * 16 + 8]]]) * 255)\n", + " # plt.imshow(np.array([[mine_img[y * 16 + 8][x * 16 + 8]]]))\n", + "\n", + "# normalize by ImageNet mean and std\n", + "mine_img = mine_img - imagenet_mean\n", + "mine_img = mine_img / imagenet_std\n", + "\n", + "plt.rcParams['figure.figsize'] = [5, 5]\n", + "show_image(torch.tensor(mine_img))\n", + "\n", + "len(pre_ids_to_restore)" + ] + }, + { + "cell_type": "code", + "execution_count": 350, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "78 196 [19, 91, 86, 80, 149, 94, 96, 60, 78, 59, 48, 29, 122, 52, 11, 132, 72, 143, 21, 99, 172, 53, 92, 161, 134, 89, 77, 195, 35, 67, 63, 44, 123, 101, 128, 162, 84, 76, 10, 137, 152, 26, 27, 0, 46, 49, 190, 194, 120, 184, 133, 165, 126, 112, 65, 115, 90, 20, 159, 192, 154, 51, 32, 98, 151, 125, 93, 81, 107, 1, 116, 124, 182, 127, 23, 41, 121, 17]\n" + ] + }, + { + "data": { + "text/plain": [ + "(torch.Size([1, 196]), torch.Size([1, 78]))" + ] + }, + "execution_count": 350, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "d = pd.read_csv('st2/6644818/shuffle_info.csv', header=None)\n", + "\n", + "ids_keep = eval(d.loc[0][1])\n", + "ids_restore = eval(d.loc[1][1])\n", + "print(len(ids_keep[0]), len(ids_restore[0]), [ x for x in ids_restore[0] if x in ids_keep[0]])\n", + "\n", + "\n", + "ids_restore = torch.Tensor(ids_restore).type(torch.int64) # torch.Tensor([ids_keep[0] + ids_restore[0]]).type(torch.int64)\n", + "ids_keep = torch.Tensor(ids_keep).type(torch.int64)\n", + "\n", + "ids_restore.shape, ids_keep.shape\n", + "\n", + "# ids_keep = [ x for x in range(14 * 14) if x not in pre_ids_to_restore ]\n", + "\n", + "# ids_restore = torch.Tensor([ids_keep + pre_ids_to_restore]).type(torch.int64)\n", + "\n", + "# ids_keep = torch.Tensor([ids_keep]).type(torch.int64)\n", + "\n", + "# show_image(torch.tensor(mine_img))\n", + "\n", + "# ids_restore, ids_restore.shape, ids_keep" + ] + }, + { + "cell_type": "code", + "execution_count": 352, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAE with pixel reconstruction:\n", + "1024\n", + "tensor([[125, 182, 133, 91, 99, 151, 107, 27, 93, 44, 115, 35, 10, 159,\n", + " 1, 86, 92, 195, 116, 0, 154, 49, 84, 190, 123, 134, 121, 124,\n", + " 65, 26, 76, 19, 162, 194, 59, 90, 23, 63, 51, 29, 41, 192,\n", + " 132, 165, 101, 80, 127, 126, 21, 128, 137, 161, 32, 60, 78, 77,\n", + " 89, 67, 11, 20, 17, 52, 152, 96, 184, 149, 72, 94, 143, 172,\n", + " 122, 53, 46, 98, 48, 120, 112, 81]])\n", + "tensor([[[ 1.0417, 0.7687, 0.1073, ..., 0.5065, 0.5881, 1.1361],\n", + " [ 0.6540, 0.6927, 0.5497, ..., 0.6509, 0.6065, 1.6248],\n", + " [ 1.2860, 1.1067, 0.9193, ..., 0.4753, 0.5549, 1.2062],\n", + " ...,\n", + " [ 1.5381, 1.5673, 1.4636, ..., 0.4356, 0.4163, 1.4357],\n", + " [ 0.7640, 0.8398, 0.6108, ..., 0.6543, 0.6467, 1.4214],\n", + " [-0.3706, -0.1278, -0.1562, ..., 0.5149, 0.7952, 1.8260]]],\n", + " grad_fn=)\n", + "torch.Size([1, 50, 1024]) torch.Size([1, 79, 1024])\n", + "torch.Size([1, 196]) torch.Size([1, 196])\n", + "tensor([[160, 26, 174, 158, 178, 25, 36, 141, 4, 195, 125, 113, 14, 132,\n", + " 137, 116, 191, 129, 11, 179, 133, 54, 6, 150, 190, 20, 105, 134,\n", + " 56, 81, 37, 55, 9, 101, 153, 143, 188, 12, 90, 194, 117, 74,\n", + " 10, 140, 168, 171, 176, 124, 164, 77, 173, 96, 82, 52, 146, 135,\n", + " 157, 40, 49, 189, 46, 43, 76, 70, 34, 172, 111, 138, 166, 39,\n", + " 169, 51, 84, 167, 185, 182, 152, 186, 66, 99, 67, 98, 91, 147,\n", + " 161, 29, 47, 72, 114, 61, 41, 83, 53, 104, 27, 78, 64, 89,\n", + " 170, 159, 71, 7, 30, 24, 94, 57, 31, 126, 106, 86, 65, 181,\n", + " 3, 100, 139, 163, 13, 85, 175, 118, 131, 58, 162, 62, 48, 184,\n", + " 128, 102, 32, 80, 23, 35, 87, 123, 19, 193, 165, 5, 155, 42,\n", + " 21, 187, 60, 69, 112, 149, 107, 75, 108, 50, 28, 127, 110, 145,\n", + " 63, 1, 79, 2, 97, 144, 109, 151, 121, 136, 183, 115, 119, 59,\n", + " 154, 177, 33, 192, 130, 88, 8, 0, 22, 16, 38, 73, 18, 17,\n", + " 93, 15, 95, 120, 122, 156, 148, 44, 68, 103, 180, 142, 45, 92]])\n" + ] + }, + { + "data": { + "application/pdf": "", + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-02-20T23:49:44.104486\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + " def my_masking(self, x):\n", + " \"\"\"\n", + " Perform per-sample random masking by per-sample shuffling.\n", + " Per-sample shuffling is done by argsort random noise.\n", + " x: [N, L, D], sequence\n", + " \"\"\"\n", + " N, L, D = x.shape # batch, length, dim\n", + " len_keep = 14 * 14 - ids_resotre.shape[-1]\n", + "\n", + " print(D)\n", + " \n", + " # keep the first subset\n", + " # ids_keep = torch.Tensor([[ x for x in range(14 * 14) if x not in ids_restore[0] ]]).type(torch.int64)\n", + "\n", + " print(ids_keep)\n", + " \n", + " x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))\n", + "\n", + " print(x_masked)\n", + " \n", + " return x_masked # , mask # , ids_restore\n", + "\n", + "\n", + "def forward_encoder(model, x):\n", + " # embed patches\n", + " x = model.patch_embed(x)\n", + "\n", + " # add pos embed w/o cls token\n", + " x = x + model.pos_embed[:, 1:, :]\n", + "\n", + " x = my_masking(model, x)\n", + "\n", + " # append cls token\n", + " cls_token = model.cls_token + model.pos_embed[:, :1, :]\n", + " cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n", + " x = torch.cat((cls_tokens, x), dim=1)\n", + "\n", + " # apply Transformer blocks\n", + " for blk in model.blocks:\n", + " x = blk(x)\n", + " x = model.norm(x)\n", + "\n", + " return x\n", + "\n", + "def restore_one_image(img, model):\n", + " x = torch.tensor(img)\n", + "\n", + " # make it a batch-like\n", + " x = x.unsqueeze(dim=0)\n", + " x = torch.einsum('nhwc->nchw', x)\n", + "\n", + " # run MAE\n", + " # loss, ty, mask = model(x.float(), mask_ratio=0)\n", + "\n", + " tx = forward_encoder(model, x.float())\n", + "\n", + " l, m, i = model.forward_encoder(x.float(), 0.75);\n", + "\n", + " print(l.shape, tx.shape)\n", + " print(i.shape, ids_restore.shape)\n", + "\n", + " print(i)\n", + "\n", + " ty = model.forward_decoder(tx, ids_restore)\n", + " \n", + " y = model.unpatchify(ty)\n", + " y = torch.einsum('nchw->nhwc', y).detach().cpu()\n", + "\n", + " x = torch.einsum('nchw->nhwc', x)\n", + " \n", + " #mask = model.unpatchify(x.float()) # 1 is removing, 0 is keeping\n", + " #mask = torch.einsum('nchw->nhwc', mask).detach().cpu()\n", + "\n", + " # make the plt figure larger\n", + " plt.rcParams['figure.figsize'] = [12, 12]\n", + "\n", + " plt.subplot(1, 2, 1)\n", + " show_image(x[0], \"original\")\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " show_image(y[0], \"reconstruction\")\n", + " \n", + " plt.show()\n", + "\n", + "torch.manual_seed(5)\n", + "print('MAE with pixel reconstruction:')\n", + "restore_one_image(mine_img, model_mae)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "047ea2ea246342f18a2d70390099c0f5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "070c3ed170534dcaa3da118dccc78aa5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0c29f121e84a44608393ade2b1381116": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0c3bdf21300f4610a68d9dbfa566a1fd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0c60f13e34d14fef9d0bbf6d7ded673a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0f00ffdd775042aa8860456d5b3440a2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "126c0bb014d84c3abf231605148a8353": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "13477a6d079a45c8b340cfe7f18df03d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "177d3c0836c14df986e732b5331825d6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_d1c852ac83ca466b839620a2d362d587", + "IPY_MODEL_2eb19f9206bf48809789b6eb15723c10", + "IPY_MODEL_855c15c10a2844b29c036c2ce58c866a" + ], + "layout": "IPY_MODEL_070c3ed170534dcaa3da118dccc78aa5" + } + }, + "1a6a711d2b5e4ea7a8c769c36e194335": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c1872a626dd4856a621bbc425c4a947": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_27587b51d0694c67b0465a4929911cd6", + "placeholder": "​", + "style": "IPY_MODEL_1e0d7aa70cf4490d97ad44009f6105c5", + "value": "100%" + } + }, + "1c5e517cebbc4d96b4d260676eca961f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_20f188a8bf53479983bdf9c2df2a55e0", + "IPY_MODEL_6e196069369c4126acdc53ac6da328aa", + "IPY_MODEL_f3b39579ce11475ea2ca67198c64f66a" + ], + "layout": "IPY_MODEL_b15d030dacc84d3d89a38d3f48c094e7" + } + }, + "1de75ad1a9e740c18f8cf2ed2cd5955b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1e0d7aa70cf4490d97ad44009f6105c5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1f6e9b83a8744173a162f479dd04dc6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1fc700d2efc1488b84cc18c540f6e497": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8afa3c2d298b40c9be0312768d98fd7c", + "IPY_MODEL_948e95b9112b4e31945e509c68ab8ec9", + "IPY_MODEL_381466138eaa45278002150b1219293f" + ], + "layout": "IPY_MODEL_2a516ab4e8c247599fb7faf9ec95f676" + } + }, + "20a2d55104cc4a98a06c5bf50a80b51c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "20f188a8bf53479983bdf9c2df2a55e0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f0a2ddf4f8a54ccd8e301b54943bb88c", + "placeholder": "​", + "style": "IPY_MODEL_d9164be3d674410da5eee962bc243727", + "value": "100%" + } + }, + "2189d23f386a4c00ae11995e974569eb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "223a29e7d81049debc22a75a4027e113": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_96ddd57e6ccf4e9a82a6575b4a9843f1", + "IPY_MODEL_ef976e12f2144286af54f4ee339c08de", + "IPY_MODEL_8da43c5164f64e7d8bb645099e1ee3e6" + ], + "layout": "IPY_MODEL_3d98aea664c645089d693365d784a580" + } + }, + "2434f5e02bdf4fc78c88d4c146ff6ae7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "24729235af21409696bed8f0b01e5127": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "27587b51d0694c67b0465a4929911cd6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2805a35efa4a401ea88c3a22dd9752f5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "297cbe85fb7041cf94532a565e83397b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "29bed95021784427be407f74be7daff0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9387e448cba046bb913d9d60ceefe363", + "placeholder": "​", + "style": "IPY_MODEL_ee038165676342fd952cc958de73697f", + "value": " 175/175 [00:12<00:00, 18.64it/s, loss: 206.76303100585938]" + } + }, + "2a516ab4e8c247599fb7faf9ec95f676": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2eb19f9206bf48809789b6eb15723c10": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_581562b79aac46fb947d61130897e232", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_aaa0adc4eefd49c1a96055eae215d7eb", + "value": 175 + } + }, + "2f347bd02cb944ccad43744dd7e4eeea": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cdd01d42b3d24f99b390b8a8fbf8dcf7", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6c39e034913845b797600de2fafe98aa", + "value": 175 + } + }, + "323712cc4b21465f894dbb2ab3960178": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "35a4be6db57c4ae2b36604feae29d861": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6cc313af7666494794ee75d95aff9289", + "IPY_MODEL_9152e497a92b4ef1b33612cfd628f739", + "IPY_MODEL_29bed95021784427be407f74be7daff0" + ], + "layout": "IPY_MODEL_323712cc4b21465f894dbb2ab3960178" + } + }, + "381466138eaa45278002150b1219293f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0c60f13e34d14fef9d0bbf6d7ded673a", + "placeholder": "​", + "style": "IPY_MODEL_2434f5e02bdf4fc78c88d4c146ff6ae7", + "value": " 175/175 [00:12<00:00, 17.95it/s, loss: 156.27389526367188]" + } + }, + "38c1e5f53c674dffb84d117171fc2563": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "3d98aea664c645089d693365d784a580": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3e6f882231d2445cbad2dc940eb1c056": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_851361f7fb4f4db6a4a165b21d627af3", + "placeholder": "​", + "style": "IPY_MODEL_0f00ffdd775042aa8860456d5b3440a2", + "value": " 175/175 [00:12<00:00, 16.92it/s, loss: 152.1123046875]" + } + }, + "45fead35e2114df598508e5694f62bef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a5da50968aae432aa5b3c90c8e7ddb04", + "placeholder": "​", + "style": "IPY_MODEL_95c22c37f1b04034b4132ea248af2e94", + "value": "100%" + } + }, + "471be6fd1d8c4925a5ed4d2a9ec7673c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "48d4a0722e9d42398e3ae796f44170e8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4d5f2b6c66904e7bb2f49f7b39174de4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f5fbf0c6280c41b59e7f48e05afe8d20", + "IPY_MODEL_c05fd736bc514fd4810efe5b3eaa9c55", + "IPY_MODEL_3e6f882231d2445cbad2dc940eb1c056" + ], + "layout": "IPY_MODEL_828eeb2e8ab346a297928bbea0eec155" + } + }, + "4f882fc7f8054471855b77b116fd566b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "529be14dea3e41d89f80b7dcb6347f22": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_24729235af21409696bed8f0b01e5127", + "placeholder": "​", + "style": "IPY_MODEL_047ea2ea246342f18a2d70390099c0f5", + "value": " 175/175 [00:13<00:00, 17.49it/s, loss: 204.53260803222656]" + } + }, + "53dac5aeffcb41d388da7f4aaf5e19b9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "581562b79aac46fb947d61130897e232": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5a0653f04c624aed8ba9af0eaad8b0fd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6599dc2951474e4282ff1894ec0851e8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b61d6e41157e43ee99f40d8b018877c4", + "IPY_MODEL_2f347bd02cb944ccad43744dd7e4eeea", + "IPY_MODEL_94f52e75bcfe48a2a51fbaf59c22352c" + ], + "layout": "IPY_MODEL_0c3bdf21300f4610a68d9dbfa566a1fd" + } + }, + "680ba64a00484feba87714c8ac2b2f1e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6c39e034913845b797600de2fafe98aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6cc313af7666494794ee75d95aff9289": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9f4caca2a12a4e64998d8e4977e2d038", + "placeholder": "​", + "style": "IPY_MODEL_be92dee7c4dc46edb88c555550f9ae37", + "value": "100%" + } + }, + "6e196069369c4126acdc53ac6da328aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ba9e7a3a9bbb46faaabdcc944650f4af", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a4c68ce78f024f9aa4bd2fe3ac296d25", + "value": 175 + } + }, + "6f1de415698948e2b4747e34f36684a4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c9d3591403d242dab30233318dd592ea", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_126c0bb014d84c3abf231605148a8353", + "value": 175 + } + }, + "72a4a60175c34a6680df47c0e1002d90": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "78274c291a2044a196f5ea743d2853a7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c1872a626dd4856a621bbc425c4a947", + "IPY_MODEL_6f1de415698948e2b4747e34f36684a4", + "IPY_MODEL_529be14dea3e41d89f80b7dcb6347f22" + ], + "layout": "IPY_MODEL_a3fc766655ce445dabb70df3fe051df0" + } + }, + "7913abef7f9147e19c39b54878f1d73e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "828eeb2e8ab346a297928bbea0eec155": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "851361f7fb4f4db6a4a165b21d627af3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "855c15c10a2844b29c036c2ce58c866a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d2946fea7b7d42fa98ec040769478597", + "placeholder": "​", + "style": "IPY_MODEL_96b3a74edb9b4400bf9c1fe1b7010f03", + "value": " 175/175 [00:12<00:00, 17.72it/s, loss: 156.78253173828125]" + } + }, + "8afa3c2d298b40c9be0312768d98fd7c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e33d1e0e7cb646aa8803b7338f6da888", + "placeholder": "​", + "style": "IPY_MODEL_1de75ad1a9e740c18f8cf2ed2cd5955b", + "value": "100%" + } + }, + "8da43c5164f64e7d8bb645099e1ee3e6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2189d23f386a4c00ae11995e974569eb", + "placeholder": "​", + "style": "IPY_MODEL_2805a35efa4a401ea88c3a22dd9752f5", + "value": " 175/175 [00:12<00:00, 17.29it/s, loss: 169.45787048339844]" + } + }, + "8dffad816bec4292a338af6c8b3a1e5d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8e950e1ea7d047618f38b6619a6312e1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9152e497a92b4ef1b33612cfd628f739": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_680ba64a00484feba87714c8ac2b2f1e", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_38c1e5f53c674dffb84d117171fc2563", + "value": 175 + } + }, + "9387e448cba046bb913d9d60ceefe363": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "93eb82f2191a4e42887695889f30a503": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "948e95b9112b4e31945e509c68ab8ec9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_471be6fd1d8c4925a5ed4d2a9ec7673c", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0c29f121e84a44608393ade2b1381116", + "value": 175 + } + }, + "94ea667c7972465e8166c8747ef24d94": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "94f52e75bcfe48a2a51fbaf59c22352c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8dffad816bec4292a338af6c8b3a1e5d", + "placeholder": "​", + "style": "IPY_MODEL_9b2657e17fca494aa30a07f515e1d35d", + "value": " 175/175 [00:12<00:00, 17.75it/s, loss: 151.23312377929688]" + } + }, + "95c22c37f1b04034b4132ea248af2e94": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "96b3a74edb9b4400bf9c1fe1b7010f03": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "96ddd57e6ccf4e9a82a6575b4a9843f1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ec552968d3f64c57b2f854f118dd234b", + "placeholder": "​", + "style": "IPY_MODEL_7913abef7f9147e19c39b54878f1d73e", + "value": "100%" + } + }, + "9b2657e17fca494aa30a07f515e1d35d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9f4caca2a12a4e64998d8e4977e2d038": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a0091f5723714384a05b0c27c489ff1b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e3806093ea654edcac0e153bd7ccdf9e", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1f6e9b83a8744173a162f479dd04dc6f", + "value": 175 + } + }, + "a3e241c8ed1449aa9e3adce6f9fa69bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a3fc766655ce445dabb70df3fe051df0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4c68ce78f024f9aa4bd2fe3ac296d25": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a5da50968aae432aa5b3c90c8e7ddb04": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aaa0adc4eefd49c1a96055eae215d7eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b15d030dacc84d3d89a38d3f48c094e7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b4f60946184f439ba90f79cda27aa34a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b61d6e41157e43ee99f40d8b018877c4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8e950e1ea7d047618f38b6619a6312e1", + "placeholder": "​", + "style": "IPY_MODEL_5a0653f04c624aed8ba9af0eaad8b0fd", + "value": "100%" + } + }, + "b6a804a7415c41a19cfdd2b3af153629": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1a6a711d2b5e4ea7a8c769c36e194335", + "placeholder": "​", + "style": "IPY_MODEL_53dac5aeffcb41d388da7f4aaf5e19b9", + "value": " 175/175 [00:12<00:00, 16.65it/s, loss: 171.0970916748047]" + } + }, + "ba9e7a3a9bbb46faaabdcc944650f4af": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "baf572d365b74d8e81eb468b66e6b045": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d918542d0318488c88829bc650b6b8cc", + "placeholder": "​", + "style": "IPY_MODEL_fcb5fce90c82415589f58057fc51812b", + "value": " 175/175 [00:12<00:00, 20.34it/s, loss: 192.555908203125]" + } + }, + "bb28b9a67e63424ba4203d4a44de8dd4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bcc33f1b00f14139b3719c2f7a622960": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_45fead35e2114df598508e5694f62bef", + "IPY_MODEL_f892714895654834a2bd95d04f2aff67", + "IPY_MODEL_b6a804a7415c41a19cfdd2b3af153629" + ], + "layout": "IPY_MODEL_d64e3edbe0914733b7a27f35a71bc9c8" + } + }, + "be92dee7c4dc46edb88c555550f9ae37": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c05fd736bc514fd4810efe5b3eaa9c55": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_48d4a0722e9d42398e3ae796f44170e8", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bb28b9a67e63424ba4203d4a44de8dd4", + "value": 175 + } + }, + "c0da3b83a124493fa8364441caa4cf00": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c5fac4f1cbf64dd08351ea32fb4b4a59": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c8a81fdd7aa143e9a7bb54d490cb105e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c9b3bccd02ee402c99c5c10e6c03530d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c9d3591403d242dab30233318dd592ea": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "caf7a900394c4705bc30d3cdecb0e24d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ea89a8bbbf8d43f19fcf91ff8935ec84", + "IPY_MODEL_a0091f5723714384a05b0c27c489ff1b", + "IPY_MODEL_baf572d365b74d8e81eb468b66e6b045" + ], + "layout": "IPY_MODEL_c5fac4f1cbf64dd08351ea32fb4b4a59" + } + }, + "cdd01d42b3d24f99b390b8a8fbf8dcf7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d1c852ac83ca466b839620a2d362d587": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c8a81fdd7aa143e9a7bb54d490cb105e", + "placeholder": "​", + "style": "IPY_MODEL_c0da3b83a124493fa8364441caa4cf00", + "value": "100%" + } + }, + "d2946fea7b7d42fa98ec040769478597": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d64e3edbe0914733b7a27f35a71bc9c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d9164be3d674410da5eee962bc243727": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d918542d0318488c88829bc650b6b8cc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e33d1e0e7cb646aa8803b7338f6da888": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e3806093ea654edcac0e153bd7ccdf9e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ea89a8bbbf8d43f19fcf91ff8935ec84": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_72a4a60175c34a6680df47c0e1002d90", + "placeholder": "​", + "style": "IPY_MODEL_c9b3bccd02ee402c99c5c10e6c03530d", + "value": "100%" + } + }, + "ec552968d3f64c57b2f854f118dd234b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ee038165676342fd952cc958de73697f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ef976e12f2144286af54f4ee339c08de": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4f882fc7f8054471855b77b116fd566b", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b4f60946184f439ba90f79cda27aa34a", + "value": 175 + } + }, + "f0a2ddf4f8a54ccd8e301b54943bb88c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f3b39579ce11475ea2ca67198c64f66a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_20a2d55104cc4a98a06c5bf50a80b51c", + "placeholder": "​", + "style": "IPY_MODEL_94ea667c7972465e8166c8747ef24d94", + "value": " 175/175 [00:12<00:00, 17.98it/s, loss: 253.90286254882812]" + } + }, + "f5fbf0c6280c41b59e7f48e05afe8d20": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_297cbe85fb7041cf94532a565e83397b", + "placeholder": "​", + "style": "IPY_MODEL_13477a6d079a45c8b340cfe7f18df03d", + "value": "100%" + } + }, + "f892714895654834a2bd95d04f2aff67": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_93eb82f2191a4e42887695889f30a503", + "max": 175, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a3e241c8ed1449aa9e3adce6f9fa69bf", + "value": 175 + } + }, + "fcb5fce90c82415589f58057fc51812b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Lab3/Week3_Autoencoder+MAE - Copy.py b/Lab3/Week3_Autoencoder+MAE - Copy.py new file mode 100644 index 0000000..301f60d --- /dev/null +++ b/Lab3/Week3_Autoencoder+MAE - Copy.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # Introduction & Import Necessary Setup +# In this labsheet, we'll delve into the fascinating world of autoencoders (AEs), a type of neural network renowned for its ability to compress and reconstruct data. Autoencoders work by first encoding input data, such as images, into a compact feature vector through an encoder network. This process effectively distills the essence of the data into a smaller, more manageable form. The feature vector, often referred to as the "bottleneck," plays a crucial role in this compression process, allowing us to represent the input data with fewer features. +# +# Following compression, a second neural network, known as the decoder, takes over to reconstruct the original data from the compressed feature vector. This remarkable ability to compress and then reconstruct data makes autoencoders extremely valuable in various applications, including data compression and image comparison at a more meaningful level than mere pixel-by-pixel analysis. +# +# Moreover, our exploration will not stop at the autoencoder framework itself. We will also introduce the concept of "deconvolution" (also known as transposed convolution), a powerful operator used to enlarge feature maps in both height and width dimensions. Deconvolution networks are indispensable in scenarios where we begin with a compact feature vector and aim to generate a full-sized image. This technique is pivotal in various advanced neural network applications, such as Variational Autoencoders (VAEs), Generative Adversarial Networks (GANs), and super-resolution. +# +# To kick things off, we'll start by importing our standard libraries, setting the stage for our deep dive into the inner workings and applications of autoencoders. + +# In[1]: + + +## Standard libraries +import os +import json +import math +import numpy as np + +## Imports for plotting +import matplotlib.pyplot as plt +get_ipython().run_line_magic('matplotlib', 'inline') +from IPython.display import set_matplotlib_formats +set_matplotlib_formats('svg', 'pdf') # For export +from matplotlib.colors import to_rgb +import matplotlib +matplotlib.rcParams['lines.linewidth'] = 2.0 +## Progress bar +from tqdm.notebook import tqdm + +## PyTorch +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torch.optim as optim +# Torchvision +import torchvision +from torchvision.datasets import CIFAR10 +from torchvision import transforms + +DATASET_PATH = "dataset" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +print("Device:", device) + + +# # Download and setup the dataset +# In this labsheet, our focus shifts to the CIFAR10 dataset, a collection known for its rich, colored images. Each image within CIFAR10 is equipped with 3 color channels and boasts a resolution of 32x32 pixels. This characteristic is particularly advantageous when working with autoencoders, as they are not bound by the constraints of probabilistic image modeling. +# +# Should you already have the CIFAR10 dataset downloaded in a different directory, it's important to adjust the DATASET_PATH variable accordingly. This step ensures you avoid unnecessary additional downloads, streamlining your workflow and allowing you to dive into the practical exercises more swiftly. + +# In[105]: + + +# Transformations applied on each image => only make them a tensor +transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,),(0.5,))]) + +# Loading the training dataset. We need to split it into a training and validation part +train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True) +train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000]) + +# Loading the test set +test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True) + +# We define a set of data loaders that we can use for various purposes later. +train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4) +val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4) +test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4) + +def get_train_images(num): + return torch.stack([train_dataset[i][0] for i in range(num)], dim=0) + + +# # Building the autoencoder +# +# In general, an autoencoder consists of an **encoder** that maps the input $x$ to a lower-dimensional feature vector $z$, and a **decoder** that reconstructs the input $\hat{x}$ from $z$. We train the model by comparing $x$ to $\hat{x}$ and optimizing the parameters to increase the similarity between $x$ and $\hat{x}$. See below for a small illustration of the autoencoder framework. +# +# +# ![img](https://raw.githubusercontent.com/hqsiswiliam/COM3025_Torch/main/autoencoder.png) + +# +# For an educational purpose revision in markdown format, the text could be enhanced as follows: +# +# To kick off our exploration, we initiate with the construction of the encoder. This component is fundamentally a deep convolutional network tailored for progressively diminishing the image's dimensions. This diminution is achieved through the use of strided convolutions, which methodically reduce the image's size layer by layer. Following the thrice-executed downscaling process, we transition the architecture from convolutional layers to a flattened feature representation. This is achieved by flattening the spatial features into a single vector, which is then processed through several linear layers. As a result, we obtain the latent representation, denoted as +# $z$, encapsulating the compressed essence of the input image. The size of this latent vector, $d$, is adjustable, providing flexibility in the encoding capacity of our network. + +# In[59]: + + +class Encoder(nn.Module): + + def __init__(self, + num_input_channels : int, + base_channel_size : int, + latent_dim : int, + act_fn : object = nn.GELU): + """ + Inputs: + - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3 + - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it. + - latent_dim : Dimensionality of latent representation z + - act_fn : Activation function used throughout the encoder network + """ + super().__init__() + c_hid = base_channel_size + self.net = nn.Sequential( + nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16 + act_fn(), + nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), + act_fn(), + nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8 + act_fn(), + nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1), + act_fn(), + nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4 + act_fn(), + nn.Flatten(), # Image grid to single feature vector + nn.Linear(2*16*c_hid, latent_dim) + ) + + # self.flatten = nn.Sequential( + # nn.Flatten(), # Image grid to single feature vector + # nn.Linear(2*16*c_hid, latent_dim) + # ) + + def forward(self, x): + # x = self.net(x) + + # print(x.shape) + + # return self.flatten(x) + + return self.net(x) + + +# # Task1 +# Now Complete the decoder implementation + +# In[133]: + + +class Decoder(nn.Module): + + def __init__(self, + num_input_channels : int, + base_channel_size : int, + latent_dim : int, + act_fn : object = nn.GELU): + """ + Inputs: + - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3 + - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it. + - latent_dim : Dimensionality of latent representation z + - act_fn : Activation function used throughout the decoder network + """ + super().__init__() + c_hid = base_channel_size + self.net = nn.Sequential( + nn.Linear(latent_dim, 2*16*c_hid), + act_fn(), + nn.Unflatten(1, (2*c_hid, 4, 4)), + nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 8x8 <= 4x4 + act_fn(), + nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1), + act_fn(), + nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, padding=1, stride=2, output_padding=1), # 16x16 <= 8x8 + act_fn(), + nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), + act_fn(), + nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, padding=1, stride=2, output_padding=1), # 32x32 <= 16x16 + nn.Tanh(), + # nn.Sigmoid(), + ) + # You code goes here. + + def forward(self, x): + return self.net(x) + # You code goes here. + + +# # Combining Encoder and Decoder +# ## Loss Function: Mean Squared Error (MSE) +# +# For our loss function, we opt for the Mean Squared Error (MSE). MSE is particularly effective in emphasizing the significance of accurately predicting pixel values that are substantially misestimated by the network. For instance, a minor deviation, such as predicting 127 instead of 128, is deemed less critical. However, larger discrepancies, like confusing a pixel value of 0 with 128, are considered more severe and detrimental to the reconstruction quality. +# +# Unlike Variational Autoencoders (VAEs) that predict the probability for each pixel value, we employ MSE as a straightforward distance measure. This approach significantly reduces the number of parameters, streamlining the training process. To enhance our understanding of the per-pixel performance, we calculate the summed squared error, averaged across the batch dimension. It's important to note that alternative aggregations (mean or sum) yield equivalent outcomes in terms of resulting parameters. +# +# ### Limitations of MSE +# +# Despite its advantages, MSE is not without drawbacks. Primarily, it tends to produce blurrier images, as it inherently removes small noise and high-frequency patterns, which contribute minimally to the overall error. To mitigate this and achieve more realistic reconstructions, integrating Generative Adversarial Networks (GANs) with autoencoders has proven effective. This hybrid approach is explored in various studies ([example 1](https://arxiv.org/abs/1704.02304), [example 2](https://arxiv.org/abs/1511.05644), and [slides](http://elarosca.net/slides/iccv_autoencoder_gans.pdf)). +# +# Furthermore, MSE may not always accurately reflect visual similarity between images. A case in point is when an autoencoder produces an image that is slightly shifted—despite the near-identical appearance, the MSE can significantly increase, showcasing a limitation in capturing true visual fidelity. A potential solution involves leveraging a pre-trained CNN to measure distance based on visual features extracted from lower layers, offering a more nuanced comparison than pixel-level MSE. +# + +# In[134]: + + +class Autoencoder(nn.Module): + + def __init__(self, + base_channel_size: int, + latent_dim: int, + encoder_class : object = Encoder, + decoder_class : object = Decoder, + num_input_channels: int = 3, + width: int = 32, + height: int = 32): + super().__init__() + # Creating encoder and decoder + self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim) + self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim) + # Example input array needed for visualizing the graph of the network + self.example_input_array = torch.zeros(2, num_input_channels, width, height) + + def forward(self, x): + z = self.encoder(x) + x_hat = self.decoder(z) + return x_hat + + def _get_reconstruction_loss(self, batch): + x = batch # We do not need the labels + x_hat = self.forward(x) + loss = F.mse_loss(x, x_hat, reduction="none") + loss = loss.sum(dim=[1,2,3]).mean(dim=[0]) + return loss + + +# # Utility code for comparing Images + +# In[14]: + + +def compare_imgs(img1, img2, title_prefix=""): + # Calculate MSE loss between both images + loss = F.mse_loss(img1, img2, reduction="sum") + # Plot images for visual comparison + grid = torchvision.utils.make_grid(torch.stack([img1, img2], dim=0), nrow=2, normalize=True) + grid = grid.permute(1, 2, 0) + plt.figure(figsize=(4,2)) + plt.title(f"{title_prefix} Loss: {loss.item():4.2f}") + plt.imshow(grid) + plt.axis('off') + plt.show() + +for i in range(2): + # Load example image + img, _ = train_dataset[i] + img_mean = img.mean(dim=[1,2], keepdims=True) + + # Shift image by one pixel + SHIFT = 1 + img_shifted = torch.roll(img, shifts=SHIFT, dims=1) + img_shifted = torch.roll(img_shifted, shifts=SHIFT, dims=2) + img_shifted[:,:1,:] = img_mean + img_shifted[:,:,:1] = img_mean + compare_imgs(img, img_shifted, "Shifted -") + + # Set half of the image to zero + img_masked = img.clone() + img_masked[:,:img_masked.shape[1]//2,:] = img_mean + compare_imgs(img, img_masked, "Masked -") + + +# # Task2 +# Add training code to train the AutoEncoder + +# In[2]: + + +# for batch in tqdm(train_loader, total=len(train_loader)): +import torch.optim as optim + +model = Autoencoder(64, 128, ) # you code here +model.to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # your code here +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # your code here, can use ReduceLROnPlateau +# Write training loop here + +loss_fn = nn.MSELoss() + +n_epoch = 40 +model.train() +for epoch in range(n_epoch): + print(f"\nEpoch {epoch}:") + + avg_loss = 0 + + for i, data in enumerate(train_loader): + inputs, _ = data + + inputs = inputs.cuda() + + loss = model._get_reconstruction_loss(inputs) #loss_fn(outputs, inputs) + + optimizer.zero_grad() + + loss.backward() + + optimizer.step() + + avg_loss += loss + + print(f'\rBatch: {i}: Loss:{loss} avg_Loss: {avg_loss/(i + 1)} ', end='') + + scheduler.step(loss) + + +# In[144]: + + +def visualize_reconstructions(model, input_imgs): + # Reconstruct images + model.eval() + with torch.no_grad(): + reconst_imgs = model(input_imgs.to(device)) + reconst_imgs = reconst_imgs.cpu() + + # Plotting + imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1) + grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True) + grid = grid.permute(1, 2, 0) + plt.figure(figsize=(7,4.5)) + plt.title(f"Reconstructed from model") + plt.imshow(grid) + plt.axis('off') + plt.show() + +input_imgs = get_train_images(6) +visualize_reconstructions(model, input_imgs) + + +# # Masked AutoEncoder +# The follow code are the demonstration of Masked Autoencoder implementation and visualization + +# # Import Necessary Libraries + +# In[4]: + + +import sys +import os +import requests + +import torch +import numpy as np + +import matplotlib.pyplot as plt +from PIL import Image + +# check whether run in Colab +if 'google.colab' in sys.modules: + print('Running in Colab.') + get_ipython().system('pip3 install timm==0.4.5 # 0.3.2 does not work in Colab') + get_ipython().system('git clone https://github.com/facebookresearch/mae.git') + sys.path.append('./mae') +else: + sys.path.append('./mae') +import models_mae + + +# # Build up necessary utillities + +# In[131]: + + +# define the utils + +imagenet_mean = np.array([0.485, 0.456, 0.406]) +imagenet_std = np.array([0.229, 0.224, 0.225]) + +def show_image(image, title=''): + # image is [H, W, 3] + assert image.shape[2] == 3 + plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) + plt.title(title, fontsize=16) + plt.axis('off') + return + +def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): + # build model + model = getattr(models_mae, arch)() + # load model + checkpoint = torch.load(chkpt_dir, map_location='cpu') + msg = model.load_state_dict(checkpoint['model'], strict=False) + print(msg) + return model + +def run_one_image(img, model): + x = torch.tensor(img) + + # make it a batch-like + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + + # run MAE + loss, y, mask = model(x.float(), mask_ratio= 0.75) + y = model.unpatchify(y) + y = torch.einsum('nchw->nhwc', y).detach().cpu() + + # visualize the mask + mask = mask.detach() + mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) + mask = model.unpatchify(mask) # 1 is removing, 0 is keeping + mask = torch.einsum('nchw->nhwc', mask).detach().cpu() + + x = torch.einsum('nchw->nhwc', x) + + # masked image + im_masked = x * (1 - mask) + + # MAE reconstruction pasted with visible patches + im_paste = x * (1 - mask) + y * mask + + # make the plt figure larger + plt.rcParams['figure.figsize'] = [24, 24] + + plt.subplot(1, 4, 1) + show_image(x[0], "original") + + plt.subplot(1, 4, 2) + show_image(im_masked[0], "masked") + + plt.subplot(1, 4, 3) + show_image(y[0], "reconstruction") + + plt.subplot(1, 4, 4) + show_image(im_paste[0], "reconstruction + visible") + + plt.show() + + +# # Load one image + +# In[189]: + + +# load an image +img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145 +# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851 +img = Image.open(requests.get(img_url, stream=True).raw) +img = img.resize((224, 224)) +img = np.array(img) / 255. + +assert img.shape == (224, 224, 3) + +# normalize by ImageNet mean and std +img = img - imagenet_mean +img = img / imagenet_std + +plt.rcParams['figure.figsize'] = [5, 5] +show_image(torch.tensor(img)) + + +# In[141]: + + +# Patch for numpy error +np.float = float +np.int = int #module 'numpy' has no attribute 'int' +np.object = object #module 'numpy' has no attribute 'object' +np.bool = bool #module 'numpy' has no attribute 'bool' +# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75) + +# download checkpoint if not exist +get_ipython().system('wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth') + +chkpt_dir = 'mae_visualize_vit_large.pth' +model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') +print('Model loaded.') + + +mine_img = Image.open('./st2/6644818.png', formats=('PNG',)).convert('RGB')# Image.open(requests.get(img_url, stream=True).raw) + +# mine_img.show() +mine_img = mine_img.resize((224, 224)) + +mine_img = np.array(mine_img) / 255. + +# print(mine_img.shape, mine_img[0][0]) + +assert mine_img.shape == (224, 224, 3) + +# normalize by ImageNet mean and std +mine_img = mine_img - imagenet_mean +mine_img = mine_img / imagenet_std + +plt.rcParams['figure.figsize'] = [5, 5] +show_image(torch.tensor(mine_img)) + + +import pandas as pd + +d = pd.read_csv('st2/6644818/shuffle_info.csv', header=None) + +ids_keep = torch.Tensor(eval(d.loc[0][1])).type(torch.int64) +ids_restore = torch.Tensor(eval(d.loc[1][1])).type(torch.int64) + +def masking(self, x): + N, L, D = x.shape # batch, length, dim + return torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # Creates the masked images + +def forward_encoder(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + x = masking(self, x) + + # append cls token + cls_token = self.cls_token + model.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x + +def restore_one_image(img, model): + x = torch.tensor(img) + + # make it a batch-like + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + + temp_x = forward_encoder(model, x.float()) + + y = model.forward_decoder(temp_x, ids_restore) + y = model.unpatchify(y) + y = torch.einsum('nchw->nhwc', y).detach().cpu() + + x = torch.einsum('nchw->nhwc', x) + + # make the plt figure larger + plt.rcParams['figure.figsize'] = [12, 12] + + plt.subplot(1, 2, 1) + show_image(x[0], "original") + + plt.subplot(1, 2, 2) + show_image(y[0], "reconstruction") + + plt.show() + +torch.manual_seed(5) +print('MAE with pixel reconstruction:') +restore_one_image(mine_img, model_mae) + + +# In[ ]: + + + +